Source code for NuRadioReco.modules.channelCWNotchFilter

import logging
logger = logging.getLogger("NuRadioReco.channelCWNotchFilter")
import time
import numpy as np
from scipy import signal
from NuRadioReco.utilities import units
from NuRadioReco.utilities import fft

"""
Contains module to filter continuous wave out of the signal using notch filters
on peaks in frequency spectrum
"""


[docs]def find_frequency_peaks_from_trace(trace : np.ndarray, fs : float, threshold : float = 4): """ Function fo find the frequency peaks in the real fourier transform of the input trace. Parameters ---------- trace : np.ndarray Waveform fs : float Sampling frequency, (input should be taking from the channel object) threshold : float, default = 4 Threshold for peak definition. A peak is defined as a point in the frequency spectrum that exceeds threshold * rms(real fourier transform) Returns ------- freq_peaks : np.ndarray Frequencies at which a peak was found """ freq = np.fft.rfftfreq(len(trace), d=1/fs) ft = fft.time2freq(trace, fs) freq_peaks = find_frequency_peaks(freq, ft, fs=fs, threshold=threshold) return freq_peaks
[docs]def find_frequency_peaks(freq: np.ndarray, spectrum : np.ndarray, threshold : float = 4): """ Function fo find the frequency peaks in the real fourier transform of the input trace. Parameters ---------- freq : np.ndarray Frequencies of a NuRadio time trace spectrum : np.ndarray Spectrum of a NuRadio time trace threshold : float, default = 4 Threshold for peak definition. A peak is defined as a point in the frequency spectrum that exceeds threshold * rms(real fourier transform) Returns ------- freq : np.ndarray Frequencies at which a peak was found """ rms = np.sqrt(np.mean(np.abs(spectrum)**2)) peak_idxs = np.where(np.abs(spectrum) > threshold * rms)[0] return freq[peak_idxs]
[docs]def get_filter(freq : int, fs, quality_factor=1e3, cache=None): """ Function to get single notch filter for a given frequency. Parameters ---------- freq : np.ndarray Frequency fs : float, sampling frequency in MHz quality_factor : int, default = 1000 quality factor of the notch filter, defined as the ratio f0/bw, where f0 is the centre frequency and bw the bandwidth of the filter at (f0,-3 dB) cache : dict, default = None, Optional caching dictionary. The function will check whether the frequency to be filtered is in the dictionary values and will otherwise add it !!! Note this does not cache the quality factor information Returns ------- filter : list, shape (6) second order IIR notch filter at frequency freq """ if cache is not None: if freq in cache.keys(): return cache[freq] filter = signal.iirnotch(freq, quality_factor, fs=fs) if cache is not None: # Check to avoid cache dictionary overflowing the memory, # set to roughly stay below 6 MB (every filter is 6 floats + freq = 7 floats ~ 56B) if len(cache.keys()) < 1e5: cache[freq] = filter return filter
[docs]def filter_cws(trace : np.ndarray, freq : np.ndarray, spectrum : np.ndarray, fs : float, quality_factor=1e3, threshold=4, cache : dict = None, filters : list = None): """ Function that applies a notch filter at the frequency peaks of a given time trace using the scipy library Parameters ---------- trace : np.ndarray waveform (shape: [2048]) freq : np.ndarray Frequency of the trace's real fourier transform spectrum: the trace's real fourier transform fs : float sampling frequency in MHz quality_factor : int, default = 1000 quality factor of the notch filter, defined as the ratio f0/bw, where f0 is the centre frequency and bw the bandwidth of the filter at (f0,-3 dB) threshold : int, default = 4 threshold for peak definition. A peak is defined as a point in the frequency spectrum that exceeds threshold * rms(real fourier transform) cache : dict, default = None, Optional caching dictionary. The function will check whether the frequency to be filtered is in the dictionary values and will otherwise add it !!! Note this assumes the quality_factor is the same for all notch filters!!! filters : NoneType or list, default = None Optional list to which the filters used in this function can be appended for future reference Returns ------- trace : np.ndarray CW-filtered trace """ freqs = find_frequency_peaks(freq, spectrum, threshold=threshold) if len(freqs): # the array is reshaped to (nr_of_filters, nr_of_coefficients), since iirnotch is a second order IIR, # the nr_of_coefficients will be 6: 3 for the numerator and 3 for the denumerator, in that order notch_filters = np.array([get_filter(freq, fs, quality_factor, cache=cache) for freq in freqs]).reshape(-1, 6) if filters is not None: filters.append(notch_filters) logging.debug(f"Shape of notch filters for one channel is: {notch_filters.shape}") trace_notched = signal.sosfiltfilt(notch_filters, trace, padtype = None) return trace_notched else: # append empty list when filters is specified to ensure # filters list is shape 24 when looping over channels if filters is not None: filters.append([]) return trace
[docs]def plot_trace(channel, ax, fs=3.2e9*units.Hz, label=None, plot_kwargs=dict()): """ Function to plot trace of given channel Parameters ---------- channel : NuRadio channel class channel from which to get trace ax : matplotlib.axes ax on which to plot fs : float, default = 3.2 Hz sampling frequency label : string plotlabel plot_kwargs : dict options for plotting """ times = np.arange(2048)/fs / units.ns trace = channel.get_trace() legendloc = 2 ax.plot(times, trace, label=label, **plot_kwargs) ax.set_xlabel("time / ns") ax.set_ylabel("trace / V") ax.legend(loc=legendloc)
[docs]def plot_ft(channel, ax, label=None, plot_kwargs=dict()): """ Function to plot real frequency spectrum of given channel Parameters ---------- channel : NuRadio channel class channel from which to get trace ax : matplotlib.axes ax on which to plot label : string plotlabel plot_kwargs : dict options for plotting """ freqs = channel.get_frequencies() spec = channel.get_frequency_spectrum() legendloc = 2 ax.plot(freqs, np.abs(spec), label=label, **plot_kwargs) ax.set_xlabel("freq / GHz") ax.set_ylabel("amplitude / V/GHz") ax.legend(loc = legendloc)
[docs]class channelCWNotchFilter(): """ Continuous wave (CW) filter module. Uses notch filters from the scipy library """ def __init__(self): pass
[docs] def begin(self, quality_factor=1e3, threshold=4, save_filters=False): self.quality_factor = quality_factor self.threshold = threshold self.filters = [] if save_filters else None # dictionary to cache known notch filters at specific frequencies self.filter_cache = {}
[docs] def run(self, event, station, det): for channel in station.iter_channels(): fs = channel.get_sampling_rate() freq = channel.get_frequencies() spectrum = channel.get_frequency_spectrum() trace = channel.get_trace() trace_fil = filter_cws( trace, freq, spectrum, fs, quality_factor=self.quality_factor, threshold=self.threshold, cache=self.filter_cache, filters=self.filters) channel.set_trace(trace_fil, fs)
# Standard test for people playing around with module settings, applies the module as one would in a data reading pipeline # using one event in RNO_G_DATA (choose station and run) as a test if __name__ == "__main__": import os import logging import argparse import matplotlib.pyplot as plt from NuRadioReco.modules.io.RNO_G.readRNOGDataMattak import readRNOGData parser = argparse.ArgumentParser(prog="%(prog)s", usage="cw filter test") parser.add_argument("--station", type=int, default=24) parser.add_argument("--channel", type = int, default = 0) parser.add_argument("--run", type=int, default=1) parser.add_argument("--quality_factor", type=int, default=1e3) parser.add_argument("--threshold", type=int, default=4) parser.add_argument("--fs", type=float, default=3.2e9 * units.Hz) parser.add_argument("--save_dir", type=str, default=None, help="Directory where to save plot produced by the test.\ If None, saves to NuRadioReco test directory") args = parser.parse_args() data_dir = os.environ["RNO_G_DATA"] rnog_reader = readRNOGData(log_level = logging.DEBUG) root_dirs = f"{data_dir}/station{args.station}/run{args.run}" rnog_reader.begin(root_dirs, # linear voltage calibration convert_to_voltage=True, mattak_kwargs=dict(backend="uproot")) channelCWNotchFilter = channelCWNotchFilter() channelCWNotchFilter.begin(quality_factor=args.quality_factor, threshold=args.threshold) for event in rnog_reader.run(): station_id = event.get_station_ids()[0] station = event.get_station(station_id) fig, axs = plt.subplots(1, 2, figsize=(14, 6)) plot_trace(station.get_channel(args.channel), axs[0], label="before") plot_ft(station.get_channel(args.channel), axs[1], label="before") t0 = time.time() channelCWNotchFilter.run(event, station, det=0) logger.debug(f"Filter took {time.time() - t0} s to run.") plot_trace(station.get_channel(args.channel), axs[0], label="after") plot_ft(station.get_channel(args.channel), axs[1], label="after") if args.save_dir is None: fig_dir = os.path.abspath(f"{__file__}/../../test") else: fig_dir = args.save_dir fig.savefig(f"{fig_dir}/test_cw_filter", bbox_inches="tight") break