#!/usr/bin/env python3
'''
Adapted from __author__ = 'Mykhaylo Shumko' : __init__, waveletTransform, plotPower, lagNAutoCorr
All other functions written by Brian Mullen.
To find the global wavelet spectrum run these two scripts:
wave = waveletAnalysis(data) #create object & runs wavelet transform
wave.globalWaveletSpectrum() #runs function to create global wavelet spectrum
To find signal change:
waveletAnalysis.tsSignal()
There seems to be very little difference between binned data and full data
'''
import sys
import numpy as np
from seas.waveletFunctions import *
import matplotlib.pylab as plt
import matplotlib
import math
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.gridspec import GridSpec
from datetime import datetime
from seas.signalanalysis import local_max, linear_regression, abline, lag_n_autocorr
from typing import List
import operator
[docs]class waveletAnalysis:
def __init__(self, data: np.ndarray, fps: int, **kwargs):
"""
A class for managing wavelet analysis.
Initialize the wavelet parameters and run the wavelet tranform.
We have only fully implemented Morlet wavelet in this code. One should update the
paramters if other types of Wavelets are used.
Functions:
waveletTransform: Wavelet transform
inverseWaveletTransform: inverse wavelet transform
plotPower: visualization function to look at power
function used in the GUI
waveletFilter:
nanCOI: creates cone of influence (COI)
values outside the cone are assigned np.nan
nanSig: creates np.nan values outside of significance
(used to visualize significant waveforms)
familySig: creates several cutoff values based on the red noise autoregression
based on percent cutoff value
sumAcrossPeriod: Sum across defined frequencies
globalWaveletSpectrum: Average across time to find significant frequencies
averageWaveletPower: Average across frequency to find significant times when waveforms
are present
noiseFilter: Filters timeseries based on high and/or low pass filters and based on
significance filters to get rid of low power noise. Used to filter global mean.
Arguments:
data: time course data
fps: Frames per second, rate of data acquisition (default=10)
verbose: Boolean on printed outputs (default=False)
plot: Plots outcomes of several functions (default=False)
mother: mother wavelet (default=Morelt)
param: defining parameter for mother wavelet(default=4)
j1: parameter for wavelet transform (see wavelet functions;
default=based on wavelt/parameters)
pad: parameter for wavelet transform (see wavelet functions;
default=based on wavelt/parameters)
dj: parameter for wavelet transform (see wavelet functions;
default=based on wavelt/parameters)
cdelta: parameter for wavelet transform (see wavelet functions;
default=based on wavelt/parameters)
psi0: constant factor to ensure a total energy of unity
(see Torrence and Compo; page 65)
siglvl: set significance level (default=0.95)
lag1: defines autoregression red-noise model
"""
assert data.ndim == 1, 'Time series is the wrong shape. It should be a 1-dim vector'
self.dataCopy = data
self.data = (data - np.mean(data)) / np.std(data, ddof=1)
self.n = len(self.data)
self.cadence = 1 / fps
self.time = np.arange(self.n) * self.cadence
# Default parameters.
# Print/ plot statements.
self.verbose = kwargs.get('verbose', False)
self.plot = kwargs.get('plot', False)
# Wavelet parameters.
self.mother = kwargs.get('mother', 'MORLET')
self.param = kwargs.get('param', 4)
self.j1 = kwargs.get('j1', 80)
self.pad = kwargs.get('pad', 1)
self.dj = kwargs.get('dj', 0.125)
self.s0 = kwargs.get('s0', 2 * self.cadence)
self.cdelta = None
if self.mother == 'MORLET' and self.param == 6:
self.cdelta = 0.776
self.psi0 = np.pi**(-0.25) * 0.85
elif self.mother == 'MORLET' and self.param == 4:
self.cdelta = 1.151
self.psi0 = np.pi**(-0.25) * 0.85
elif self.mother == 'DOG' and self.param == 2:
self.cdelta = 3.541
self.psi0 = None #used in wavelet inverse, see (Torrence and Compo 1998)
elif self.mother == 'DOG' and self.param == 6:
self.cdelta = 1.966
self.psi0 = None #used in wavelet inverse, see (Torrence and Compo 1998)
else:
assert self.cdelta != None, 'Unknown c value based on wavelet choice'
if self.psi0 == None:
if verbose:
'Unknown Psi 0, must be input before inverse can be calculated'
# Noise modeling parameter.
self.siglvl = kwargs.get('siglvl', 0.95)
self.lag1 = 0.5 * (lag_n_autocorr(data, 1) + lag_n_autocorr(data, 2))
self.waveletTransform()
[docs] def plotPower(self, ax=None):
'''
Plot log power spectorgram, with cone of influence
Arguments:
ax: specify axis to plot on (see matplotlib.pyplot)
Returns:
None
'''
self.levels = [0.0625, 0.125, 0.25, 0.5, 1, 2, 4, 8, 16]
colors = [
'navy', 'slateblue', 'c', 'g', 'gold', 'orange', 'tomato', 'crimson'
]
if ax == None:
f = plt.figure()
f, ax = plt.subplots(1)
else:
ax = np.ravel(ax)[0]
# Max period is fourier_factor*S0*2^(j1*dj), fourier_factor = 3.97383530632.
CS = ax.contourf(self.time,
self.period,
np.log2(self.power),
len(self.levels),
colors=colors)
im = ax.contourf(CS, levels=np.log2(self.levels), colors=colors)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Period (s)')
ax.set_title('Wavelet Power Spectrum')
# 95 significance contour, levels at -99 (fake) and 1 (95# signif).
ax.contour(self.time, self.period, self.sig95, [-99, 1], colors='k')
# Cone-of-influence, anything "below" is dubious.
ax.fill_between(self.time,
np.max(self.period),
self.coi,
alpha=0.5,
facecolor='white',
zorder=3)
ax.plot(self.time, self.coi, 'k')
# Format y-scale.
# Different matplotlib versions available for python < 3.8.
try:
ax.set_yscale('log', base=2, subs=None)
except ValueError:
ax.set_yscale('log', basey=2, subsy=None)
ax.set_ylim([np.min(self.period), np.max(self.period)])
axy = ax.yaxis
axy.set_major_formatter(matplotlib.ticker.ScalarFormatter())
ax.ticklabel_format(
axis='y', style='plain') ## causes issues with tkinter mpl canvas
ax.invert_yaxis()
# Set up the size and location of the colorbar.
divider = make_axes_locatable(ax)
cax = divider.append_axes("bottom", size="5%", pad=0.5)
plt.colorbar(im, cax=cax, orientation='horizontal')
[docs] def waveletFilter(self,
lowerPeriod: float = None,
upperPeriod: float = None,
sigLevel: float = 1):
'''
High-pass and low-pass filter possibilities.
In seas.ica filter_method == 'wavelet' will filter high frequencies past the nyquist
smapling rate and very low power spectra (significance ratio less than 0.25)
See seas.waveletAnalysis.noiseFilter
Arguments:
lowerPeriod: High pass filter, specify lowest period and
it will exclude all under that period
upperPeriod: Low pass filter, specify highest period and
it will exclude all over that period
sigLevel: Significance filter, eliminate low power waveforms
(0.25 is used in noiseFilter)
Returns:
None
'''
self.waveFlt = self.wave.copy()
# Band pass filter:
# Zero out parts of the wavlet space that we don't want to reconstruct.
if lowerPeriod != None:
if lowerPeriod > self.period[0]:
lower_ind = np.where(self.period < lowerPeriod)[0][-1]
self.waveFlt[:lower_ind, :] = 0
if upperPeriod != None:
if upperPeriod < self.period[-1]:
upper_ind = np.where(self.period > upperPeriod)[0][0]
self.waveFlt[upper_ind:, :] = 0
# Significance filter:
notSigInd = np.where(
self.sig95 < sigLevel
) # Only pass data that has power of (100% - sigThreshold). Usually sigThreshold is 95%. Was 0.25.
self.waveFlt[notSigInd] = 0
[docs] def nanCOI(self):
'''
Get rid of all values outside the cone of influence.
Sets those values to np.nan
Arguments:
None
Returns:
None
'''
self.nanCOImat = self.power.copy()
for i in range(self.power.shape[1]):
cutoff = np.where(self.coi[i] < self.period)
self.nanCOImat[cutoff, i] = np.nan
[docs] def nanSig(self):
'''
Get rid of all values not significant.
Sets those values to np.nan
Arguments:
None
Returns:
None
'''
self.nanSigmat = self.wave
self.nanSigmat[np.where(wavelet.sig95 < 1)] = np.nan
[docs] def familySig(self,
sigList: List[float] = [0.9, 0.95, 0.99, 0.999],
dof: int = -1,
sigtest: float = 0):
'''
Plot a family of significance curves for visualization and analysis.
Arguments:
sigList: List or float of significant values
dof: degrees of freedom (see waveletFunctions)
sigtest: which significance test is used (see waveletFunctions for
more information)
0: regular chi-square test (for full wavelet transform; non-smoothed)
1: time-average test (for globalWaveletSpectrum)
2: scale-average test (for averageWaveletPower)
Returns:
fam_significance: significance cutoffs
sigList: List of requested significants
'''
if isinstance(sigList, float):
if sigtest < 2:
fam_signif = np.zeros((1, self.scale.shape[0])) * np.nan
if sigtest == 2:
fam_signif = np.nan
fam_signif = wave_signif([1.0],
dt=self.cadence,
scale=self.scale,
sigtest=sigtest,
lag1=self.lag1,
siglvl=sigList,
dof=dof,
mother=self.mother,
param=self.param)
if isinstance(sigList, list):
if sigtest < 2:
fam_signif = np.zeros(
(len(sigList), self.scale.shape[0])) * np.nan
if sigtest == 2:
fam_signif = np.zeros((len(sigList), 1)) * np.nan
for i, sig in enumerate(sigList):
fam_signif[i] = wave_signif([1.0],
dt=self.cadence,
scale=self.scale,
sigtest=sigtest,
lag1=self.lag1,
siglvl=sig,
dof=dof,
mother=self.mother,
param=self.param)
return np.squeeze(fam_signif), np.squeeze(sigList)
[docs] def sumAcrossPeriod(self, perLim: List[float] = [0, 100]):
'''
Sum wavelet power across select periods.
Arguments:
perLim: tuple for lower and upper periods. Power summed
between these defined periods.
Returns:
period_sum: sum across the power of defined frequencies
'''
if self.verbose:
print('Summing across {0} to {1} periods on wavelet run with mother {2} at paramter {3}'.\
format(perLim[0], perLim[1], self.mother, self.param) )
l_per_lim = np.min(np.where(perLim[0] < self.period))
u_per_lim = np.min(np.where(perLim[1] < self.period))
if not hasattr(self, 'nanCOImat'):
self.nanCOI()
period_sum = (self.dj * self.cadence) / self.cdelta * np.nansum(
(self.nanCOImat[l_per_lim:u_per_lim]**2 /
self.period[l_per_lim:u_per_lim, None]),
axis=0)
return np.squeeze(period_sum)
[docs] def globalWaveletSpectrum(self):
'''
Global Wavelet Spectrum, Average across time to find significant
prominant frequencies
Arguments:
None
Returns:
None
'''
if self.verbose:
print('Assessing wavelet mother {0} at paramter {1}'.format(
self.mother, self.param))
# Calulate the global self spectrum.
self.nanCOI()
# If np.sum(~np.isnan(self.nanCOImat))!=0:.
self.period_size = np.sum(~np.isnan(self.nanCOImat), axis=1)
nan_ind = np.where(self.period_size == 0)[0]
self.gws = np.zeros_like(self.period) * np.nan
if nan_ind.any():
self.gws[:nan_ind[0]] = np.nanmean(self.nanCOImat[:nan_ind[0], :],
axis=1)
self.gws[nan_ind] = 0
else:
self.gws = np.nanmean(self.nanCOImat, axis=1)
if self.period_size.shape[0] != self.period.shape[0]:
dif = self.period_size.shape[0] - self.period.shape[0]
if dif < 0:
self.period_size = np.append(self.period_size,
np.zeros(np.abs(dif)))
else:
self.period_size = self.period_size[:self.period.shape[0]]
# Calculate the average significance.
self.gws_sig, self.gws_sigList = self.familySig(sigList=[0.95],
dof=self.period_size,
sigtest=1)
if self.verbose:
print('Auto-correlation value: {0:.4g}'.format(self.lag1))
# Determine fourier wavelength.
if self.mother == 'DOG':
self.flambda = (2 * np.pi * 1 / self.period) / np.sqrt(self.param +
.5)
if self.mother == 'MORLET':
self.flambda = (4 * np.pi * 1 / self.period) / (
self.param + np.sqrt(2 + np.square(self.param)))
mx_wav, mx_gws, mx_sig = local_max(self.period, self.gws, self.gws_sig)
fl_wav, mx_gws, mx_sig = local_max(self.flambda, self.gws, self.gws_sig)
lwav = []
lgws = []
lfl = []
for i in range(len(mx_wav)):
if mx_gws[i] > mx_sig[i]:
lwav.append(mx_wav[i])
lgws.append(mx_gws[i])
lfl.append(fl_wav[i])
lwav_inv = [x**(-1) for x in lwav]
self.gws_localmax_power = lgws
self.gws_localmax_freq = lfl
# Find the lowest and highest frequencies that are still significant.
hiwav = np.nan
hival = np.nan
lowav = np.nan
loval = np.nan
if np.where(self.gws > self.gws_sig)[0].shape[0] > 0:
hival = self.gws[np.where(self.gws > self.gws_sig)][0]
hiwav = self.flambda[np.where(self.gws > self.gws_sig)][0]
loval = self.gws[np.where(self.gws > self.gws_sig)][-1]
lowav = self.flambda[np.where(self.gws > self.gws_sig)][-1]
self.gws_lo_high_freq = [(lowav, loval), (hiwav, hival)]
if nan_ind.any():
self.gws[nan_ind] = np.nan
if self.verbose:
print('Low frequency: ', lowav)
print('High freqency: ', hiwav)
if self.plot:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharey=True)
nyq = 0.5 * 1 / self.cadence
linetype = ['-', '-.', '--', ':']
fam_signif, sigList = self.familySig(dof=self.period_size,
sigtest=1)
# Self period graph.
ax1.plot(self.period, self.gws)
ax1.plot(lwav, lgws, 'ro')
for i in range(len(sigList)):
if i >= len(linetype) - 1:
j = len(linetype) - 1
else:
j = i
ax1.plot(self.period,
fam_signif[i],
label=sigList[i],
ls=linetype[j],
color='k')
if not lgws:
ax1.set_ylim([0, 10])
else:
ax1.set_ylim([0, np.ceil(max(lgws))])
if not lwav or np.ceil(max(lwav)) < 10:
ax1.set_xlim([0, 10])
else:
ax1.set_xlim([0, np.ceil(max(lwav))])
# ax1.legend()
ax1.set_xlabel(' Wavelet {0} {1} period(s)'.format(
self.mother, self.param))
ax1.set_ylabel('normalized power (to variance)')
ax1.set_title('Power spectrum by period')
# Fourier space lambda.
ax2.plot(self.flambda, self.gws)
for i in range(len(sigList)):
if i >= len(linetype) - 1:
j = len(linetype) - 1
else:
j = i
ax2.plot(self.flambda,
fam_signif[i],
label=sigList[i],
ls=linetype[j],
color='k')
if not lgws:
ax2.set_ylim([0, 10])
else:
ax2.set_ylim([0, np.ceil(max(lgws))])
if not lfl or np.ceil(max(lfl)) < nyq:
ax2.set_xlim([0, nyq])
else:
ax2.set_xlim([0, np.ceil(max(lfl))])
ax2.plot(lowav, loval, 'go', label='lowSigFreq')
ax2.plot(hiwav, hival, 'bo', label='highSigFreq')
ax2.plot(lfl, lgws, 'ro', label='localMax')
ax2.legend()
ax2.set_xlabel('Fourier frquency')
ax2.set_title('Power spectrum by frequency')
# ax2.set_ylabel('normalized power (to variance)')
plt.tight_layout()
plt.show()
[docs] def averageWaveletPower(self, perLim: List[float] = [0.25, 8]):
'''
Average wavelet power, Average across periods to find significant
prominant times of activity.
Arguments:
perLim: tuple for lower and upper periods. Power summed
between these defined periods.
Returns:
None
'''
assert len(perLim) == 2, 'Period limit list must only include 2 values'
if self.verbose:
print('Creating scaled average of the timeseries, \
created with wavelet mother {0} at paramter {1}'.format(
self.mother, self.param))
self.period_sum = self.sumAcrossPeriod(perLim=perLim)
if self.period_sum.shape[0] != self.time.shape[0]:
dif = self.period_sum.shape[0] - self.time.shape[0]
if dif < 0:
self.period_sum = np.append(self.period_sum,
np.zeros(np.abs(dif)))
else:
self.period_sum = self.period_sum[:self.period.shape[0]]
# calculate the average significance
self.sig_period_sum, _ = self.familySig(sigList=[0.95],
dof=perLim,
sigtest=2)
# Find coordinates of local max values.
mx_wav, mx_gws = local_max(self.time, self.period_sum)
# Return only those above significance threshold.
ltime = []
lgws = []
for i in range(len(mx_wav)):
if mx_gws[i] > self.sig_period_sum:
ltime.append(mx_wav[i])
lgws.append(mx_gws[i])
if self.plot:
sigline = np.zeros(self.period_sum.shape[0]) + self.sig_period_sum
plt.plot(self.time, self.period_sum)
plt.plot(self.time, np.squeeze(sigline), 'k--')
plt.plot(ltime, lgws, 'ro')
# plt.xlim([150, 180])
# plt.ylim([0, 20])
# plt.xlim([0,8])
# plt.ylim([0,1])
plt.show()
self.events = ltime
self.events_value = lgws
[docs] def noiseFilter(self,
lowerPeriod: float = None,
upperPeriod: float = 10,
sigLevel: float = 0.25):
'''
High-pass and low-pass filter possibilities.
In seas.ica filter_method == 'wavelet' will filter high frequencies past the nyquist
smapling rate and very low power spectra (significance ratio less than 0.25)
See seas.waveletAnalysis.noiseFilter
Arguments:
lowerPeriod: High pass filter, specify lowest period and
it will exclude all under that period
upperPeriod: Low pass filter, specify highest period and
it will exclude all over that period
sigLevel: Significance filter, eliminate low power waveforms
(0.25 is used in noiseFilter)
Returns:
filtData: inverse wavlet transform based on removing periods/power
defined in the parameters
'''
if lowerPeriod is None:
lowerPeriod = 2 * self.cadence # Nyquist sampling rate.
self.waveletFilter(lowerPeriod=lowerPeriod,
upperPeriod=upperPeriod,
sigLevel=sigLevel)
self.inverseWaveletTransform()
filtData = (self.dataFlt * np.std(self.dataCopy, ddof=1)) + np.mean(
self.dataCopy)
if self.plot:
plt.figure(num=None,
figsize=(20, 3),
dpi=80,
facecolor='w',
edgecolor='k')
ax1 = plt.subplot(211)
ax1.plot(self.time, self.dataCopy, color='k', label='origional')
ax1.set_ylabel('dfof')
ax1.legend()
ax2 = plt.subplot(212)
ax2.plot(self.time, filtData, color='blue', label='filtered')
ax2.plot(self.time,
self.dataCopy - filtData,
color='orange',
label='residual')
ax2.set_ylabel('dfof')
ax2.set_xlabel('time(s)')
ax2.legend()
plt.show()
return filtData