Source code for NuRadioReco.modules.LOFAR.planeWaveDirectionFitter_LOFAR

"""
This module has been adapted from pycrtools.modules.tasks.directionfitplanewave and
NuRadioReco.modules.LOFAR.beamformingDirectionFitter_LOFAR

.. moduleauthor:: Philipp Laub <philipp.laub@fau.de>
"""

import logging
import numpy as np
import matplotlib.pyplot as plt
import radiotools.helper as hp


from NuRadioReco.utilities import units
from NuRadioReco.framework.parameters import stationParameters, channelParameters, showerParameters
from NuRadioReco.modules.base.module import register_run
from NuRadioReco.modules.LOFAR.beamforming_utilities import geometric_delay_far_field, lightspeed


[docs] def average_direction(event, detector, mode='normal'): """ Calculate the average direction for an event based on the plane wave directions of the individual stations. Parameters ---------- event : Event object The event for which to calculate the average direction. detector : Detector object The detector for which to calculate the average direction. mode : str, default='normal' The mode to use for the calculation. Can be 'normal' (just raw mean) or 'weighted' (with number of good antennas as weight per station). Returns ------- avg_zenith : float The average zenith angle for the event. avg_azimuth : float The average azimuth angle for the event. """ zeniths = [] azimuths = [] num_good_antennas = [] for station in event.get_stations(): if station.get_parameter(stationParameters.triggered): flagged_channels = station.get_parameter(stationParameters.flagged_channels) num_good_antennas.append( detector.get_number_of_channels(station.get_id()) - len(flagged_channels) ) zeniths.append(station.get_parameter(stationParameters.cr_zenith)) azimuths.append(station.get_parameter(stationParameters.cr_azimuth)) zeniths = np.array(zeniths) azimuths = np.array(azimuths) num_good_antennas = np.array(num_good_antennas) # Calculate the average direction: if mode == 'normal': avg_zenith = np.mean(zeniths) avg_azimuth = np.mean(azimuths) elif mode == 'weighted': avg_zenith = np.sum(zeniths * num_good_antennas) / np.sum(num_good_antennas) avg_azimuth = np.sum(azimuths * num_good_antennas) / np.sum(num_good_antennas) else: raise ValueError(f"Unknown mode: {mode}") return avg_zenith, avg_azimuth
[docs] class planeWaveDirectionFitter: """ Fits the direction per station using timing differences of channels under the assumption of an incoming plane wave. """ def __init__(self): self.logger = logging.getLogger("NuRadioReco.planeWaveDirectionFitter") self.__cr_snr = None self.__logger_level = None self.__debug = None self.__window_size = None self.__ignore_non_horizontal_array = None self.__rmsfactor = None self.__min_amp = None self.__max_iter = None self.__min_number_good_antennas = None
[docs] def begin(self, max_iter=10, cr_snr=6.5, min_amp=None, rmsfactor=2.0, force_horizontal_array=True, debug=False, logger_level=logging.NOTSET, min_number_good_antennas=4): """ Set the parameters for the plane wave fit. Parameters ---------- max_iter : int, default=10 The maximum number of iterations to use during the fitting procedure. cr_snr : float, default=3 The minimum SNR a channel should have to be considered having a cosmic ray signal. Ignored if min_amp is not None. min_amp : float, default=0.001 The minimum amplitude a channel should have to be considered having a cosmic ray signal. Set to None if you want to use the SNR instead. rmsfactor : float, default=2.0 How many sigma (times RMS) above the average can a delay deviate from the expected timelag (from latest fit iteration) before it is considered bad and removed as outlier. force_horizontal_array : bool, default=True Set to True when you know the array is non-horizontal (z > 0.5) but want to use the horizontal approximation anyway. Recommended to set to True. debug : bool, default=False Set to True to enable debug plots. logger_level : int, default=logging.WARNING The logging level to use for the module. min_number_good_antennas : int, default=4 The minimum number of good antennas that should be present in a station to consider it for the fit. """ self.__max_iter = max_iter self.__cr_snr = cr_snr self.__min_amp = min_amp self.__rmsfactor = rmsfactor self.__ignore_non_horizontal_array = force_horizontal_array self.__debug = debug self.__logger_level = logger_level self.logger.setLevel(logger_level) self.__min_number_good_antennas = min_number_good_antennas
@staticmethod def _get_timelags(station, channel_ids_dominant_pol): """ Get timing differences between signals in antennas with respect to some reference antenna (the first one in the list of ids). The peak is determined using the Hilbert envelope after resampling the trace with `resample_factor`. Parameters ---------- station : Station object The station for which to get the time lags channel_ids_dominant_pol : list of int The list of channel ids to return the time lags for (usually the dominant polarisation) Returns ------- timelags : np.ndarray The timelags (in internal units) for each channel in the list, with respect to the first one """ # Get the signal time found by stationPulseFinder timelags = [] for channel_id in channel_ids_dominant_pol: timelags.append(station.get_channel(channel_id).get_parameter(channelParameters.signal_time)) timelags -= timelags[0] # get timelags wrt 1st antenna return np.asarray(timelags) @staticmethod def _direction_horizontal_array(positions: np.ndarray, times: np.ndarray, ignore_z_coordinate=False): r""" --- adapted from pycrtools.modules.scrfind --- Given N antenna positions, and (pulse) arrival times for each antenna, get a direction of arrival (azimuth, zenith) assuming a source at infinity (plane wave). Here, we find the direction assuming all antennas are placed in the z=0 plane. If all antennas are co-planar, the best-fitting solution can be found using a 2D-linear fit. We find the best-fitting A and B in: .. math:: t = A x + B y + C where t is the array of times; x and y are arrays of coordinates of the antennas. The C is the overall time offset in the data, that has to be subtracted out. The optimal value of C has to be determined in the fit process (it's not just the average time, nor the time at antenna 0). This is done using :mod:`numpy.linalg.lstsq`. The (azimuth, zenith) follows from: .. math:: A = \sin(\mathrm{zenith}) \cos(\mathrm{azimuth}) B = \sin(\mathrm{zenith}) \sin(\mathrm{azimuth}) Parameters ---------- positions : np.ndarray Positions (x,y,z) of the antennas (shape: (N_antennas, 3)) times : array, float Measured pulse arrival times for all antennas Returns ------- zenith : float Zenith in the [0, 2pi] interval (given in internal units) azimuth : float Azimuth in the [0, 2pi] interval (given in internal units) """ # make x, y arrays out of the input position array x = positions[:, 0] y = positions[:, 1] # now a crude test for nonzero z-input, |z| > 0.5 z = positions[:, 2] if not ignore_z_coordinate and max(abs(z)) > 0.5: raise ValueError("Input values of z are nonzero ( > 0.5) !") M = np.vstack([x, y, np.ones(len(x))]).T # says the linalg.lstsq doc A, B, C = np.linalg.lstsq(M, lightspeed * times, rcond=None)[0] zenith = np.arcsin(np.sqrt(A**2 + B**2)) # TODO: this can result in RuntimeWarning - why? azimuth = np.arctan2(-B, -A) # note minus sign as we want the direction of the _incoming_ vector (from the sky, not towards it) return np.mod(zenith * units.rad, 360 * units.deg), np.mod(azimuth * units.rad, 360 * units.deg)
[docs] @register_run() def run(self, event, detector): """ Run the plane wave fit for the given event and detector. Parameters ---------- event : Event object The event for which to run the plane wave fit. detector : Detector object The detector for which to run the plane wave fit. """ for station in event.get_stations(): if not station.get_parameter(stationParameters.triggered): self.logger.debug(f"Station CS{station.get_id():03d} did not trigger, skipping...") continue self.logger.debug(f"Running over station CS{station.get_id():03d}") # get LORA initial guess for the direction lora_zenith = event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.zenith) lora_azimuth = event.get_hybrid_information().get_hybrid_shower("LORA").get_parameter(showerParameters.azimuth) # Get all group IDs which are still present in the station station_channel_group_ids = set([channel.get_group_id() for channel in station.iter_channels()]) # Get the dominant polarisation orientation as calculated by stationPulseFinder dominant_orientation = station.get_parameter(stationParameters.cr_dominant_polarisation) # Collect the positions of 'good' antennas good_channel_pair_ids = np.zeros((len(station_channel_group_ids), 2), dtype=int) relative_position_array = np.zeros((len(station_channel_group_ids), 3)) good_amp_or_snr = np.zeros(len(station_channel_group_ids), dtype=bool) for ind, channel_group_id in enumerate(station_channel_group_ids): relative_position_array[ind] = detector.get_relative_position(station.get_id(), channel_group_id) for channel in station.iter_channel_group(channel_group_id): if np.all(detector.get_antenna_orientation(station.get_id(), channel.get_id()) == dominant_orientation): good_channel_pair_ids[ind, 0] = channel.get_id() else: good_channel_pair_ids[ind, 1] = channel.get_id() # Check if dominant channel has acceptable SNR or acceptable amplitude (if desired) channel = station.get_channel(good_channel_pair_ids[ind, 0]) if self.__min_amp is None: if channel.get_parameter(channelParameters.SNR) > self.__cr_snr: good_amp_or_snr[ind] = True else: if np.max(np.abs(channel.get_trace())) >= self.__min_amp: good_amp_or_snr[ind] = True num_good_antennas = np.sum(good_amp_or_snr) mask_good_antennas = np.full(num_good_antennas, True) # the dominant antennas are good_antennas[:, 0] good_antennas = good_channel_pair_ids[good_amp_or_snr] position_array = relative_position_array[good_amp_or_snr] # iteratively do the plane wave fit and remove outliers (controlled by rmsfactor) # until the number of good antennas remains constant niter = 0 zenith, azimuth = lora_zenith, lora_azimuth while niter < self.__max_iter: # TODO: maybe add additional condition? niter += 1 # if only three antennas (or less) remain, fit should not be trusted as it always has a solution (fails) if num_good_antennas < self.__min_number_good_antennas: self.logger.warning(f"Only {num_good_antennas:d} good antennas remaining!") self.logger.error(f"Too few good antennas for direction fit!") break # update arrays to use only previously found "good" antennas: position_array = position_array[mask_good_antennas] good_antennas = good_antennas[mask_good_antennas] # get time lags from the dominant antennas only times = self._get_timelags(station, good_antennas[:, 0]) goodpositions = position_array goodtimes = times zenith, azimuth = self._direction_horizontal_array(goodpositions, goodtimes, self.__ignore_non_horizontal_array) # get residuals expected_delays = geometric_delay_far_field( goodpositions, hp.spherical_to_cartesian(zenith / units.rad, azimuth / units.rad) ) expected_delays -= expected_delays[0] # get delays wrt 1st antenna residual_delays = goodtimes - expected_delays # Debug plots if required if self.__debug: self.debug_plots( event, expected_delays, good_antennas, niter, position_array, residual_delays, station, times ) if np.isnan(zenith) or np.isnan(azimuth): self.logger.error( 'Plane wave fit returns NaN. I will try to recover by setting zenith and azimuth ' 'to the LORA estimate and recalculating the residual delays.' ) zenith = lora_zenith azimuth = lora_azimuth expected_delays = geometric_delay_far_field( goodpositions, hp.spherical_to_cartesian(zenith / units.rad, azimuth / units.rad) ) expected_delays -= expected_delays[0] residual_delays = goodtimes - expected_delays bins = int( (residual_delays.max() - residual_delays.min()) * lightspeed / (position_array[:, 0].max() - position_array[:, 0].min()) ) hist, edges = np.histogram(residual_delays, bins=max(bins, 1)) max_time = np.argmax(hist) self.logger.debug(f"histogram filled: {hist}") self.logger.debug(f"edges: {edges}") self.logger.debug(f"maximum at: {max_time}") upper = edges[min(max_time + 2, len(edges) - 1)] lower = edges[max(max_time - 1, 0)] self.logger.debug(f"Selecting between lower {lower} and upper {upper}") mask_good_antennas = (residual_delays > lower) & (residual_delays < upper) else: # remove > k-sigma outliers and iterate spread = np.std(residual_delays) k = self.__rmsfactor mask_good_antennas = abs(residual_delays - np.mean(residual_delays)) < k * spread # gives subset of 'good_antennas' that is 'good' after this iteration self.logger.debug(f"station {station.get_id()}:") self.logger.debug(f"iteration {niter:d}:") self.logger.debug(f'azimuth = {np.rad2deg(azimuth):.3f}, zenith = {np.rad2deg(zenith):.3f}') self.logger.debug(f'number of good antennas = {num_good_antennas:d}') # if the next iteration has the same number of good antennae the while loop will be terminated if len(good_antennas[mask_good_antennas]) == num_good_antennas: break else: num_good_antennas = len(good_antennas[mask_good_antennas]) self.logger.status( f"Azimuth (counterclockwise wrt to East) and zenith for station CS{station.get_id():03d}:" ) self.logger.status(f"{azimuth / units.deg}, {zenith / units.deg}") self.logger.status( f"Azimuth (clockwise wrt to North) and elevation for station CS{station.get_id():03d}:" ) self.logger.status(f"{90 - azimuth / units.deg}, {90 - zenith / units.deg}") # Set stationParameters.zenith/azimuth because voltageToEfieldConverter uses these to convert # NOTE: these can be the LORA direction (in case the fit failed) station.set_parameter(stationParameters.zenith, zenith) station.set_parameter(stationParameters.azimuth, azimuth) # Only set reconstructed direction if it is not identical to the LORA direction if not (zenith == lora_zenith and azimuth == lora_azimuth): self.logger.info( f"The fit for station CS{station.get_id():03d} seems to have failed." f"I will not set the cr_zenith and cr_azimuth station parameters, but you can" f"still unfold the voltages to electric fields with the LORA direction as this" f"is saved in the zenith and azimuth station parameters." ) station.set_parameter(stationParameters.cr_zenith, zenith) station.set_parameter(stationParameters.cr_azimuth, azimuth) # flag channels that were not used in the fit station_flagged_channels = station.get_parameter(stationParameters.flagged_channels) for channel_id in good_channel_pair_ids.flatten().tolist(): if channel_id not in good_antennas.flatten(): # TODO: this flag is not always correct, as channels excluded by SNR are also flagged station_flagged_channels[channel_id].append("planewavefit_timing_outlier") station.set_parameter(stationParameters.flagged_channels, station_flagged_channels)
[docs] @staticmethod def debug_plots( event, expected_delays, good_antennas, niter, position_array, residual_delays, station, times ): """ Create debug plots for the plane wave fit. """ planeWaveDirectionFitter.__debug_mosaic( event, expected_delays, good_antennas, niter, position_array, residual_delays, station, times ) planeWaveDirectionFitter.__debug_residuals( event, good_antennas, residual_delays, station, niter )
@staticmethod def __debug_residuals(event, good_antennas, residual_delays, station, niter): """ Show the residuals per antenna and mark SNR """ fig, ax = plt.subplots() antenna_SNRs = np.zeros(len(good_antennas)) for i, antenna in enumerate(good_antennas[:, 0]): channel = station.get_channel(antenna) antenna_SNRs[i] = channel.get_parameter(channelParameters.SNR) plt.scatter(np.arange(len(residual_delays)), residual_delays, marker='o', c=antenna_SNRs) # add colorbar plt.colorbar(label='SNR') ax.set_xlabel('Antenna') ax.set_ylabel('Residual time [ns]') ax.set_title(f'Residuals for station {station.get_id()}') fig.savefig( f"pipeline_planewavefit_residuals_CS{station.get_id():03d}_iteration{niter}_{event.get_id()}.png", dpi=250, bbox_inches='tight' ) # fig.savefig( # f"pipeline_planewavefit_residuals_CS{station.get_id():03d}_iteration{niter}_{event.get_id()}.svg", # dpi=250, bbox_inches='tight' # ) plt.close(fig) @staticmethod def __debug_mosaic(event, expected_delays, good_antennas, niter, position_array, residual_delays, station, times): """ Plot the timings, as well as the residuals and the traces used for the fit. """ import matplotlib as mpl inner = [['times'], ['expected']] outer = [[inner, 'residuals'], ['traces', 'traces']] fig, axd = plt.subplot_mosaic(outer, layout="constrained", figsize=(10, 12)) for channel in station.iter_channels(use_channels=good_antennas[:, 0]): axd['traces'].plot(channel.get_trace() / units.mV) # mark the signal window: channel = station.get_channel(good_antennas[0, 0]) # should be all the same pulse_window_start, pulse_window_end = channel.get_parameter(channelParameters.signal_regions) axd['traces'].axvline(pulse_window_start, color='r') axd['traces'].axvline(pulse_window_end, color='r') axd['traces'].set_xlim(pulse_window_start - 500, pulse_window_end + 500) axd['traces'].set_xlabel('Sample index') axd['traces'].set_ylabel('Amplitude [mV]') axd['traces'].set_title(f'Good traces used in iteration {niter}') # Plot the timing residuals datasets = [times, expected_delays] norm1 = mpl.colors.Normalize(vmin=np.min(datasets), vmax=np.max(datasets)) norm2 = mpl.colors.Normalize(vmin=np.min(residual_delays), vmax=np.max(residual_delays)) cmap1 = mpl.colormaps.get_cmap('viridis') cmap2 = mpl.colormaps.get_cmap('seismic') axd['times'].scatter(position_array[:, 0], position_array[:, 1], c=times, norm=norm1, cmap=cmap1, label='Measured') axd['times'].set_title("Time delays") axd['times'].legend() axd['expected'].scatter(position_array[:, 0], position_array[:, 1], c=expected_delays, norm=norm1, cmap=cmap1, label='Expected') axd['expected'].legend() axd['residuals'].scatter(position_array[:, 0], position_array[:, 1], c=residual_delays, norm=norm2, cmap=cmap2) axd['residuals'].set_title("Residual time delays") fig.colorbar(mpl.cm.ScalarMappable(norm=norm1, cmap=cmap1), ax=axd['expected'], orientation='horizontal', location='bottom', label='Time [ns]') fig.colorbar(mpl.cm.ScalarMappable(norm=norm2, cmap=cmap2), ax=axd['residuals'], orientation='horizontal', location='bottom', label='Time [ns]') fig.suptitle(f"Station {station.get_id()}") for ax in ['expected', 'residuals']: axd[ax].set_xlabel("Easting [m]") axd['residuals'].yaxis.set_label_position("right") axd['residuals'].set_ylabel("Northing [m]") axd['residuals'].set_aspect('equal') fig.savefig( f"pipeline_planewavefit_debug_CS{station.get_id():03d}_iteration{niter}_{event.get_id()}.png", dpi=250, bbox_inches='tight' ) # fig.savefig( # f"pipeline_planewavefit_debug_CS{station.get_id():03d}_iteration{niter}_{event.get_id()}.svg", # dpi=250, bbox_inches='tight' # ) plt.close(fig)
[docs] def end(self): pass