from NuRadioReco.modules.base.module import register_run
from NuRadioReco.utilities import units, fft
# For typing
import NuRadioReco.framework.event
import NuRadioReco.framework.station
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.signal import lfilter
import numpy as np
import sys
import time
import logging
logger = logging.getLogger("NuRadioReco.modules.channelSinewaveSubtraction")
"""
This module provides a class for continuous wave (CW) noise filtering using sine subtraction.
In contrast to the module channelCWNOtchFilter, which uses a notch filter to remove CW noise.
"""
[docs]
class channelSinewaveSubtraction:
""" Continuous wave (CW) filter module. Uses sine subtraction based on scipy curve_fit. """
def __init__(self):
self.freq_band = None
self.save_filtered_freqs = None
self.begin()
[docs]
def begin(self, save_filtered_freqs: bool = False, freq_band: tuple[float, float] = (0.1, 0.7)) -> None:
"""
Initialize the CW filter module.
Parameters
----------
save_filtered_freqs: bool (default: False)
Flag to save the identified noise frequencies for each channel.
freq_band: tuple (default: (0.1, 0.7))
Frequency band to calculate baseline RMS of fft spectrum. Used to identify noise peaks.
0.1 to 0.7 GHz is the default for RNO-G, based on bandpass.
"""
self.save_filtered_freqs = [] if save_filtered_freqs else None
self.freq_band = freq_band
[docs]
@register_run()
def run(self, event: NuRadioReco.framework.event.Event, station: NuRadioReco.framework.station.Station,
det=None, peak_prominence: float = 4.0) -> None:
"""
Run the CW filter module on a given event and station. Removes all the CW peaks > peak_prominence * RMS.
Parameters
----------
event: `NuRadioReco.framework.event.Event`
Event object to process.
station: `NuRadioReco.framework.station.Station`
Station object to process.
det: `NuRadioReco.detector.detector.Detector` (default: None)
Detector object to process.
peak_prominence: float (default: 4.0)
Threshold for identifying prominent peaks in the FFT spectrum.
"""
for channel in station.iter_channels():
sampling_rate = channel.get_sampling_rate()
trace = channel.get_trace()
trace_fil = sinewave_subtraction(
trace, peak_prominence, sampling_rate=sampling_rate,
saved_noise_freqs=self.save_filtered_freqs, freq_band=self.freq_band)
channel.set_trace(trace_fil, sampling_rate)
[docs]
def get_filtered_frequencies(self):
""" Get the list of identified noise frequencies for each channel. """
return self.save_filtered_freqs
[docs]
def guess_amplitude(wf: np.ndarray, target_freq: float, sampling_rate: float = 3.2):
"""
Estimate the amplitude of a specific harmonic in the waveform.
Parameters
----------
wf: np.ndarray
Input waveform (1D array).
target_freq: float
Target frequency (GHz) for which to estimate amplitude.
sampling_rate: float (default: 3.2)
Sampling rate of the waveform (GHz).
Returns
-------
ampl: float
Estimated amplitude of the target frequency.
"""
if wf.size == 0:
raise ValueError("Input waveform is empty.")
if target_freq < 0 or target_freq > sampling_rate / 2:
raise ValueError("Target frequency is out of range (0 to Nyquist frequency).")
frequencies = fft.freqs(len(wf), sampling_rate)
# Here we intentionally use a different FFT normalization which retains the amplitude
# in the time domain.
amplitude_spectrum = np.abs(np.fft.rfft(wf, sampling_rate) * 2 / len(wf))
bin_index = np.argmin(np.abs(frequencies - target_freq))
amplitude = amplitude_spectrum[bin_index]
return amplitude
[docs]
def guess_amplitude_iir(wf: np.ndarray, target_freq: float, sampling_rate: float = 3.2):
"""
Estimate the amplitude of a specific frequency using an IIR filter representation of Goertzel.
Parameters
----------
wf: np.ndarray
Input waveform (1D array).
target_freq: float
Target frequency (GHz) to analyze.
sampling_rate: float (default: 3.2)
Sampling rate of the waveform (GHz).
Returns
-------
amplitude: float
Estimated amplitude at the target frequency.
"""
if np.any(np.isnan(wf)):
raise ValueError("Input signal contains NaNs!")
N = len(wf) # Number of samples
k = int(0.5 + (N * target_freq / sampling_rate)) # Frequency bin index
omega = (2.0 * np.pi * k) / N # Angular frequency
scaling_factor = N / 2.0
# IIR filter coefficients derived from Goertzel's difference equation
b = [1.0, 0, 0.0] # Numerator coefficients
a = [1.0, -2.0 * np.cos(omega), 1.0] # Denominator coefficients
# Apply the filter
filtered_signal = lfilter(b, a, wf)
# Extract last two values for amplitude estimation
s_prev = filtered_signal[-1]
s_prev2 = filtered_signal[-2]
# Compute real and imaginary parts of the signal at the target frequency
real = s_prev - s_prev2 * np.cos(omega)
imag = s_prev2 * np.sin(omega)
# Compute magnitude (amplitude)
amplitude = np.sqrt(real**2 + imag**2) / scaling_factor
return amplitude
[docs]
def guess_phase(fft_spec: np.ndarray, freqs: np.ndarray, target_freq: float):
"""
Estimate the phase of a specific frequency in the FFT spectrum.
Parameters
----------
fft_spec: np.ndarray
FFT spectrum of the waveform.
freq: np.ndarray
Frequency array corresponding to the FFT spectrum.
target_freq: float
Target frequency (GHz) for which to estimate phase.
sampling_rate: float (default: 3.2)
Sampling rate of the waveform (GHz).
Returns
-------
phase: float
Estimated phase of the target frequency.
"""
# Find phase of the target frequency
bin_index = np.argmin(np.abs(freqs - target_freq))
phase = np.angle(fft_spec[bin_index])
return phase
[docs]
def sinewave_subtraction(wf: np.ndarray, peak_prominence: float = 4.0, sampling_rate: float = 3.2,
saved_noise_freqs: list = None, freq_band: tuple = (0.1, 0.7)):
"""
Perform sine subtraction on a waveform to remove CW noise.
Parameters
----------
wf: np.ndarray
Input waveform (1D array).
sampling_rate: float (default: 3.2)
Sampling rate of the waveform (GHz).
peak_prominance: float (default: 6.0)
Threshold for identifying prominent peaks in the FFT spectrum.
saved_noise_freqs: list (default: None)
A list to store identified noise frequencies for each channel.
freq_band: tuple (default for RNO-g: (0.1, 0.7))
Frequency band to calculate baseline RMS of fft spectrum. Used to identify noise peaks.
Returns
-------
np.ndarray
Corrected waveform with CW noise removed.
"""
dt = 1 / sampling_rate # in ns
t = np.arange(0, len(wf) * dt, dt) # in ns
# zero meaning, just in case
wf = wf - np.mean(wf)
def sinusoid(t, amplitude, noise_frequency, phase):
return amplitude * np.sin(2 * np.pi * noise_frequency * t + phase + np.pi/2)
spec_complex = fft.time2freq(wf, sampling_rate) # need later to estimate phase
spec = abs(spec_complex)
freqs = fft.freqs(len(wf), sampling_rate)
# find total power of the original waveform
power_orig = np.sum(spec ** 2)
# find noise frequencies:
# frequency range for RMS calculation, defined by bandpass
f_min, f_max = freq_band
# Mask frequencies within the range
band_mask = (freqs >= f_min) & (freqs <= f_max)
filtered_spec = spec[band_mask]
# Compute RMS in the selected frequency band
rms_band = np.sqrt(np.mean(filtered_spec ** 2))
# Find noise peaks based on this band-limited RMS
peak_idxs = np.where(spec > peak_prominence * rms_band)[0]
noise_freqs = []
corrected_waveform = wf.copy()
# find mean CW freq bin
if len(peak_idxs) > 0:
# Initialize a list to hold groups of neighboring peak indices
group = [peak_idxs[0]]
# Loop through the remaining peak indices to group neighboring peaks
for i in range(1, len(peak_idxs)):
if peak_idxs[i] - peak_idxs[i - 1] == 1: # If the peak is neighboring the previous one
group.append(peak_idxs[i])
else:
# Calculate the mean frequency for the current group of neighbors
noise_freqs.append(np.mean(freqs[group]))
# Start a new group with the current peak
group = [peak_idxs[i]]
# Don't forget to append the last group
if group:
noise_freqs.append(np.mean(freqs[group]))
# Convert the list to a NumPy array (optional, if you prefer an array)
noise_freqs = np.array(noise_freqs)
if saved_noise_freqs is not None:
saved_noise_freqs.append(noise_freqs)
for noise_freq in noise_freqs:
ampl_guess = guess_amplitude_iir(wf, noise_freq, sampling_rate)
phase = guess_phase(spec_complex, freqs, noise_freq)
initial_guess = [ampl_guess, noise_freq, phase]
# Fit the sinusoidal model to the waveform
try:
params, covariance = curve_fit(sinusoid, t, wf, p0=initial_guess)
# Check if any parameters are NaN or Inf
if np.any(np.isnan(params)) or np.any(np.isinf(params)):
raise RuntimeError("Fit returned invalid parameters.")
estimated_amplitude, estimated_freq, estimated_phase = params
# Check if the covariance matrix is invalid
if np.all(np.isinf(covariance)) or np.all(np.isnan(covariance)):
raise RuntimeError("Fit covariance matrix is invalid, fit may not have converged.")
# Generate the estimated CW noise
estimated_cw_noise = sinusoid(t, estimated_amplitude, estimated_freq,estimated_phase)
logger.info(f"Subtract sinewave with a frequency: {estimated_freq / units.MHz:.1f} MHz, "
f"an amplitude: {estimated_amplitude:.1e} V/GHz and a phase: {estimated_phase / units.deg:.1f} deg")
# Subtract the estimated CW noise
corrected_waveform -= estimated_cw_noise
power_after_subtraction = np.sum(abs(fft.time2freq(corrected_waveform, sampling_rate)) ** 2)
logger.info(f"Power reduction: {100 * (1 - power_after_subtraction / power_orig):.1f}%")
if power_orig < power_after_subtraction:
logger.warning("Power increased after subtraction. Skipping this frequency.")
corrected_waveform += estimated_cw_noise
raise RuntimeError("Power increased after subtraction. Reverse subtraction.")
except RuntimeError:
logger.error(f"Curve fitting failed for frequency: {noise_freq / units.MHz} MHz")
else:
saved_noise_freqs.append([])
return corrected_waveform
[docs]
def plot_ft(channel, ax, label=None, plot_kwargs=dict()):
"""
Function to plot real frequency spectrum of given channel
Parameters
----------
channel: `NuRadioReco.framework.channel.Channel`
Channel from which to get trace
ax: matplotlib.axes
ax on which to plot
label: string
plotlabel
plot_kwargs: dict
options for plotting
"""
freqs = channel.get_frequencies()
spec = channel.get_frequency_spectrum()
legendloc = 2
ax.plot(freqs / units.MHz, np.abs(spec), label=label, **plot_kwargs)
ax.set_xlabel("freq / MHz")
ax.set_ylabel("amplitude / V/GHz")
ax.set_yscale("log")
ax.set_ylim(np.mean(np.abs(spec)) / 100, None)
ax.legend(loc=legendloc)
[docs]
def plot_trace(channel, ax, label=None, plot_kwargs=dict()):
"""
Function to plot trace of given channel.
Parameters
----------
channel: `NuRadioReco.framework.channel.Channel`
Channel from which to get trace
ax: matplotlib.axes
ax on which to plot
fs: float, default = 3.2 Hz
sampling frequency
label: string
plotlabel
plot_kwargs: dict
options for plotting
"""
times = channel.get_times()
trace = channel.get_trace()
ax.plot(times, trace, label=label, **plot_kwargs)
ax.set_xlabel("time / ns")
ax.set_ylabel("trace / V")
ax.legend(loc=2)
if __name__ == "__main__":
from NuRadioReco.modules.io.RNO_G.readRNOGDataMattak import readRNOGData
import argparse
import os
parser = argparse.ArgumentParser(prog="%(prog)s", usage="cw filter test")
parser.add_argument("--station", type=int, default=13)
parser.add_argument("--channel", type = int, default = 0)
parser.add_argument("--run", type=int, default=104)
args = parser.parse_args()
data_dir = os.environ["RNO_G_DATA"] # used deep CR burn sample..
rnog_reader = readRNOGData(log_level = logging.INFO)
root_dirs = f"{data_dir}/station{args.station}/run{args.run}"
rnog_reader.begin(root_dirs,
# linear voltage calibration
convert_to_voltage=False,
mattak_kwargs=dict(backend="uproot"))
sub = channelSinewaveSubtraction()
sub.begin(save_filtered_freqs=True)
ev_num = 66
logger.setLevel(logging.DEBUG)
for event in rnog_reader.run():
if event.get_id() == ev_num:
station_id = event.get_station_ids()[0]
station = event.get_station(station_id)
fig, axs = plt.subplots(1, 2, figsize=(14, 6))
plot_trace(station.get_channel(args.channel), axs[0], label="before", plot_kwargs={"lw": 2})
plot_ft(station.get_channel(args.channel), axs[1], label="before", plot_kwargs={"lw": 2})
sub.run(event, station, det=0)
plot_trace(station.get_channel(args.channel), axs[0], label="after", plot_kwargs={"lw": 1})
plot_ft(station.get_channel(args.channel), axs[1], label="after", plot_kwargs={"lw": 1})
# save plot into the current dir
current_dir = os.path.dirname(os.path.abspath(__file__))
fig.savefig(current_dir + "/test_cw_filter", bbox_inches="tight")