Source code for NuRadioReco.modules.correlationDirectionFitter

from __future__ import absolute_import, division, print_function, unicode_literals
from NuRadioReco.modules.base.module import register_run
from scipy import signal, fftpack
import matplotlib.pyplot as plt
import numpy as np
from NuRadioReco.utilities import geometryUtilities as geo_utl
from NuRadioReco.utilities import units
from NuRadioReco.framework.parameters import stationParameters as stnp
from NuRadioReco.framework.parameters import electricFieldParameters as efp
import scipy.optimize as opt
from radiotools import helper as hp
import logging


[docs]class correlationDirectionFitter: """ Fits the direction using correlation of parallel channels. """ def __init__(self): self.__zenith = [] self.__azimuth = [] self.__delta_zenith = [] self.__delta_azimuth = [] self.logger = logging.getLogger('NuRadioReco.correlationDirectionFitter') self.__debug = None self.begin()
[docs] def begin(self, debug=False, log_level=logging.NOTSET): self.logger.setLevel(log_level) self.__debug = debug
[docs] @register_run() def run(self, evt, station, det, n_index=None, ZenLim=None, AziLim=None, channel_pairs=((0, 2), (1, 3)), use_envelope=False): """ reconstruct signal arrival direction for all events Parameters ---------- evt: Event The event to run the module on station: Station The station to run the module on det: Detector The detector description n_index: float the index of refraction ZenLim: 2-dim array/list of floats (default: [0 * units.deg, 90 * units.deg]) the zenith angle limits for the fit AziLim: 2-dim array/list of floats (default: [0 * units.deg, 360 * units.deg]) the azimuth angle limits for the fit channel_pairs: pair of pair of integers specify the two channel pairs to use, default ((0, 2), (1, 3)) use_envelope: bool (default False) if True, the hilbert envelope of the traces is used """ if ZenLim is None: ZenLim = [0 * units.deg, 90 * units.deg] if AziLim is None: AziLim = [0 * units.deg, 360 * units.deg] use_correlation = True def ll_regular_station(angles, corr_02, corr_13, sampling_rate, positions, trace_start_times): """ Likelihood function for a four antenna ARIANNA station, using correction. Using correlation, has no built in wrap around, pulse needs to be in the middle """ zenith = angles[0] azimuth = angles[1] times = [] for pos in positions: tmp = [geo_utl.get_time_delay_from_direction(zenith, azimuth, pos[0], n=n_index), geo_utl.get_time_delay_from_direction(zenith, azimuth, pos[1], n=n_index)] times.append(tmp) delta_t_02 = times[0][1] - times[0][0] delta_t_13 = times[1][1] - times[1][0] # take different trace start times into account delta_t_02 -= (trace_start_times[0][1] - trace_start_times[0][0]) delta_t_13 -= (trace_start_times[1][1] - trace_start_times[1][0]) delta_t_02 *= sampling_rate delta_t_13 *= sampling_rate pos_02 = int(corr_02.shape[0] / 2 - delta_t_02) pos_13 = int(corr_13.shape[0] / 2 - delta_t_13) # weight_02 = np.sum(corr_02 ** 2) # Normalize crosscorrelation # weight_13 = np.sum(corr_13 ** 2) # # likelihood = -1 * (corr_02[pos_02] ** 2 / weight_02 + corr_13[pos_13] ** 2 / weight_13) # After deliberating a bit, I don't think we should use the square because anti-correlating # pulses would be wrong, given that it is not a continous waveform weight_02 = np.sum(np.abs(corr_02)) # Normalize crosscorrelation weight_13 = np.sum(np.abs(corr_13)) likelihood = -1 * (corr_02[pos_02] / weight_02 + corr_13[pos_13] / weight_13) return likelihood def ll_regular_station_fft(angles, corr_02_fft, corr_13_fft, sampling_rate, positions, trace_start_times): """ Likelihood function for a four antenna ARIANNA station, using FFT convolution Using FFT convolution, has built-in wrap around, but ARIANNA signals are too short for it to be accurate will show problems at zero time delay """ zenith = angles[0] azimuth = angles[1] times = [] for pos in positions: tmp = [geo_utl.get_time_delay_from_direction(zenith, azimuth, pos[0], n=n_index) * sampling_rate, geo_utl.get_time_delay_from_direction(zenith, azimuth, pos[1], n=n_index) * sampling_rate] times.append(tmp) delta_t_02 = (times[0][1] + trace_start_times[0][1] * sampling_rate) - (times[0][0] + trace_start_times[0][0] * sampling_rate) delta_t_13 = (times[1][1] + trace_start_times[1][1] * sampling_rate) - (times[1][0] + trace_start_times[1][0] * sampling_rate) if delta_t_02 < 0: pos_02 = int(delta_t_02 + corr_02_fft.shape[0]) else: pos_02 = int(delta_t_02) if delta_t_13 < 0: pos_13 = int(delta_t_13 + corr_13_fft.shape[0]) else: pos_13 = int(delta_t_13) weight_02 = np.sum(np.abs(corr_02_fft)) # Normalize crosscorrelation weight_13 = np.sum(np.abs(corr_13_fft)) likelihood = -1 * (np.abs(corr_02_fft[pos_02]) ** 2 / weight_02 + np.abs(corr_13[pos_13]) ** 2 / weight_13) return likelihood station_id = station.get_id() positions_pairs = [[det.get_relative_position(station_id, channel_pairs[0][0]), det.get_relative_position(station_id, channel_pairs[0][1])], [det.get_relative_position(station_id, channel_pairs[1][0]), det.get_relative_position(station_id, channel_pairs[1][1])]] sampling_rate = station.get_channel(channel_pairs[0][0]).get_sampling_rate() # assume that channels have the same sampling rate trace_start_time_pairs = [[station.get_channel(channel_pairs[0][0]).get_trace_start_time(), station.get_channel(channel_pairs[0][1]).get_trace_start_time()], [station.get_channel(channel_pairs[1][0]).get_trace_start_time(), station.get_channel(channel_pairs[1][1]).get_trace_start_time()]] # determine automatically if one channel has an inverted waveform with respect to the other signs = [1., 1.] for iPair, pair in enumerate(channel_pairs): antenna_type = det.get_antenna_type(station_id, pair[0]) if("LPDA" in antenna_type): otheta, ophi, rot_theta, rot_azimuth = det.get_antenna_orientation(station_id, pair[0]) otheta2, ophi2, rot_theta2, rot_azimuth2 = det.get_antenna_orientation(station_id, pair[1]) if(np.isclose(np.abs(rot_azimuth - rot_azimuth2), 180 * units.deg, atol=1 * units.deg)): signs[iPair] = -1 if use_correlation: # Correlation if not use_envelope: corr_02 = signal.correlate(station.get_channel(channel_pairs[0][0]).get_trace(), signs[0] * station.get_channel(channel_pairs[0][1]).get_trace()) corr_13 = signal.correlate(station.get_channel(channel_pairs[1][0]).get_trace(), signs[1] * station.get_channel(channel_pairs[1][1]).get_trace()) else: corr_02 = signal.correlate(np.abs(signal.hilbert(station.get_channel(channel_pairs[0][0]).get_trace())), np.abs(signal.hilbert(station.get_channel(channel_pairs[0][1]).get_trace()))) corr_13 = signal.correlate(np.abs(signal.hilbert(station.get_channel(channel_pairs[1][0]).get_trace())), np.abs(signal.hilbert(station.get_channel(channel_pairs[1][1]).get_trace()))) else: # FFT convolution corr_02_fft = fftpack.ifft(-1 * fftpack.fft(station.get_channel(channel_pairs[0][0]).get_trace()).conjugate() * fftpack.fft(station.get_channel(channel_pairs[0][1]).get_trace())) corr_13_fft = fftpack.ifft(-1 * fftpack.fft(station.get_channel(channel_pairs[1][0]).get_trace()).conjugate() * fftpack.fft(station.get_channel(channel_pairs[1][1]).get_trace())) if use_correlation: # Using correlation ll = opt.brute( ll_regular_station, ranges=(slice(ZenLim[0], ZenLim[1], 0.01), slice(AziLim[0], AziLim[1], 0.01)), args=(corr_02, corr_13, sampling_rate, positions_pairs, trace_start_time_pairs), full_output=True, finish=opt.fmin) # slow but does the trick else: ll = opt.brute(ll_regular_station_fft, ranges=(slice(ZenLim[0], ZenLim[1], 0.05), slice(AziLim[0], AziLim[1], 0.05)), args=(corr_02_fft, corr_13_fft, sampling_rate, positions_pairs, trace_start_time_pairs), full_output=True, finish=opt.fmin) # slow but does the trick if self.__debug: import peakutils zenith = ll[0][0] azimuth = ll[0][1] times = [] for pos in positions_pairs: tmp = [geo_utl.get_time_delay_from_direction(zenith, azimuth, pos[0], n=n_index), geo_utl.get_time_delay_from_direction(zenith, azimuth, pos[1], n=n_index)] times.append(tmp) delta_t_02 = times[0][1] - times[0][0] delta_t_13 = times[1][1] - times[1][0] # take different trace start times into account delta_t_02 -= (trace_start_time_pairs[0][1] - trace_start_time_pairs[0][0]) delta_t_13 -= (trace_start_time_pairs[1][1] - trace_start_time_pairs[1][0]) delta_t_02 *= sampling_rate delta_t_13 *= sampling_rate toffset = -(np.arange(0, corr_02.shape[0]) - corr_02.shape[0] / 2) / sampling_rate fig, (ax, ax2) = plt.subplots(2, 1, sharex=True) ax.plot(toffset, corr_02) ax.axvline(delta_t_02 / sampling_rate, label='time', c='k') indices = peakutils.indexes(corr_02, thres=0.8, min_dist=5) ax.plot(toffset[indices], corr_02[indices], 'o') imax = np.argmax(corr_02[indices]) self.logger.debug("offset 02= {:.3f}".format(toffset[indices[imax]] - (delta_t_02 / sampling_rate))) ax2.plot(toffset, corr_13) indices = peakutils.indexes(corr_13, thres=0.8, min_dist=5) ax2.plot(toffset[indices], corr_13[indices], 'o') ax2.axvline(delta_t_13 / sampling_rate, label='time', c='k') ax2.set_xlabel("time") ax2.set_ylabel("Correlation Ch 1/ Ch3", fontsize='small') ax.set_ylabel("Correlation Ch 0/ Ch2", fontsize='small') plt.tight_layout() # plt.close("all") station[stnp.zenith] = max(ZenLim[0], min(ZenLim[1], ll[0][0])) station[stnp.azimuth] = ll[0][1] output_str = "reconstucted angles theta = {:.1f}, phi = {:.1f}".format(station[stnp.zenith] / units.deg, station[stnp.azimuth] / units.deg) if station.has_sim_station(): sim_zen = None sim_az = None if(station.get_sim_station().is_cosmic_ray()): sim_zen = station.get_sim_station()[stnp.zenith] sim_az = station.get_sim_station()[stnp.azimuth] elif(station.get_sim_station().is_neutrino()): # in case of a neutrino simulation, each channel has a slightly different arrival direction -> compute the average sim_zen = [] sim_az = [] for efield in station.get_sim_station().get_electric_fields_for_channels(ray_path_type='direct'): sim_zen.append(efield[efp.zenith]) sim_az.append(efield[efp.azimuth]) sim_zen = np.array(sim_zen) sim_az = hp.get_normalized_angle(np.array(sim_az)) ops = "average incident zenith {:.1f} +- {:.1f}".format(np.mean(sim_zen) / units.deg, np.std(sim_zen) / units.deg) ops += " (individual: " for x in sim_zen: ops += "{:.1f}, ".format(x / units.deg) ops += ")" self.logger.debug(ops) ops = "average incident azimuth {:.1f} +- {:.1f}".format(np.mean(sim_az) / units.deg, np.std(sim_az) / units.deg) ops += " (individual: " for x in sim_az: ops += "{:.1f}, ".format(x / units.deg) ops += ")" self.logger.debug(ops) sim_zen = np.mean(np.array(sim_zen)) sim_az = np.mean(np.array(sim_az)) if(sim_zen is not None): dOmega = hp.get_angle(hp.spherical_to_cartesian(sim_zen, sim_az), hp.spherical_to_cartesian(station[stnp.zenith], station[stnp.azimuth])) output_str += " MC theta = {:.2f}, phi = {:.2f}, dOmega = {:.2f}, dZen = {:.1f}, dAz = {:.1f}".format(sim_zen / units.deg, hp.get_normalized_angle(sim_az) / units.deg, dOmega / units.deg, (station[stnp.zenith] - sim_zen) / units.deg, (station[stnp.azimuth] - hp.get_normalized_angle(sim_az)) / units.deg) self.__zenith.append(sim_zen) self.__azimuth.append(sim_az) self.__delta_zenith.append(station[stnp.zenith] - sim_zen) self.__delta_azimuth.append(station[stnp.azimuth] - hp.get_normalized_angle(sim_az)) self.logger.info(output_str) # Still have to add fit quality parameter to output if self.__debug: import peakutils # access simulated efield and high level parameters sim_present = False if(station.has_sim_station()): if(station.get_sim_station().has_parameter(stnp.zenith)): sim_station = station.get_sim_station() azimuth_orig = sim_station[stnp.azimuth] zenith_orig = sim_station[stnp.zenith] sim_present = True self.logger.debug("True CoREAS zenith {0}, azimuth {1}".format(zenith_orig, azimuth_orig)) self.logger.debug("Result of direction fitting: [zenith, azimuth] {}".format(np.rad2deg(ll[0]))) # Show fit space zen = np.arange(ZenLim[0], ZenLim[1], 1 * units.deg) az = np.arange(AziLim[0], AziLim[1], 2 * units.deg) x_plot = np.zeros(zen.shape[0] * az.shape[0]) y_plot = np.zeros(zen.shape[0] * az.shape[0]) z_plot = np.zeros(zen.shape[0] * az.shape[0]) i = 0 for a in az: for z in zen: # Evaluate fit function for grid if use_correlation: z_plot[i] = ll_regular_station([z, a], corr_02, corr_13, sampling_rate, positions_pairs, trace_start_time_pairs) else: z_plot[i] = ll_regular_station_fft([z, a], corr_02_fft, corr_13_fft, sampling_rate, positions_pairs, trace_start_time_pairs) x_plot[i] = a y_plot[i] = z i += 1 fig, ax = plt.subplots(1, 1) ax.scatter(np.rad2deg(x_plot), np.rad2deg(y_plot), c=z_plot, cmap='gnuplot2_r', lw=0) # ax.imshow(z_plot, cmap='gnuplot2_r', extent=(0, 360, 90, 180)) if sim_present: ax.plot(np.rad2deg(hp.get_normalized_angle(azimuth_orig)), np.rad2deg(zenith_orig), marker='d', c='g', label="True") ax.scatter(np.rad2deg(ll[0][1]), np.rad2deg(ll[0][0]), marker='o', c='k', label='Fit') # ax.colorbar(label='Fit parameter') ax.set_ylabel('Zenith [rad]') ax.set_xlabel('Azimuth [rad]') plt.tight_layout() # plot allowed solution separately for each pair of channels toffset = -(np.arange(0, corr_02.shape[0]) - corr_02.shape[0] / 2.) / sampling_rate indices = peakutils.indexes(corr_02, thres=0.8, min_dist=5) t02s = toffset[indices][np.argsort(corr_02[indices])[::-1]] + (trace_start_time_pairs[0][1] - trace_start_time_pairs[0][0]) toffset = -(np.arange(0, corr_13.shape[0]) - corr_13.shape[0] / 2.) / sampling_rate indices = peakutils.indexes(corr_13, thres=0.8, min_dist=5) t13s = toffset[indices][np.argsort(corr_13[indices])[::-1]] + (trace_start_time_pairs[1][1] - trace_start_time_pairs[1][0]) from scipy import constants c = constants.c * units.m / units.s dx = -6 * units.m def get_deltat13(dt, phi): t = -1. * dt * c / (dx * np.cos(phi) * n_index) t[t < 0] = np.nan return np.arcsin(t) def get_deltat02(dt, phi): t = -1 * dt * c / (dx * np.sin(phi) * n_index) t[t < 0] = np.nan return np.arcsin(t) def getDeltaTCone(r, dt): dist = np.linalg.norm(r) t0 = -dist * n_index / c Phic = np.arccos(dt / t0) # cone angle for allowable solutions self.logger.debug('dist = {}, dt = {}, t0 = {}, phic = {}'.format(dist, dt, t0, Phic)) nr = r / dist # normalize p = np.cross([0, 0, 1], nr) # create a perpendicular normal vector to r p = p / np.linalg.norm(p) q = np.cross(nr, p) # nr, p, and q form an orthonormal basis self.logger.debug('nr = {}\np = {}\nq = {}\n'.format(nr, p, q)) ThetaC = np.linspace(0, 2 * np.pi, 1000) Phis = np.zeros(len(ThetaC)) Thetas = np.zeros(len(ThetaC)) for i, thetac in enumerate(ThetaC): # create a set of vectors that point along the cone defined by r and PhiC rc = nr + np.tan(Phic) * (np.sin(thetac) * p + np.cos(thetac) * q) nrc = rc / np.linalg.norm(rc) theta = np.arccos(nrc[2]) phi = np.arctan2(nrc[1], nrc[0]) Phis[i] = phi Thetas[i] = theta return Phis, Thetas # phis = np.deg2rad(np.linspace(0, 360, 10000)) r0_2 = positions_pairs[0][1] - positions_pairs[0][0] # vector pointing from Ch2 to Ch0 r1_3 = positions_pairs[1][1] - positions_pairs[1][0] # vector pointing from Ch3 to Ch1 self.logger.debug('r02 {}\nr13 {}'.format(r0_2, r1_3)) linestyles = ['-', '--', ':', '-.'] for i, t02 in enumerate(t02s): # theta02 = get_deltat02(t02, phis) phi02, theta02 = getDeltaTCone(r0_2, t02) theta02[theta02 < 0] += np.pi phi02[phi02 < 0] += 2 * np.pi jumppos02 = np.where(np.abs(np.diff(phi02)) >= 5.0) for j, pos in enumerate(jumppos02): phi02 = np.insert(phi02, pos + 1 + j, np.nan) theta02 = np.insert(theta02, pos + 1 + j, np.nan) # mask02 = ~np.isnan(theta02) ax.plot(np.rad2deg(phi02), np.rad2deg(theta02), '{}C3'.format(linestyles[i % 4]), label='c 0+2 dt = {}'.format(t02)) for i, t13 in enumerate(t13s): # theta13 = get_deltat13(t13, phis) phi13, theta13 = getDeltaTCone(r1_3, t13) theta13[theta13 < 0] += np.pi phi13[phi13 < 0] += 2 * np.pi jumppos13 = np.where(np.abs(np.diff(phi13)) >= 5.0) for j, pos in enumerate(jumppos13): phi13 = np.insert(phi13, pos + 1 + j, np.nan) theta13 = np.insert(theta13, pos + 1 + j, np.nan) # mask13 = ~np.isnan(theta13) ax.plot(np.rad2deg(phi13), np.rad2deg(theta13), '{}C2'.format(linestyles[i % 4]), label='c 1+3 dt = {}'.format(t13)) ax.legend(fontsize='small') ax.set_ylim(ZenLim[0] / units.deg, ZenLim[1] / units.deg) ax.set_xlim(AziLim[0] / units.deg, AziLim[1] / units.deg)
# plot expectation # import expectation as e # zenith_expected = np.pi - e.get_arrival_angle(time.mktime(station.get_station_time().timetuple())) # ax.plot(225.5, np.rad2deg(zenith_expected), 'xr', label='expectation') # plt.legend()
[docs] def end(self): fig, ax = plt.subplots(1, 1) mask = np.abs(self.__delta_azimuth) < (1 * units.deg) ax.scatter(np.array(self.__zenith)[mask] / units.deg, np.array(self.__delta_zenith)[mask] / units.deg, s=20) ax.set_xlabel("zenith angle (MC) [deg]") ax.set_ylabel("(zenith_rec - zenith_MC) [deg]") fig.tight_layout() fig.savefig("zenith_bias.png") from radiotools import plthelpers as php bins = np.arange(-10, 10, .1) fig, ax = php.get_histogram(np.array(self.__delta_azimuth) / units.deg, bins=bins, xlabel="delta azimuth [deg]") fig.savefig("azimuth.png") plt.show() pass