Source code for NuRadioReco.framework.base_trace

from __future__ import absolute_import, division, print_function
import numpy as np
import logging
import fractions
import decimal
import numbers
import functools
from NuRadioReco.utilities import fft, bandpass_filter
import NuRadioReco.detector.response
import scipy.signal
import copy
try:
    import cPickle as pickle
except ImportError:
    import pickle
logger = logging.getLogger("NuRadioReco.BaseTrace")


[docs]class BaseTrace: def __init__(self, trace=None, sampling_rate=None, trace_start_time=0): """ Initialize the BaseTrace object. Parameters ---------- trace : np.array of floats (default: None) The time trace. Can also be set later with the `set_trace` method. sampling_rate : float (default: None) The sampling rate of the trace, i.e., the inverse of the bin width. trace_start_time : float (default: 0) The start time of the trace. """ self._sampling_rate = None self._time_trace = None self._frequency_spectrum = None self.__time_domain_up_to_date = True self._trace_start_time = trace_start_time if trace is not None: self.set_trace(trace, sampling_rate)
[docs] def get_trace(self): """ Returns the time trace. If the frequency spectrum was modified before, an ifft is performed automatically to have the time domain representation up to date. Returns ------- trace: np.array of floats the time trace """ if not self.__time_domain_up_to_date: self._time_trace = fft.freq2time(self._frequency_spectrum, self._sampling_rate) self.__time_domain_up_to_date = True self._frequency_spectrum = None return np.copy(self._time_trace)
[docs] def get_filtered_trace(self, passband, filter_type='butter', order=10, rp=None): """ Returns the trace after applying a filter to it. This does not change the stored trace. Parameters ---------- passband: list of floats lower and upper bound of the filter passband filter_type: string type of the applied filter. Options are rectangular, butter and butterabs order: int Order of the Butterworth filter, if the filter types butter or butterabs are chosen """ spec = copy.copy(self.get_frequency_spectrum()) freq = self.get_frequencies() filter_response = bandpass_filter.get_filter_response(freq, passband, filter_type, order, rp) spec *= filter_response return fft.freq2time(spec, self.get_sampling_rate())
[docs] def get_frequency_spectrum(self, window_mask=None): """ Returns the frequency spectrum. Parameters ---------- window_mask: array of bools (default: None) If not None, specifies the time window to be used for the FFT. Has to have the same length as the trace. Returns ------- frequency_spectrum: np.array of floats The frequency spectrum. """ if window_mask is None: if self.__time_domain_up_to_date: self._frequency_spectrum = fft.time2freq(self._time_trace, self._sampling_rate) self._time_trace = None self.__time_domain_up_to_date = False return np.copy(self._frequency_spectrum) else: trace = copy.copy(self.get_trace()) # The double transpose allows to work with 1D and ND traces return fft.time2freq(trace.T[window_mask].T, self._sampling_rate)
[docs] def set_trace(self, trace, sampling_rate): """ Sets the time trace. Parameters ---------- trace : np.array of floats The time series sampling_rate : float or str The sampling rate of the trace, i.e., the inverse of the bin width. If `sampling_rate="same"`, sampling rate is not changed (requires previous initialisation). """ if trace is not None: if trace.shape[trace.ndim - 1] % 2 != 0: raise ValueError( f'Attempted to set trace with an uneven number ({trace.shape[trace.ndim - 1]}) ' 'of samples. Only traces with an even number of samples are allowed.') self.__time_domain_up_to_date = True self._time_trace = np.copy(trace) self._frequency_spectrum = None if isinstance(sampling_rate, str) and sampling_rate.lower() == "same": if self._sampling_rate is None: raise ValueError( "You specified to keep the sampling rate but no value have been set previously.") pass # keep value of self._sampling_rate elif sampling_rate is not None: self._sampling_rate = sampling_rate else: raise ValueError("You have to specify a sampling rate for `BaseTrace.set_trace(...)`")
[docs] def set_frequency_spectrum(self, frequency_spectrum, sampling_rate): """ Sets the frequency spectrum. Parameters ---------- frequency_spectrum : np.array of floats The frequency spectrum sampling_rate : float or str The sampling rate of the trace, i.e., the inverse of the bin width. If `sampling_rate="same"`, sampling rate is not changed (requires previous initialisation). """ self.__time_domain_up_to_date = False self._frequency_spectrum = np.copy(frequency_spectrum) self._time_trace = None if isinstance(sampling_rate, str) and sampling_rate.lower() == "same": if self._sampling_rate is None: raise ValueError( "You specified to keep the sampling rate but no value have been set previously.") pass # keep value of self._sampling_rate elif sampling_rate is not None: self._sampling_rate = sampling_rate else: raise ValueError("You have to specify a sampling rate for `BaseTrace.set_frequency_spectrum(...)`")
[docs] def get_sampling_rate(self): """ Returns the sampling rate of the trace. Returns ------- sampling_rate: float sampling rate, i.e., the inverse of the bin width """ return self._sampling_rate
[docs] def get_times(self): try: length = self.get_number_of_samples() times = np.arange(0, length / self._sampling_rate - 0.1 / self._sampling_rate, 1. / self._sampling_rate) + self._trace_start_time if len(times) != length: err = ("time array does not have the same length as the trace. " f"n_samples = {length:d}, sampling rate = {self._sampling_rate:.5g}") logger.error(err) raise ValueError(err) except (ValueError, AttributeError): times = np.array([]) return times
[docs] def set_trace_start_time(self, start_time): self._trace_start_time = start_time
[docs] def add_trace_start_time(self, start_time): self._trace_start_time += start_time
[docs] def get_trace_start_time(self): return self._trace_start_time
[docs] def get_frequencies(self, window_mask=None): """ Returns the frequencies of the frequency spectrum. Parameters ---------- window_mask: array of bools (default: None) If not None, used to determine the number of samples in the time domain used for the frequency spectrum. Returns ------- frequencies: np.array of floats The frequencies of the frequency spectrum. """ if window_mask is None: nsamples = self.get_number_of_samples() else: nsamples = int(np.sum(window_mask)) return get_frequencies(nsamples, self._sampling_rate)
[docs] def get_hilbert_envelope(self): from scipy import signal # get hilbert envelope for either 1D (N) analytic trace or (3,N) E-field h = signal.hilbert(self.get_trace()) return np.abs(h)
[docs] def get_hilbert_envelope_mag(self): # ensure taking axis 0 of a 2D trace (trace might be (N) for analytic trace or (3,N) for E-field return np.linalg.norm(np.atleast_2d(self.get_hilbert_envelope()), axis=0)
[docs] def get_number_of_samples(self): """ Returns the number of samples in the time domain. Returns ------- n_samples: int number of samples in time domain """ if self.__time_domain_up_to_date: length = self._time_trace.shape[-1] # returns the correct length independent of the dimension of the array (channels are 1dim, efields are 3dim) else: length = (self._frequency_spectrum.shape[-1] - 1) * 2 return length
[docs] def apply_time_shift(self, delta_t, silent=False): """ Uses the fourier shift theorem to apply a time shift to the trace Note that this is a cyclic shift, which means the trace will wrap around, which might lead to problems, especially for large time shifts. Parameters ---------- delta_t: float Time by which the trace should be shifted silent: boolean (default:False) Turn off warnings if time shift is larger than 10% of trace length Only use this option if you are sure that your trace is long enough to acommodate the time shift """ if delta_t > .1 * self.get_number_of_samples() / self.get_sampling_rate() and not silent: logger.warning('Trace is shifted by more than 10% of its length') spec = self.get_frequency_spectrum() spec *= np.exp(-2.j * np.pi * delta_t * self.get_frequencies()) self.set_frequency_spectrum(spec, self._sampling_rate)
[docs] def resample(self, sampling_rate): if sampling_rate == self.get_sampling_rate(): return resampling_factor = fractions.Fraction(decimal.Decimal(sampling_rate / self.get_sampling_rate())).limit_denominator(5000) resampled_trace = self.get_trace() if resampling_factor.numerator != 1: # resample and use axis -1 since trace might be either shape (N) for analytic trace or shape (3,N) for E-field resampled_trace = scipy.signal.resample(resampled_trace, resampling_factor.numerator * self.get_number_of_samples(), axis=-1) if resampling_factor.denominator != 1: # resample and use axis -1 since trace might be either shape (N) for analytic trace or shape (3,N) for E-field resampled_trace = scipy.signal.resample(resampled_trace, np.shape(resampled_trace)[-1] // resampling_factor.denominator, axis=-1) if resampled_trace.shape[-1] % 2 != 0: resampled_trace = resampled_trace.T[:-1].T self.set_trace(resampled_trace, sampling_rate)
[docs] def serialize(self): time_trace = self.get_trace() # if there is no trace, the above will return np.array(None). if not time_trace.shape: return None data = {'sampling_rate': self.get_sampling_rate(), 'time_trace': time_trace, 'trace_start_time': self.get_trace_start_time()} return pickle.dumps(data, protocol=4)
[docs] def deserialize(self, data_pkl): data = pickle.loads(data_pkl) self.set_trace(data['time_trace'], data['sampling_rate']) if 'trace_start_time' in data.keys(): self.set_trace_start_time(data['trace_start_time'])
[docs] def add_to_trace(self, channel): """ Adds the trace of another channel to the trace of this channel. The trace is only added within the time window of "this" channel. If this channel is an empty trace with a defined _sampling_rate and _trace_start_time, and a _time_trace containing zeros, this function can be seen as recording a channel in the specified readout window. Parameters ---------- channel: BaseTrace The channel whose trace is to be added to the trace of this channel. """ assert self.get_number_of_samples() is not None, "No trace is set for this channel" assert self.get_sampling_rate() == channel.get_sampling_rate(), "Sampling rates of the two channels do not match" tt_readout = self.get_times() t0_readout = self.get_trace_start_time() t1_readout = tt_readout[-1] sampling_rate_readout = self.get_sampling_rate() n_samples_readout = self.get_number_of_samples() tt_channel = channel.get_times() t0_channel = channel.get_trace_start_time() t1_channel = tt_channel[-1] sampling_rate_channel = channel.get_sampling_rate() n_samples_channel = channel.get_number_of_samples() # We handle 1+2x2 cases: # 1. Channel is completely outside readout window: if t1_channel < t0_readout or t1_readout < t0_channel: return # 2. Channel starts before readout window: if t0_channel < t0_readout: i_start_readout = 0 t_start_readout = t0_readout i_start_channel = int((t0_readout-t0_channel) * sampling_rate_channel) + 1 # The first bin of channel inside readout t_start_channel = tt_channel[i_start_channel] # 3. Channel starts after readout window: elif t0_channel >= t0_readout: i_start_readout = int((t0_channel-t0_readout) * sampling_rate_readout) # The bin of readout right before channel starts t_start_readout = tt_readout[i_start_readout] i_start_channel = 0 t_start_channel = t0_channel # 4. Channel ends after readout window: if t1_channel >= t1_readout: i_end_readout = n_samples_readout - 1 t_end_readout = t1_readout i_end_channel = int((t1_readout - t0_channel) * sampling_rate_channel) + 1 # The bin of channel right after readout ends t_end_channel = tt_channel[i_end_channel] # 5. Channel ends before readout window: elif t1_channel < t1_readout: i_end_readout = int((t1_channel - t0_readout) * sampling_rate_readout) # The bin of readout right before channel ends t_end_readout = tt_readout[i_end_readout] i_end_channel = n_samples_channel - 1 t_end_channel = t1_channel # Determine the remaining time between the binning of the two traces and use time shift as interpolation: residual_time_offset = t_start_channel - t_start_readout tmp_channel = copy.deepcopy(channel) tmp_channel.apply_time_shift(residual_time_offset) trace_to_add = tmp_channel.get_trace()[i_start_channel:i_end_channel] # Add the trace to the original trace: original_trace = self.get_trace() original_trace[i_start_readout:i_end_readout] += trace_to_add self.set_trace(original_trace, sampling_rate_readout)
def __add__(self, x): """ Redefine the "+" operator for BaseTrace objects. The operation will return a new BaseTrace object containing the sum of the two traces. If the two traces have different sampling rates, one of them is upsampled to the higher sampling rate. """ # Some sanity checks if not isinstance(x, BaseTrace): raise TypeError('+ operator is only defined for 2 BaseTrace objects') if self.get_trace() is None or x.get_trace() is None: raise ValueError('One of the trace objects has no trace set') if self.get_trace().ndim != x.get_trace().ndim: raise ValueError('Traces have different dimensions') if self.get_sampling_rate() != x.get_sampling_rate(): # Upsample trace with lower sampling rate # Create new baseTrace object for the resampling so we don't change the originals if self.get_sampling_rate() > x.get_sampling_rate(): upsampled_trace = BaseTrace() upsampled_trace.set_trace(x.get_trace(), x.get_sampling_rate()) upsampled_trace.resample(self.get_sampling_rate()) trace_1 = copy.copy(self.get_trace()) trace_2 = upsampled_trace.get_trace() sampling_rate = self.get_sampling_rate() else: upsampled_trace = BaseTrace() upsampled_trace.set_trace(self.get_trace(), self.get_sampling_rate()) upsampled_trace.resample(x.get_sampling_rate()) trace_1 = upsampled_trace.get_trace() trace_2 = copy.copy(x.get_trace()) sampling_rate = x.get_sampling_rate() else: trace_1 = copy.copy(self.get_trace()) trace_2 = copy.copy(x.get_trace()) sampling_rate = self.get_sampling_rate() # Figure out which of the traces has the earlier trace start time if self.get_trace_start_time() <= x.get_trace_start_time(): first_trace = trace_1 second_trace = trace_2 trace_start = self.get_trace_start_time() else: first_trace = trace_2 second_trace = trace_1 trace_start = x.get_trace_start_time() # Calculate the difference in the trace start time between the traces and the number of # samples that time difference corresponds to time_offset = np.abs(x.get_trace_start_time() - self.get_trace_start_time()) i_start = int(round(time_offset * sampling_rate)) # We have to distinguish 2 cases: Trace is 1D (channel) or 2D(E-field) # and treat them differently if trace_1.ndim == 1: # Calculate length the new trace needs to hold both input traces trace_length = max(first_trace.shape[0], i_start + second_trace.shape[0]) # Make sure trace has an even number of samples trace_length += trace_length % 2 # Put both pulses at the start of their own traces for now. We correct for different start times later early_trace = np.zeros(trace_length) early_trace[:first_trace.shape[0]] = first_trace late_trace = np.zeros(trace_length) late_trace[:second_trace.shape[0]] = second_trace else: # Same as in the if bracket, but for a 2D trace (like an E-field) trace_length = max(first_trace.shape[1], i_start + second_trace.shape[1]) trace_length += trace_length % 2 early_trace = np.zeros((first_trace.shape[0], trace_length)) early_trace[:, :first_trace.shape[1]] = first_trace late_trace = np.zeros((second_trace.shape[0], trace_length)) late_trace[:, :second_trace.shape[1]] = second_trace # Correct for different trace start times by using fourier shift theorem to # shift the later trace backwards. late_trace_object = BaseTrace() late_trace_object.set_trace(late_trace, sampling_rate) late_trace_object.apply_time_shift(time_offset, True) # Create new BaseTrace object holding the summed traces new_trace = BaseTrace() new_trace.set_trace(early_trace + late_trace_object.get_trace(), sampling_rate) new_trace.set_trace_start_time(trace_start) return new_trace def __mul__(self, x): if isinstance(x, numbers.Number): if self._time_trace is not None: self._time_trace *= x return self if self._frequency_spectrum is not None: self._frequency_spectrum *= x return self raise ValueError('Cant multiply baseTrace with number because no value is set for trace.') elif isinstance(x, NuRadioReco.detector.response.Response): return x * self # operation defined in detector.response.Response else: raise TypeError('Multiplication of baseTrace object with object of type {} is not defined'.format(type(x))) def __rmul__(self, x): return self.__mul__(x) def __truediv__(self, x): if isinstance(x, numbers.Number): if self._time_trace is not None: self._time_trace = self._time_trace / x return self if self._frequency_spectrum is not None: self._frequency_spectrum = self._frequency_spectrum / x return self raise ValueError('Cant divide baseTrace by number because no value is set for trace.') else: raise TypeError('Division of baseTrace object with object of type {} is not defined'.format(type(x)))
[docs]@functools.lru_cache(maxsize=1024) def get_frequencies(length, sampling_rate): return np.fft.rfftfreq(length, d=1. / sampling_rate)