Source code for NuRadioReco.modules.RNO_G.stationHitFilter

"""
Hit Filter Module (RNO-G specific, can be upgraded to more general use in the future)
The main purpose is to reject thermal events in data.
"""

from NuRadioReco.modules.base.module import register_run
from NuRadioReco.utilities import trace_utilities, units

from collections import defaultdict
import numpy as np
import logging
import math
import copy
import time


[docs] class stationHitFilter: def __init__(self, complete_time_check=False, complete_hit_check=False, time_window=10.0*units.ns, threshold_multiplier=6.5, select_trigger=None): """ Passes event through "hit filter". Looks for temporal coincidence in multiple channel pairs. Currently this module is designed specifically for the deep components of an RNO-G station and uses only the deep in-ice channels. The module checks for temporal coincidences in multiple channel pairs (uses only antennas at ~100 depth; i.e., not the shallow v-pols). Adjacent channels are put into groups: G0: (0,1,2,3); G1: (9,10); G2: (23,22); G3: (8,4). To determine the timing and amplitude of a "hit" in every channel, the Hilbert Transform is applied to the waveform and the maximum is found. The time of the maximum is then used to check for coincidences in each group. This "time checker" requires at least 1 coincident pair in G0 (PA), and another coincident pair in any other group (including G0 - with an additional requirement that channel pairs need to be connected). When the time check fails, by default, a "hit checker" will see if there's any in-ice channel that has a high hit (maximum > threshold_multiplier * noise_RMS) and whenever there's a high hit the event passes the module. Parameters ---------- complete_time_check: bool (default: False) If False, time checker will stop early whenever criteria are satisfied. If True, time checker will run through all channel groups no matter what. complete_hit_check: bool (default: False) If False, only run the hit checker when the time checker fails. If True, always run the hit checker. time_window: float (default: 10.0*units.ns) Coincidence window for two adjacent channels. threshold_multiplier: float (default: 6.5) High hit threshold multiplier, where a hit threshold is the multiplier times the noise RMS. select_trigger: str (default: None) Select a specific trigger type to run this filter on. If None, all triggers will be evaluated. If you select a specific trigger, events with other triggers will be treated as if they have passed the module (but not counted). """ self._complete_time_check = complete_time_check self._complete_hit_check = complete_hit_check self._dT = time_window self._threshold_multiplier = threshold_multiplier self._select_trigger = select_trigger self._in_ice_channels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 21, 22, 23] self._in_ice_channel_groups = ([0, 1, 2, 3], [9, 10], [23, 22], [8, 4]) self._channel_pairs_in_PA = ([0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]) self._n_channels_in_ice = len(self._in_ice_channels) self._envelope_max_time_index = None self._envelope_max_time = None self._noise_RMS = None self._times = None self._traces = None self._envelope_traces = None self._hit_thresholds = None self._is_over_hit_threshold = None self._passed_hit_filter = None self._passed_time_checker = None self._passed_hit_checker = None ### 2-D list ### # This list contains sub-lists for all groups. # Each sub-list contains bool(s) indicating the COINCIDENCE of channel pair(s) in each group # self._is_in_time_window[0] is the PA, it has 6 pairs (6 bools) # Initiating each element with None, it will later be replaced with bool when running the time checker self._is_in_time_window_template = [[None for _ in range(math.factorial(len(group)-1))] for group in self._in_ice_channel_groups] def _channel_mapping(self, channel_in): """ Channel Mappings for specifically channels on Helper String 2 The Hit Filter works only on all 15 in-ice channels. This function maps channels 21, 22, 23 to indices 12, 13, 14 (Skip surface channels 12 ~ 20). Parameters ---------- channel_in: int Input channel number Returns ------- channel_out: int Output channel number """ channel_out = channel_in - 9 if channel_in >= 21 else channel_in return channel_out def _time_checker(self): """ Check the hit times between channels in groups to select coincident pairs. See if there are at least 2 coincident pairs in time sequence in Group 0 (PA), and if there's only 1 pair in PA, then find the other pair in other groups; otherwise, the input event fails the time checker. Returns ------- self._passed_time_checker: bool Event passed the time checker or not """ self._passed_time_checker = False envelope_max_time = self.get_envelope_max_time() coincidences_in_time_sequence_PA = [False, False, False] self._is_in_time_window = copy.deepcopy(self._is_in_time_window_template) for i_group, group in enumerate(self._in_ice_channel_groups): # Group 0 is special because we require at least one coincident pair in this group if i_group == 0: # For e.g, channel 3 and 0 take 3 times the dT. dT_multipliers = np.diff(self._channel_pairs_in_PA).flatten() for i_channel_pair, channel_pair in enumerate(self._channel_pairs_in_PA): hit_time_difference = abs(envelope_max_time[channel_pair[1]] - envelope_max_time[channel_pair[0]]) self._is_in_time_window[i_group][i_channel_pair] = hit_time_difference <= dT_multipliers[i_channel_pair] * self._dT # This condition assures that the conicident pairs are in time sequence. For example: # Having a conicdence only in pairs (0, 1) and (0, 3) is not valid. But having (0, 1) and (1, 3) is valid. if self._is_in_time_window[i_group][i_channel_pair]: coincidences_in_time_sequence_PA[channel_pair[0]] = self._is_in_time_window[i_group][i_channel_pair] if np.sum(coincidences_in_time_sequence_PA) >= 2: self._passed_time_checker = True else: # If the time checker already passed, we can skip the rest of the groups when not completing the time checking if self._passed_time_checker and not self._complete_time_check: break # If there are no coincidences in PA, we can skip the rest of the groups if not np.any(coincidences_in_time_sequence_PA) and not self._complete_time_check: break if len(group) != 2: raise NotImplementedError("For any channel group other than Group 0 (PA), only 2 channels are supported for now.") hit_time_difference = abs(envelope_max_time[self._channel_mapping(group[0])] - envelope_max_time[self._channel_mapping(group[1])]) self._is_in_time_window[i_group][0] = hit_time_difference <= self._dT if self._is_in_time_window[i_group][0] and np.sum(coincidences_in_time_sequence_PA) >= 1: self._passed_time_checker = True if not self._complete_time_check: break return self._passed_time_checker def _hit_checker(self): """ Find a high hit within all 15 in-ice channels. See if there's at least 1 high hit within channels, if yes then the event passes the hit checker. Returns ------- self._passed_hit_checker: bool Event passed the hit checker or not """ self._is_over_hit_threshold = np.amax(self.get_envelope_traces(), axis=-1) > self.get_hit_thresholds() self._passed_hit_checker = np.any(self._is_over_hit_threshold) return self._passed_hit_checker
[docs] def begin(self, log_level=logging.INFO): """ Set up logging info. Parameters ---------- log_level: enum Set verbosity level of logger (default: logging.INFO) """ self.logger = logging.getLogger('NuRadioReco.RNO_G.stationHitFitter') self.logger.setLevel(log_level) self.__counting_dict = defaultdict(int) self.__is_wanted_trigger_type = None self.__total_run_time = 0
[docs] def set_up(self, set_of_traces, set_of_times, noise_RMS): """ Set things up before passing to the Hit Filter. This setup function calculates the noise RMS, sets the high hit threshold for each channel, gets the Hilbert envelope, and finds the time when the maximum happens (hit). IMPORTANT: If you use this function by yourself, make sure your inputs come from all 15 in-ice channels in order only, rather than all 24 channels. Parameters ---------- set_of_traces: 2-D array of floats A set of input trace arrays of all 15 in-ice channels set_of_times: 2-D array of floats A set of input times arrays of all 15 in-ice channels noise_RMS: 1-D array of floats A set of input noise RMS values of all 15 in-ice channels """ self._passed_hit_checker = None if len(set_of_traces) != self._n_channels_in_ice or len(set_of_times) != self._n_channels_in_ice: raise NotImplementedError("Make sure your inputs come from all in-ice channels in order for Hit Filter to perform.") self._traces = set_of_traces self._times = set_of_times self._envelope_traces = trace_utilities.get_hilbert_envelope(self._traces) self._envelope_max_time_index = np.argmax(self._envelope_traces, axis=-1) self._envelope_max_time = self._times[range(len(self._times)), self._envelope_max_time_index] if noise_RMS is not None: if len(noise_RMS) != self._n_channels_in_ice: raise NotImplementedError("Make sure your input RMS values come from all in-ice channels in order for Hit Filter to perform.") self._noise_RMS = noise_RMS else: self._noise_RMS = np.array([trace_utilities.get_split_trace_noise_RMS(trace) for trace in self._traces]) self._hit_thresholds = self._noise_RMS * self.get_threshold_multiplier()
[docs] def apply_hit_filter(self): """ See if the input event will survive the Hit Filter or not. After the setup, it first checks with the time checker, if event passed the time checker then it passes the Hit Filter; if event failed the time checker, then checks with the hit checker to find a high hit. Returns ------- self._passed_hit_filter: bool Event passed the Hit Filter or not """ self._passed_hit_filter = self._time_checker() # If time checker failed or if you want to complete the hit checking regardless if not self._passed_hit_filter or self._complete_hit_check: self._hit_checker() # Only update _passed_hit_filter with the hit checker when the time checker did not pass if not self.is_passed_time_checker(): self._passed_hit_filter = self.is_passed_hit_checker() return self._passed_hit_filter
[docs] @register_run() def run(self, evt, station, det=None, noise_RMS_all=None): """ Runs the Hit Filter. Parameters ---------- evt: `NuRadioReco.framework.event.Event` Using the event object to get the trigger type station: `NuRadioReco.framework.station.Station` The station to use the Hit Filter on det: Detector object | None Detector object (not used in this method, included to have the same signature as other NuRadio classes) noise_RMS_all: 1-D numpy array (default: None) Noise RMS values of all 24 channels, if not given the Hit Filter will calculate them for the 15 in-ice channels Returns ------- self.is_passed_hit_filter(): bool Event passed the Hit Filter or not """ t0 = time.time() trigger_type = evt.get_station().get_first_trigger().get_name() # Only run the module on selected trigger type if self._select_trigger is not None and trigger_type != self._select_trigger: self._passed_hit_filter = True return True # This implicitly obeys the channel mapping traces = np.array([np.array(channel.get_trace()) for channel in station.iter_channels() if channel.get_id() in self._in_ice_channels]) times = np.array([np.array(channel.get_times()) for channel in station.iter_channels() if channel.get_id() in self._in_ice_channels]) if noise_RMS_all is not None: noise_RMS = noise_RMS_all[self._in_ice_channels] # HACK: use channel IDs to index noise_RMS else: noise_RMS = noise_RMS_all self.set_up(traces, times, noise_RMS) self.apply_hit_filter() self.__is_wanted_trigger_type = trigger_type == "LT" self.__counting_dict[trigger_type] += 1 self.__counting_dict[f"{trigger_type}_passed"] += int(self.is_passed_hit_filter()) self.__counting_dict["total"] += 1 self.__total_run_time += time.time() - t0 return self.is_passed_hit_filter()
[docs] def end(self): event_count = self.__counting_dict.pop("total") counts = (f"Processed Total: {event_count} events. Total run time: {self.__total_run_time:.2f} s, " f"Time per event: {self.__total_run_time / event_count * 1000:.2f} ms") trigger_types = np.unique([key.strip("_passed") for key in self.__counting_dict.keys()]) for trigger_type in trigger_types: counts += (f"\n{trigger_type:>10} triggers: {self.__counting_dict[f'{trigger_type}_passed']:4d} / {self.__counting_dict[trigger_type]:4d} events " f"({self.__counting_dict[f'{trigger_type}_passed'] / self.__counting_dict[trigger_type] * 100:.2f} %)") self.logger.info(counts)
##################### ###### Getters ###### #####################
[docs] def get_threshold_multiplier(self): """ Returns ------- self._threshold_multiplier: float High hit threshold multiplier (default: 6.5) """ return self._threshold_multiplier
[docs] def get_in_ice_channels(self): """ Returns ------- self._in_ice_channels: 1-D list of ints In-ice channel IDs in a list: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 21, 22, 23] """ return self._in_ice_channels
[docs] def get_in_ice_channel_groups(self): """ Returns ------- self._in_ice_channel_groups: Several 1-D lists of ints in a tuple In-ice channel groups in a tuple: ([0, 1, 2, 3], [9, 10], [23, 22], [8, 4]) """ return self._in_ice_channel_groups
[docs] def get_channel_pairs_in_PA(self): """ Returns ------- self._channel_pairs_in_PA: Several 1-D lists of ints in a tuple Channel pairs of the PA in a tuple: ([0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]) """ return self._channel_pairs_in_PA
[docs] def get_envelope_max_time_index(self): """ Returns ------- self._envelope_max_time_index: 1-D numpy array of ints Hit time index of each channel """ return self._envelope_max_time_index
[docs] def get_envelope_max_time(self): """ Returns ------- self._envelope_max_time: 1-D numpy array of floats Hit time of each channel """ return self._envelope_max_time
[docs] def get_in_ice_channels_noise_RMS(self): """ Returns ------- self._noise_RMS: 1-D numpy array of floats Noise RMS values for channels """ return self._noise_RMS
[docs] def get_times(self): """ Returns ------- self._times: 2-D numpy array of floats Arrays of times for channels """ return self._times
[docs] def get_traces(self): """ Returns ------- self._traces: 2-D numpy array of floats Arrays of traces for channels """ return self._traces
[docs] def get_envelope_traces(self): """ Returns ------- self._envelope_traces: 2-D numpy array of floats Arrays of envelope traces for channels """ return self._envelope_traces
[docs] def get_hit_thresholds(self): """ Returns ------- self._hit_thresholds: 1-D numpy array of floats Hit threshold of each channel """ return self._hit_thresholds
[docs] def is_passed_hit_filter(self): """ Returns ------- self._passed_hit_filter: bool See if event passed the Hit Filter """ return self._passed_hit_filter
[docs] def is_passed_time_checker(self): """ Returns ------- self._passed_time_checker: bool See if event passed the time checker """ return self._passed_time_checker
[docs] def is_passed_hit_checker(self): """ Returns ------- self._passed_hit_checker: bool See if event passed the hit checker """ if self._passed_hit_checker is not None: return self._passed_hit_checker else: raise NotImplementedError("Cannot call is_passed_hit_checker() when hit checking wasn't performed.")
[docs] def is_over_hit_threshold(self): """ Returns ------- self._is_over_hit_threshold: 1-D numpy array of bools See if there's a high hit in each channel """ if self._complete_hit_check: return self._is_over_hit_threshold else: raise NotImplementedError("Cannot call is_over_hit_threshold() when complete_hit_check is False.")
[docs] def is_in_time_window(self): """ Returns ------- self._is_in_time_window: 2-D list of bools See if a channel pair is coincident in the group """ if self._complete_time_check: return self._is_in_time_window else: raise NotImplementedError("Cannot call is_in_time_window() when complete_time_check is False.")
[docs] def is_wanted_trigger_type(self): """ Returns ------- self.__is_wanted_trigger_type: bool When we want to select only events with low threshold triggers. """ if self.__is_wanted_trigger_type is not None: return self.__is_wanted_trigger_type else: raise NotImplementedError("Cannot call is_wanted_trigger_type() when self.__is_wanted_trigger_type is None.")
[docs] def is_in_time_window_PA(self): """ Returns ------- dict: dictionary of bools See if channel pairs are coincident or not in Group 0 (PA) In the dictionary, there are 6 pairs: (0,1), (0,2), (0,3), (1,2), (1,3), (2,3) To see if a pair is coincident one can do, for example: is_in_time_window_PA[(0,1)], then it will be True or False """ dict = {} for i_pair, pair in enumerate(self._channel_pairs_in_PA): i = pair[0] j = pair[1] dict[(i,j)] = self._is_in_time_window[0][i_pair] return dict