# -*- coding: utf-8 -*-
"""
This module implements NLfp Class for NeuroChaT software.
@author: Md Nurul Islam; islammn at tcd dot ie
"""
import os
import re
import inspect
from functools import reduce
import logging
from collections import OrderedDict as oDict
from copy import deepcopy
from math import floor, ceil
from neurochat.nc_utils import window_rms
from neurochat.nc_utils import butter_filter
from neurochat.nc_utils import find_peaks
from neurochat.nc_utils import find_true_ranges
from neurochat.nc_utils import fft_psd
from neurochat.nc_utils import find
from neurochat.nc_circular import CircStat
from neurochat.nc_hdf import Nhdf
from neurochat.nc_base import NBase
import numpy as np
import scipy.stats as stats
import scipy.signal as sg
from scipy.fftpack import fft
[docs]class NLfp(NBase):
"""
This data class contains information about the neural LFP signal.
It decodes data from different formats and analyses LFP signal in
the recording.
Attributes
----------
_file_tag : str
The tag of the lfp data.
_channel_id : int
The id of the channel the lfp data came from.
_samples : np.ndarray
Array of lfp samples.
_timestamp : np.ndarray
Array of timestamps of the lfp samples.
"""
def __init__(self, **kwargs):
"""See the class description."""
super().__init__(**kwargs)
self._file_tag = ''
self._channel_id = 0
self._samples = None
self._timestamp = None
self.set_record_info({'Total samples': 0, "No of samples": 0})
self.__type = 'lfp'
[docs] def get_type(self):
"""
Return the type of object. For NLfp, this is always `lfp` type.
Parameters
----------
None
Returns
-------
str
"""
return self.__type
# For multi-unit analysis, {'SpikeName': cell_no} pairs should be used as
# function input
[docs] def set_channel_id(self, channel_id=''):
"""
Set the electrode channels ID.
Parameters
----------
channel_id : str
Channel ID for the LFP data
Returns
-------
None
"""
self._channel_id = channel_id
[docs] def get_channel_id(self):
"""
Return the electrode channels ID.
Parameters
----------
None
Returns
-------
str
LFP channel ID
"""
return self._channel_id
[docs] def set_file_tag(self, file_tag):
"""
Set the file tag or extension for the LFP dataset.
For example,
Axona recordings usually have file tags like 'eeg' or 'eeg8' etc.
Parameters
----------
file_tag : str
File tag or extension for the LFP dataset
Returns
-------
None
"""
self._file_tag = file_tag
[docs] def get_file_tag(self):
"""
Return the file tag or extension for the LFP dataset.
For example,
Axona recordings usually have file tags like 'eeg' or 'eeg8' etc.
Parameters
----------
None
Returns
-------
str
File tag or extension for the LFP dataset
"""
return self._file_tag
[docs] def get_timestamp(self):
"""
Return the timestamps of the LFP waveform.
Parameters
----------
None
Returns
-------
ndarray
Timestamps of the LFP signal
"""
return self._timestamp
def _set_timestamp(self, timestamp=None):
"""
Set the timestamps for LFP samples.
Parameters
----------
timestamp : list or ndarray
Timestamps of LFP samples
Returns
-------
None
"""
if timestamp is not None:
self._timestamp = timestamp
[docs] def get_samples(self):
"""
Return LFP waveform samples.
Parameters
----------
None
Returns
-------
ndarray
Samples of the LFP signal
"""
return self._samples
def _set_samples(self, samples=[]):
"""
Set LFP samples.
Parameters
----------
samples : list or ndarray
LFP samples
Returns
-------
None
"""
self._samples = samples
def _set_total_samples(self, tot_samples=0):
"""
Set the number of LFP samples.
This is performed as part of storing the recording information.
Parameters
----------
tot_samples : int
Total number of samples in the LFP signal
Returns
-------
None
"""
self._record_info['No of samples'] = tot_samples
def _set_total_channel(self, tot_channels):
"""
Set the value of the number of channels.
This is performed as part of storing the recording information.
Parameters
----------
tot_channels : int
Total number of channels
Returns
-------
None
"""
self._record_info['No of channels'] = tot_channels
def _set_timestamp_bytes(self, bytes_per_timestamp):
"""
Set `bytes per timestamp` value.
This is performed as part of storing the recording information.
Parameters
----------
bytes_per_timestamp : int
Total number of bytes to represent timestamp in the binary file
Returns
-------
None
"""
self._record_info['Bytes per timestamp'] = bytes_per_timestamp
def _set_sampling_rate(self, sampling_rate):
"""
Set the sampling rate of the LFP signal.
This is performed as part of storing the recording information.
Parameters
----------
sampling_rate : int
Sampling rate of the LFP waveform
Returns
-------
None
"""
self._record_info['Sampling rate'] = sampling_rate
def _set_bytes_per_sample(self, bytes_per_sample):
"""
Set `bytes per sample` value.
This is performed as part of storing the recording information.
Parameters
----------
bytes_per_sample : int
Total number of bytes to represent each sample in the binary file
Returns
-------
None
"""
self._record_info['Bytes per sample'] = bytes_per_sample
def _set_fullscale_mv(self, adc_fullscale_mv):
"""
Set fullscale value of ADC value in mV.
This is performed as part of storing the recording information.
Parameters
----------
adc_fullscale_mv : int
Fullscale voltage of ADC signal in mV
Returns
-------
None
"""
self._record_info['ADC Fullscale mv'] = adc_fullscale_mv
[docs] def get_total_samples(self):
"""
Return total number of LFP samples.
Parameters
----------
None
Returns
-------
ndarray
Total number of LFP samples
"""
return self._record_info['No of samples']
[docs] def get_total_channel(self):
"""
Return total number of electrode channels in the LFP data file.
Parameters
----------
None
Returns
-------
int
Total number of electrode channels
"""
return self._record_info['No of channels']
[docs] def get_timestamp_bytes(self):
"""
Return number of bytes to represent each timestamp in the binary file.
Parameters
----------
None
Returns
-------
int
Number of bytes to represent timestamps
"""
return self._record_info['Bytes per timestamp']
[docs] def get_sampling_rate(self):
"""
Return the sampling rate of spike waveforms.
Parameters
----------
None
Returns
-------
int
Sampling rate for spike waveforms
"""
return self._record_info['Sampling rate']
[docs] def get_bytes_per_sample(self):
"""
Return the number of bytes to represent each LFP waveform sample.
Parameters
----------
None
Returns
-------
int
Number of bytes to represent each sample of the LFP waveform
"""
return self._record_info['Bytes per sample']
[docs] def get_fullscale_mv(self):
"""
Return the fullscale value of the ADC in mV.
Parameters
----------
None
Returns
-------
int
Fullscale ADC value in mV
"""
return self._record_info['ADC Fullscale mv']
[docs] def get_recording_time(self):
"""
Return the recording time in seconds.
Parameters
----------
None
Returns
-------
int
Recording time in seconds
"""
return self.get_total_samples() / (self.get_sampling_rate())
[docs] def load(self, filename=None, system=None):
"""
Load LFP datasets.
Parameters
----------
filename : str
Name of the spike datafile
system : str
Recording system or format of the spike data file
Returns
-------
None
See also
--------
load_lfp_axona(), load_lfp_NLX(), load_lfp_NWB()
"""
if system is None:
system = self._system
else:
self._system = system
if filename is None:
filename = self._filename
else:
self._filename = filename
loader = getattr(self, 'load_lfp_' + system)
loader(filename)
[docs] def add_spike(self, spike=None, **kwargs):
"""
Add new spike node to current NLfp() object.
Parameters
----------
spike : NSpikes
NSPike object. If None, new object is created
Returns
-------
`:obj:NSpike()`
A new NSpike() object
"""
cls = kwargs.get('cls', None)
if not inspect.isclass(cls):
try:
data_type = spike.get_type()
if data_type == 'spike':
cls = spike.__class__
except BaseException:
logging.error('Data type cannot be determined!')
if inspect.isclass(cls):
new_spike = self._add_node(cls, spike, 'spike', **kwargs)
return new_spike
else:
logging.error('Cannot add the spike data!')
[docs] def load_spike(self, names='all'):
"""
Load datasets of the spike nodes.
The name of each node is used for obtaining the filenames.
Parameters
----------
names : list of str
Names of the nodes to load. If None, current NSpike() object is loaded
Returns
-------
None
"""
if names == 'all':
for spike in self._spikes:
spike.load()
else:
logging.error("Spikes by name has yet to be implemented")
# for name in names:
# spike = self.get_spikes_by_name(name)
# spike.load()
[docs] def add_lfp(self, lfp=None, **kwargs):
"""
Add new LFP node to current NLfp() object.
Parameters
----------
lfp : NLfp
NLfp object. If None, new object is created
Returns
-------
`:obj:Nlfp`
A new NLfp() object
"""
new_lfp = self._add_node(self.__class__, lfp, 'lfp', **kwargs)
return new_lfp
[docs] def load_lfp(self, names=None):
"""
Load datasets of the LFP nodes.
The name of each node is used for obtaining the filenames.
Parameters
----------
names : list of str
Names of the nodes to load. If `all`, all LFP nodes are loaded
Returns
-------
None
"""
if names is None:
self.load()
elif names == 'all':
for lfp in self._lfp:
lfp.load()
else:
logging.error("Lfp by name has yet to be implemented")
[docs] def spectrum(self, **kwargs):
"""
Analyse frequency spectrum of the LFP signal.
Parameters
----------
**kwargs
Keyword arguments
Returns
-------
dict
Graphical data of the analysis
"""
graph_data = oDict()
Fs = self.get_sampling_rate()
slc = kwargs.get('slice', None)
if slc:
lfp = self.get_samples()[slc]
else:
lfp = self.get_samples()
window = kwargs.get('window', 1.0)
window = sg.get_window('hann', int(window * Fs)) if isinstance(window, float)\
or isinstance(window, int) else window
win_sec = np.ceil(window.size / Fs)
noverlap = kwargs.get('noverlap', 0.5 * win_sec)
noverlap = noverlap if noverlap < win_sec else 0.5 * win_sec
noverlap = np.ceil(noverlap * Fs)
nfft = kwargs.get('nfft', 2 * Fs)
nfft = np.power(2, int(np.ceil(np.log2(nfft))))
ptype = kwargs.get('ptype', 'psd')
ptype = 'spectrum' if ptype == 'power' else 'density'
prefilt = kwargs.get('prefilt', True)
_filter = kwargs.get('filtset', [10, 1.5, 40, 'bandpass'])
fmax = kwargs.get('fmax', Fs / 2)
if prefilt:
lfp = butter_filter(lfp, Fs, *_filter)
tr = kwargs.get('tr', False)
db = kwargs.get('db', False)
if tr:
f, t, Sxx = sg.spectrogram(
lfp, fs=Fs,
window=window, nperseg=window.size, noverlap=noverlap, nfft=nfft,
detrend='constant', return_onesided=True,
scaling=ptype)
graph_data['t'] = t + self.get_timestamp()[0]
graph_data['f'] = f[find(f <= fmax)]
if db:
Sxx = 10 * np.log10(Sxx / np.amax(Sxx))
Sxx = Sxx.flatten()
Sxx[find(Sxx < -40)] = -40
Sxx = np.reshape(Sxx, [f.size, t.size])
graph_data['Sxx'] = Sxx[find(f <= fmax), :]
else:
f, Pxx = sg.welch(
lfp, fs=Fs,
window=window, nperseg=window.size, noverlap=noverlap, nfft=nfft,
detrend='constant', return_onesided=True, scaling=ptype)
graph_data['f'] = f[find(f <= fmax)]
if db:
Pxx = 10 * np.log10(Pxx / Pxx.max())
Pxx[find(Pxx < -40)] = -40
graph_data['Pxx'] = Pxx[find(f <= fmax)]
return graph_data
[docs] def phase_dist(self, event_stamp, **kwargs):
"""
Analysis of spike to LFP phase distribution.
Parameters
----------
event_stamp : ndarray
Timestamps of the events of spiking activities for
measuring the phase distribution.
**kwargs
Keyword arguments
Returns
-------
dict
Graphical data of the analysis
"""
_results = oDict()
graph_data = oDict()
cs = CircStat()
lfp = self.get_samples() * 1000
Fs = self.get_sampling_rate()
time = self.get_timestamp()
# Input parameters
bins = int(360 / kwargs.get('binsize', 5))
rbinsize = kwargs.get('rbinsize', 2) # raster binsize
rbins = int(360 / rbinsize)
fwin = kwargs.get('fwin', [6, 12])
pratio = kwargs.get('pratio', 0.2)
aratio = kwargs.get('aratio', 0.15)
# Filter
fmax = fwin[1]
fmin = fwin[0]
_filter = [5, fmin, fmax, 'bandpass']
_prefilt = kwargs.get('filtset', [10, 1.5, 40, 'bandpass'])
b_lfp = butter_filter(lfp, Fs, *_filter) # band LFP
lfp = butter_filter(lfp, Fs, *_prefilt)
# Measure phase
hilb = sg.hilbert(b_lfp)
phase = np.angle(hilb, deg=True)
phase[phase < 0] = phase[phase < 0] + 360
mag = np.abs(hilb)
ephase = np.interp(event_stamp, time, phase)
p2p = np.abs(np.max(lfp) - np.min(lfp))
xline = 0.5 * np.mean(mag) # cross line
# Detection algo
# zero cross
mag1 = mag[0:-3]
mag2 = mag[1:-2]
mag3 = mag[2:-1]
xind = np.union1d(
find(np.logical_and(mag1 < xline, mag2 > xline)),
find(np.logical_and(
np.logical_and(mag1 < xline, mag2 == xline),
mag3 > xline))
)
# Ignore segments <1/fmax
i = 0
rcount = np.empty([0, ])
bcount = np.empty([0, 0])
phBins = np.arange(0, 360, 360 / bins)
rbins = np.arange(0, 360, 360 / rbins)
seg_count = 0
while i < len(xind) - 1:
k = i + 1
while (
((time[xind[k]] - time[xind[i]]) < (1 / fmin))
and (k < len(xind) - 1)
):
k += 1
s_lfp = lfp[xind[i]: xind[k]]
s_p2p = np.abs(np.max(s_lfp) - np.min(s_lfp))
if s_p2p >= aratio * p2p:
s_psd, f = fft_psd(s_lfp, Fs)
if np.sum(s_psd[np.logical_and(f >= fmin, f <= fmax)]
) > pratio * np.sum(s_psd):
# Phase distribution
s_phase = ephase[np.logical_and(
event_stamp > time[xind[i]], event_stamp <= time[xind[k]])]
if not s_phase.shape[0]:
pass
else:
seg_count += 1
cs.set_theta(s_phase)
temp_count = cs.circ_histogram(bins=rbinsize)
temp_count = temp_count[0]
if not rcount.size:
rcount = temp_count
else:
rcount = np.append(rcount, temp_count)
temp_count = np.histogram(
s_phase, bins=bins, range=[0, 360])
temp_count = np.resize(temp_count[0], [1, bins])
if not len(bcount):
bcount = temp_count
else:
bcount = np.append(bcount, temp_count, axis=0)
i = k
rcount = rcount.reshape([seg_count, rbins.size])
phCount = np.sum(bcount, axis=0)
cs.set_rho(phCount)
cs.set_theta(phBins)
cs.calc_stat()
result = cs.get_result()
meanTheta = result['meanTheta'] * np.pi / 180
_results['LFP Spike Mean Phase'] = result['meanTheta']
_results['LFP Spike Mean Phase Count'] = result['meanRho']
_results['LFP Spike Phase Res Vect'] = result['resultant']
graph_data['meanTheta'] = meanTheta
graph_data['phCount'] = phCount
graph_data['phBins'] = phBins
graph_data['raster'] = rcount
graph_data['rasterbins'] = rbins
self.update_result(_results)
return graph_data
[docs] def phase_at_events(self, event_stamps, **kwargs):
"""
Phase based on times.
Parameters
----------
event_stamps : array
an array of event times
**kwargs:
keyword arguments
Returns
-------
(array)
Phase values for each position
"""
lfp = self.get_samples() * 1000
Fs = self.get_sampling_rate()
time = self.get_timestamp()
# Input parameters
fwin = kwargs.get('fwin', [6, 12])
# Filter
fmax = fwin[1]
fmin = fwin[0]
_filter = [5, fmin, fmax, 'bandpass']
b_lfp = butter_filter(lfp, Fs, *_filter) # band LFP
# Measure phase
hilb = sg.hilbert(b_lfp)
phase = np.angle(hilb, deg=True)
phase[phase < 0] = phase[phase < 0] + 360
ephase = np.interp(event_stamps, time, phase)
return ephase
[docs] def plv(self, event_stamp, **kwargs):
"""
Calculate phase-locking value of the spike train to underlying LFP signal.
When 'mode'= None in the input kwargs, it calculates the PLV and SFC over
the entire spike-train.
If 'mode'= 'bs', it bootstraps the spike-timestamps
and calculates the locking values for each set of new spike timestamps.
If 'mode'= 'tr', a time-resolved phase-locking analysis is performed where
the LFP signal is split into overlapped segments for each calculation.
Parameters
----------
evnet_stamp : ndarray
Timestamps of the events or the spiking activities for measuring the phase
locking
**kwargs
Keyword arguments
Returns
-------
dict
Graphical data of the analysis
"""
graph_data = oDict()
lfp = self.get_samples() * 1000
Fs = self.get_sampling_rate()
time = self.get_timestamp()
window = np.array(kwargs.get('window', [-0.5, 0.5]))
win = np.ceil(window * Fs).astype(int)
win = np.arange(win[0], win[1])
slep_win = sg.hann(win.size, False)
nfft = kwargs.get('nfft', 1024)
# None, 'bs', 'tr' bs=bootstrp, tr=time-resolved
mode = kwargs.get('mode', None)
fwin = kwargs.get('fwin', [])
xf = np.arange(0, Fs, Fs / nfft)
f = xf[0: int(nfft / 2) + 1]
ind = np.arange(f.size) if len(fwin) == 0 else find(
np.logical_and(f >= fwin[0], f <= fwin[1]))
if mode == 'bs':
nsample = kwargs.get('nsample', 50)
nrep = kwargs.get('nrep', 500)
STA = np.empty([nrep, win.size])
fSTA = np.empty([nrep, ind.size])
STP = np.empty([nrep, ind.size])
SFC = np.empty([nrep, ind.size])
PLV = np.empty([nrep, ind.size])
for i in np.arange(nrep):
data = self.plv(np.random.choice(event_stamp, nsample, False),
window=window, nfft=nfft, mode=None, fwin=fwin)
t = data['t']
STA[i, :] = data['STA']
fSTA[i, :] = data['fSTA']
STP[i, :] = data['STP']
SFC[i, :] = data['SFC']
PLV[i, :] = data['PLV']
graph_data['t'] = t
graph_data['f'] = f[ind]
graph_data['STAm'] = STA.mean(0)
graph_data['fSTAm'] = fSTA.mean(0)
graph_data['STPm'] = STP.mean(0)
graph_data['SFCm'] = SFC.mean(0)
graph_data['PLVm'] = PLV.mean(0)
graph_data['STAe'] = stats.sem(STA, 0)
graph_data['fSTAe'] = stats.sem(fSTA, 0)
graph_data['STPe'] = stats.sem(STP, 0)
graph_data['SFCe'] = stats.sem(SFC, 0)
graph_data['PLVe'] = stats.sem(PLV, 0)
elif mode == 'tr':
nsample = kwargs.get('nsample', None)
slide = kwargs.get('slide', 25) # in ms
slide = slide / 1000 # convert to sec
offset = np.arange(window[0], window[-1], slide)
nwin = offset.size
fSTA = np.empty([nwin, ind.size])
STP = np.empty([nwin, ind.size])
SFC = np.empty([nwin, ind.size])
PLV = np.empty([nwin, ind.size])
if nsample is None or nsample > event_stamp.size:
stamp = event_stamp
else:
stamp = np.random.choice(event_stamp, nsample, False)
for i in np.arange(nwin):
data = self.plv(stamp + offset[i],
nfft=nfft, mode=None, fwin=fwin, window=window)
t = data['t']
fSTA[i, :] = data['fSTA']
STP[i, :] = data['STP']
SFC[i, :] = data['SFC']
PLV[i, :] = data['PLV']
graph_data['offset'] = offset
graph_data['f'] = f[ind]
graph_data['fSTA'] = fSTA.transpose()
graph_data['STP'] = STP.transpose()
graph_data['SFC'] = SFC.transpose()
graph_data['PLV'] = PLV.transpose()
elif mode is None:
center = time.searchsorted(event_stamp)
# Keep windows within data
center = np.array([center[i] for i in range(0, len(event_stamp))
if center[i] + win[0] >= 0 and center[i] + win[-1] < time.size])
sta_data = self.event_trig_average(event_stamp, **kwargs)
STA = sta_data['ETA']
STA_detrended = STA - np.mean(STA)
detrended_slep_STA = np.multiply(STA_detrended, slep_win)
fSTA = fft(detrended_slep_STA, nfft)
fSTA = np.absolute(fSTA[0: int(nfft / 2) + 1])**2 / nfft**2
fSTA[1:-1] = 2 * fSTA[1:-1]
fLFP = []
for x in center:
lfp_sig = lfp[x + win]
lfp_sig_detrended = lfp_sig - np.mean(lfp_sig)
lfp_sig_slep = np.multiply(lfp_sig_detrended, slep_win)
fft_lfp = fft(lfp_sig_slep, nfft)
fLFP.append(fft_lfp)
fLFP = np.array(fLFP)
STP = np.absolute(fLFP[:, 0: int(nfft / 2) + 1])**2 / nfft**2
STP[:, 1:-1] = 2 * STP[:, 1:-1]
STP = STP.mean(0)
SFC = np.divide(fSTA, STP) * 100
PLV = np.copy(fLFP)
# Normalize
PLV = np.divide(PLV, np.absolute(PLV))
PLV[np.isnan(PLV)] = 0
PLV = np.absolute(PLV.mean(0))[0: int(nfft / 2) + 1]
PLV[1:-1] = 2 * PLV[1:-1]
graph_data['t'] = sta_data['t']
graph_data['f'] = f[ind]
graph_data['STA'] = STA
graph_data['fSTA'] = fSTA[ind]
graph_data['STP'] = STP[ind]
graph_data['SFC'] = SFC[ind]
graph_data['PLV'] = PLV[ind]
return graph_data
[docs] def event_trig_average(self, event_stamp=None, **kwargs):
"""
Averaging event-triggered LFP signals.
Parameters
----------
event_stamp : ndarray
Timestamps of the events or the spiking activities for measuring the
event triggered average of the LFP signal
**kwargs
Keywrod arguments
Returns
-------
dict
Graphical data of the analysis
"""
graph_data = oDict()
window = np.array(kwargs.get('window', [-0.5, 0.5]))
if event_stamp is None:
spike = kwargs.get('spike', None)
try:
data_type = spike.get_type()
except BaseException:
logging.error(
'The data type of the addes object cannot be determined!')
if data_type == 'spike':
event_stamp = spike.get_unit_stamp()
elif spike in self.get_spike_names():
event_stamp = self.get_spike(spike).get_unit_stamp()
if event_stamp is None:
logging.error('No valid event timestamp or spike is provided')
else:
lfp = self.get_samples() * 1000
Fs = self.get_sampling_rate()
time = self.get_timestamp()
center = time.searchsorted(event_stamp, side='left')
win = np.ceil(window * Fs).astype(int)
win = np.arange(win[0], win[1])
# Keep windows within data
center = np.array([center[i] for i in range(0, len(event_stamp))
if center[i] + win[0] >= 0 and center[i] + win[-1] < time.size])
eta = reduce(lambda y, x: y + lfp[x + win], center)
eta = eta / center.size
graph_data['t'] = win / Fs
graph_data['ETA'] = eta
graph_data['center'] = center
return graph_data
[docs] def spike_lfp_causality(self, spike=None, **kwargs):
"""
(Not implemented yet).
Analyses spike to underlying LFP causality
Parameters
----------
spike : NSpike
Spike dataset which is used for the causality analysis
**kwargs
Keywrod arguments
Returns
-------
dict
Should return graphical data of the analysis. The function is not
implemented yet.
"""
pass
[docs] def subsample(self, sample_range=None):
"""
Extract a time range from the lfp.
Parameters
----------
sample_range : tuple
the time in seconds to extract from the lfp as (lower, upper)
Returns
-------
NLfp
subsampled version of initial lfp object
"""
in_range = sample_range
sample_rate = self.get_sampling_rate()
if in_range is None:
length = int(self.get_duration() * sample_rate)
if (length != self.get_total_samples()):
logging.warning(
"Unequal calculated and recorded total lfp samples" +
"Calculated {} and recorded {}".format(
length, self.get_total_samples()))
return self
else:
new_lfp = deepcopy(self)
lfp_samples = self.get_samples()[
int(sample_rate * in_range[0]):int(sample_rate * in_range[1])]
lfp_times = self.get_timestamp()[
int(sample_rate * in_range[0]):int(sample_rate * in_range[1])]
new_lfp._set_samples(lfp_samples)
new_lfp._set_timestamp(lfp_times)
new_lfp._set_total_samples(len(lfp_samples))
new_lfp._set_duration(in_range[1] - in_range[0])
return new_lfp
[docs] def sharp_wave_ripples(self, in_range=None, **kwargs):
"""
Detect SWR events in the lfp, optionally in a given range.
This method is based on finding peaks in the
Root mean square envelope of a filtered signal between
100 and 250Hz (default params).
Peaks that are above the peak_percentile keyword argument are returned.
Parameters
----------
in_range : tuple
A range in seconds
kwargs
------
swr_lower : float
Lower band in hz
swr_upper : float
Upper band in hz
rms_window_size_ms : int
Size of the rms window in ms
percentile : float
The percentile threshold for a peak
Returns
-------
dict
lfp times, lfp samples, swr times, lfp sample rate
"""
swr_lower = kwargs.get("swr_lower", 100)
swr_higher = kwargs.get("swr_upper", 250)
rms_window_size_ms = kwargs.get("rms_window_size_ms", 7)
percentile = kwargs.get("peak_percentile", 99.5)
lfp = self.subsample(in_range)
sample_rate = lfp.get_sampling_rate()
# Estimate SWR events
filtered_lfp = butter_filter(
lfp.get_samples(), sample_rate, 10,
swr_lower, swr_higher, 'bandpass')
rms_window_size = floor((rms_window_size_ms / 1000) * sample_rate)
rms_envelope = window_rms(filtered_lfp, rms_window_size, mode="same")
p_val = np.percentile(rms_envelope, percentile)
_, peaks = find_peaks(rms_envelope, thresh=p_val)
peaks = lfp.get_timestamp()[0] + (peaks / sample_rate)
"""
Alternative way to get SWR
#rms_envelope = distinct_window_rms(filtered_lfp, rms_window_size)
#peaks = (
# longest_sleep_period[0] + peaks * rms_window_size) / sample_rate
"""
return {
"lfp times": lfp.get_timestamp(),
"lfp samples": filtered_lfp,
"swr times": peaks, "lfp sample rate": sample_rate}
[docs] def bandpower(self, band, **kwargs):
"""
Compute the average power of the signal x in a specific frequency band.
Modified from excellent article at
https://raphaelvallat.com/bandpower.html
Parameters
----------
band : list
Lower and upper frequencies of the band of interest. [lower, upper].
Keyword Arguments
-----------------
method : string
Periodogram method.
Only 'welch' is currently supported.
window_sec : float
Length of each window in seconds.
If None, window_sec = (1 / min(band)) * 2.
band_total : bool
Whether to band the total power
Default False
total_band: List
low and high frequency values for the filter on total band.
Default [1.5, 90]
unit : str
Currently support micro and milli.
The scale to return Volts in.
Defaults to micro.
Returns
------
bp : Dict
"bandpower", "total_power" and "relative_power".
"""
from scipy.signal import welch
from scipy.integrate import simps
band = np.asarray(band)
low, high = band
method = kwargs.get("method", "welch")
window_sec = kwargs.get("window_sec", 2 / (low + 0.000001))
unit = kwargs.get("unit", "micro")
scale = 1000 if unit == "micro" else 1
sf = self.get_sampling_rate()
lfp_samples = self.get_samples() * scale
band_total = kwargs.get('band_total', False)
_filter = kwargs.get('total_band', [1.5, 90])
# if prefilt:
# lfp_samples = butter_filter(lfp_samples, sf, *_filter)
# Compute the modified periodogram (Welch)
if method == 'welch':
nperseg = int(window_sec * sf)
freqs, psd = welch(lfp_samples, sf, nperseg=nperseg)
# The multaper method is more accurate but we will not use it
# Welch's method is still very good
# See MNE for the multitaper method
# from mne.time_frequency import psd_array_multitaper
# elif method == 'multitaper':
# psd, freqs = psd_array_multitaper(lfp_samples, sf, adaptive=True,
# normalization='full', verbose=0)
# Frequency resolution
freq_res = freqs[1] - freqs[0]
# Find index of band in frequency vector
idx_band = np.logical_and(freqs >= low, freqs <= high)
# Integral approximation of the spectrum using parabola (Simpson's
# rule)
bp = simps(psd[idx_band], dx=freq_res)
if band_total:
idx_band = np.logical_and(
freqs >= _filter[0],
freqs <= _filter[1])
tp = simps(psd[idx_band], dx=freq_res)
else:
tp = simps(psd, dx=freq_res)
output = {
"bandpower": bp,
"total_power": tp,
"relative_power": bp / tp}
return output
[docs] def bandpower_ratio(self, first_band, second_band, win_sec, **kwargs):
"""
Calculate the ratio in power between two bandpass filtered signals.
Note that common ranges are:
delta (1.5–4 Hz), theta (5-11 Hz)
Parameters
----------
first_band : 1d array
lower and upper bands
second_band : 1d array
lower and upper bands
win_sec : float
length of the windows to bin lfp into in seconds.
recommend 4 for eg.
Keyword Arguments
-----------------
first_name : str
name of band 1, default "Band 1"
second_name : str
name of band 2, default "Band 2"
Returns
-------
float
the ratio between the power signals.
See also
--------
nc_lfp.NLfp().bandpower()
"""
_results = oDict()
name1 = kwargs.get("first_name", "Band 1")
name2 = kwargs.get("second_name", "Band 2")
if "window_sec" not in kwargs:
kwargs["window_sec"] = win_sec
b1 = self.bandpower(first_band, **kwargs)
b2 = self.bandpower(second_band, **kwargs)
if b1["total_power"] != b2["total_power"]:
logging.error(
"Differing total power in lfp bandpower ratio calculations")
bp = b1["bandpower"] / b2["bandpower"]
key1 = name1 + " Power"
key2 = name2 + " Power"
key3 = name1 + " " + name2 + " Power Ratio"
_results[key1] = b1["bandpower"]
_results[key1 + " (Relative)"] = b1["relative_power"]
_results[key2] = b2["bandpower"]
_results[key2 + " (Relative)"] = b2["relative_power"]
_results[key3] = bp
_results["Total Power"] = b1["total_power"]
self.update_result(_results)
return bp
[docs] def save_to_hdf5(self, file_name=None, system=None):
"""
Store NLfp() object to HDF5 file.
Parameters
----------
file_name : str
Full file directory for the lfp data
system : str
Recoring system or data format
Returns
-------
None
Also see
--------
nc_hdf.Nhdf().save_lfp()
"""
hdf = Nhdf()
if file_name and system:
if os.path.exists(file_name):
self.set_filename(file_name)
self.set_system(system)
self.load()
else:
logging.error('Specified file cannot be found!')
hdf.save_lfp(lfp=self)
hdf.close()
[docs] def load_lfp_NWB(self, file_name):
"""
Decode LFP data from NWB (HDF5) file format.
Parameters
----------
file_name : str
Full file directory for the lfp data
Returns
-------
None
"""
file_name, path = file_name.split('+')
if os.path.exists(file_name):
hdf = Nhdf()
hdf.set_filename(file_name)
_record_info = {}
if path in hdf.f:
g = hdf.f[path]
elif '/processing/Neural Continuous/LFP/' + path in hdf.f:
path = '/processing/Neural Continuous/LFP/' + path
g = hdf.f[path]
else:
logging.error('Specified path does not exist!')
for key, value in g.attrs.items():
_record_info[key] = value
self.set_record_info(_record_info)
self._set_samples(hdf.get_dataset(group=g, name='data'))
self._set_timestamp(hdf.get_dataset(group=g, name='timestamps'))
self._set_total_samples(
hdf.get_dataset(group=g, name='num_samples'))
hdf.close()
else:
logging.error(file_name + ' does not exist!')
[docs] def load_lfp_Axona(self, file_name):
"""
Decode LFP data from Axona file format.
Parameters
----------
file_name : str
Full file directory for the lfp data
Returns
-------
None
"""
file_directory, file_basename = os.path.split(file_name)
file_tag, file_extension = os.path.splitext(file_basename)
file_extension = file_extension[1:]
set_file = os.path.join(file_directory, file_tag + '.set')
self._set_data_source(file_name)
self._set_source_format('Axona')
if os.path.isfile(file_name):
with open(file_name, 'rb') as f:
while True:
line = f.readline()
try:
line = line.decode('latin-1')
except BaseException:
break
if line == '':
break
if line.startswith('trial_date'):
# Blank eeg file
if line.strip() == "trial_date":
self._set_total_samples(0)
return
self._set_date(
' '.join(line.replace(',', ' ').split()[1:]))
if line.startswith('trial_time'):
self._set_time(line.split()[1])
if line.startswith('experimenter'):
self._set_experimenter(' '.join(line.split()[1:]))
if line.startswith('comments'):
self._set_comments(' '.join(line.split()[1:]))
if line.startswith('duration'):
self._set_duration(float(''.join(line.split()[1:])))
if line.startswith('sw_version'):
self._set_file_version(line.split()[1])
if line.startswith('num_chans'):
self._set_total_channel(int(''.join(line.split()[1:])))
if line.startswith('sample_rate'):
self._set_sampling_rate(
float(''.join(re.findall(r'\d+.\d+|\d+', line))))
if line.startswith('bytes_per_sample'):
self._set_bytes_per_sample(
int(''.join(line.split()[1:])))
if line.startswith(
'num_' + file_extension[:3].upper() + '_samples'):
self._set_total_samples(int(''.join(line.split()[1:])))
if line.startswith("data_start"):
break
num_samples = self.get_total_samples()
bytes_per_sample = self.get_bytes_per_sample()
f.seek(0, 0)
header_offset = []
num_iters = 0
while True:
try:
buff = f.read(10).decode('UTF-8')
num_iters += 1
except BaseException:
break
if buff == 'data_start':
header_offset = f.tell()
break
else:
f.seek(-9, 1)
if num_iters > 5000:
raise RuntimeError(
"Failed load lfp from {} - no data_start.".format(
file_name))
eeg_ID = re.findall(r'\d+', file_extension)
self.set_file_tag(1 if not eeg_ID else int(eeg_ID[0]))
max_ADC_count = 2**(8 * bytes_per_sample - 1) - 1
max_byte_value = 2**(8 * bytes_per_sample)
with open(set_file, 'r', encoding='latin-1') as f_set:
lines = f_set.readlines()
channel_lines = dict(
[tuple(map(int, re.findall(r'\d+.\d+|\d+', line)[0].split()))
for line in lines if line.startswith('EEG_ch_')]
)
channel_id = channel_lines[self.get_file_tag()]
self.set_channel_id(channel_id)
gain_lines = dict(
[tuple(map(int, re.findall(r'\d+.\d+|\d+', line)[0].split()))
for line in lines if 'gain_ch_' in line]
)
gain = gain_lines[channel_id - 1]
for line in lines:
if line.startswith('ADC_fullscale_mv'):
self._set_fullscale_mv(
int(re.findall(r'\d+.\d+|d+', line)[0]))
break
AD_bit_uvolt = 2 * self.get_fullscale_mv() / \
(gain * np.power(2, 8 * bytes_per_sample))
record_size = bytes_per_sample
sample_le = 256**(np.arange(0, bytes_per_sample, 1))
if not header_offset:
print('Error: data_start marker not found!')
else:
f.seek(header_offset, 0)
byte_buffer = np.fromfile(f, dtype='uint8')
len_bytebuffer = len(byte_buffer)
end_offset = len('\r\ndata_end\r')
lfp_wave = np.zeros([num_samples, ], dtype=np.float64)
for k in np.arange(0, bytes_per_sample, 1):
byte_offset = k
sample_value = (
sample_le[k] * byte_buffer[byte_offset:byte_offset + len_bytebuffer - end_offset - record_size:record_size])
if sample_value.size < num_samples:
sample_value = np.append(sample_value, np.zeros(
[num_samples - sample_value.size, ]))
sample_value = sample_value.astype(
np.float64, casting='unsafe', copy=False)
np.add(lfp_wave, sample_value, out=lfp_wave)
np.putmask(lfp_wave, lfp_wave > max_ADC_count,
lfp_wave - max_byte_value)
self._set_samples(lfp_wave * AD_bit_uvolt)
self._set_timestamp(
np.arange(0, num_samples, 1) / self.get_sampling_rate())
else:
logging.error(
"No lfp file found for file {}".format(file_name))
[docs] def load_lfp_Neuralynx(self, file_name):
"""
Decode LFP data from Neuralynx file format.
Parameters
----------
file_name : str
Full file directory for the lfp data
Returns
-------
None
"""
self._set_data_source(file_name)
self._set_source_format('Neuralynx')
# Format description for the NLX file:
# NeuroChaT subsamples the original recording from 32000 to 250
resamp_freq = 250
header_offset = 16 * 1024 # fixed for NLX files
bytes_per_timestamp = 8
bytes_chan_no = 4
bytes_sample_freq = 4
bytes_num_valid_samples = 4
bytes_per_sample = 2
samples_per_record = 512
max_byte_value = np.power(2, bytes_per_sample * 8)
max_ADC_count = np.power(2, bytes_per_sample * 8 - 1) - 1
AD_bit_uvolt = 10**-6
self._set_bytes_per_sample(bytes_per_sample)
record_size = None
with open(file_name, 'rb') as f:
while True:
line = f.readline()
try:
line = line.decode('UTF-8')
except BaseException:
break
if line == '':
break
if 'SamplingFrequency' in line:
# We are subsampling from the blocks of 512 samples per
# record
self._set_sampling_rate(
float(''.join(re.findall(r'\d+.\d+|\d+', line))))
if 'RecordSize' in line:
record_size = int(
''.join(re.findall(r'\d+.\d+|\d+', line)))
if 'Time Opened' in line:
self._set_date(re.search(r'\d+/\d+/\d+', line).group())
self._set_time(re.search(r'\d+:\d+:\d+', line).group())
if 'FileVersion' in line:
self._set_file_version(line.split()[1])
if 'ADMaxValue' in line:
max_ADC_count = float(
''.join(re.findall(r'\d+.\d+|\d+', line)))
if 'ADBitVolts' in line:
AD_bit_uvolt = float(
''.join(re.findall(r'\d+.\d+|\d+', line))) * (10**6)
# gain = 1 assumed to keep in similarity to Axona
self._set_fullscale_mv(max_byte_value * AD_bit_uvolt / 2)
if not record_size:
record_size = bytes_per_timestamp + \
bytes_chan_no + \
bytes_sample_freq + \
bytes_num_valid_samples + \
bytes_per_sample * samples_per_record
time_offset = 0
sample_freq_offset = bytes_per_timestamp + bytes_chan_no
num_valid_samples_offset = sample_freq_offset + bytes_sample_freq
sample_offset = num_valid_samples_offset + bytes_num_valid_samples
f.seek(0, 2)
num_samples = int((f.tell() - header_offset) / record_size)
f.seek(header_offset, 0)
time = np.array([])
lfp_wave = np.array([])
sample_le = 256**(np.arange(0, bytes_per_sample, 1))
for _ in np.arange(num_samples):
sample_bytes = np.fromfile(f, dtype='uint8', count=record_size)
block_start = int.from_bytes(
sample_bytes[time_offset + np.arange(bytes_per_timestamp)],
byteorder='little', signed=False) / 10**6
valid_samples = int.from_bytes(
sample_bytes[num_valid_samples_offset +
np.arange(bytes_num_valid_samples)],
byteorder='little', signed=False)
sampling_freq = int.from_bytes(
sample_bytes[sample_freq_offset +
np.arange(bytes_sample_freq)],
byteorder='little', signed=False)
wave_bytes = sample_bytes[sample_offset + np.arange(valid_samples * bytes_per_sample)]\
.reshape([valid_samples, bytes_per_sample])
block_wave = np.dot(wave_bytes, sample_le)
# for k in np.arange(valid_samples):
# block_wave[k] = int.from_bytes(
# sample_bytes[sample_offset+ k*bytes_per_sample+ \
# np.arange(bytes_per_sample)], byteorder='little',
# signed=False)
np.putmask(block_wave, block_wave > max_ADC_count,
block_wave - max_byte_value)
block_wave = block_wave * AD_bit_uvolt
block_time = block_start + \
np.arange(valid_samples) / sampling_freq
interp_time = np.arange(
block_start, block_time[-1], 1 / resamp_freq)
interp_wave = np.interp(interp_time, block_time, block_wave)
time = np.append(time, interp_time)
lfp_wave = np.append(lfp_wave, interp_wave)
time -= time.min()
self._set_samples(lfp_wave)
self._set_total_samples(lfp_wave.size)
self._set_timestamp(time)
self._set_sampling_rate(resamp_freq)
[docs] def find_artf(self, sd_thresh=3, min_artf_freq=8):
"""
Obtain locations of signal above threshold in windows.
NOTE
----
This function is still a work in progress and may see future changes.
Parameters
----------
sd_thresh: float
threshold to exclude artf
min_artf_freq: float
minimum artf freq - Used to locate artf blocks
eg. 250 Hz sampling rate, 40Hz min_artf_freq: locates artf within 9 samples of each other
"""
samples = self.get_samples()
Fs = self.get_sampling_rate()
std = np.std(samples)
mean = np.mean(samples)
over_thresh = np.logical_or(
samples >= mean + sd_thresh * std,
samples <= mean - sd_thresh * std)
# use np.where on the logical or if not using find_true_ranges
_, thr_locs = find_true_ranges(
[i for i in range(len(samples))], over_thresh,
min_range=1, return_idxs=True)
final_thr_locs = []
if len(thr_locs) == 0:
print("No artefacts found.")
thr_vals = []
thr_time = []
else:
for i in range(len(thr_locs) - 1):
# Set based on sampling freq/max freq of interest.
if thr_locs[i + 1] - thr_locs[i] <= ceil(Fs / min_artf_freq):
for j in range(thr_locs[i], thr_locs[i + 1]):
final_thr_locs.append(j)
else:
final_thr_locs.append(thr_locs[i])
final_thr_locs.append(thr_locs[-1])
# print('original: ', len(thr_locs))
thr_locs = np.array(final_thr_locs)
# print('changed: ', len(thr_locs))
thr_vals = self.get_samples()[thr_locs]
thr_time = self.get_timestamp()[thr_locs]
per_removed = len(thr_locs) / len(samples) * 100
# print(len(thr_locs), len(thr_time))
return mean, std, thr_locs, thr_vals, thr_time, per_removed
[docs] def __str__(self):
"""Return a friendly string representation of the object."""
return ("{} object with tag {} from channel {} at {}Hz with {} samples".format(
"NeuroChaT NLfp", self._file_tag, self._channel_id,
self.get_sampling_rate(),
self.get_total_samples()
))