Source code for neurochat.nc_spike

# -*- coding: utf-8 -*-
"""
This module implements NSpike Class for NeuroChaT software.

@author: Md Nurul Islam; islammn at tcd dot ie

"""

import os

import re

import logging
from collections import OrderedDict as oDict
from copy import deepcopy

import numpy as np

from neurochat.nc_utils import extrema, find, residual_stat

from neurochat.nc_hdf import Nhdf
from neurochat.nc_base import NBase

from scipy.optimize import curve_fit


[docs]class NSpike(NBase): """ This data class contains information about the neural spikes. It decodes data from different formats and analyses single units in the recording. Attributes ---------- _unit_no : int The current unit number being considered. _unit_stamp : np.ndarray The timestamps of the current unit. _timestamp : np.ndarray The timestamps of all units. _unit_list : list of int The list of all units available. _unit_Tags : list of int The tag of each timestamp, denoting which unit it represents. It must match the length of _timestamp. _waveform : OrderedDict Spike waveforms where each key represents one channel The data in a channel is a numpy array. """ def __init__(self, **kwargs): """See the class description.""" super().__init__(**kwargs) self._unit_no = kwargs.get('unit_no', 0) self._unit_stamp = [] self._timestamp = [] self._unit_list = [] self._unit_Tags = [] self._waveform = [] self.set_record_info({'Timebase': 1, 'Samples per spike': 1, 'No of spikes': 0, 'Channel IDs': None}) self.__type = 'spike'
[docs] def get_type(self): """ Return the type of object. For NSpike, this is always `spike` type. Parameters ---------- None Returns ------- str """ return self.__type
[docs] def get_unit_tags(self): """ Return the unit number or tags of the clustered units. Parameters ---------- None Returns ------- list or ndarray """ return self._unit_Tags
[docs] def set_unit_tags(self, new_tags): """ Set the number or tags of the clustered units. Parameters ---------- new_tags : list or ndarray Tags for each spiking wave Returns ------- None """ if len(new_tags) == len(self._timestamp): self._unit_Tags = new_tags self._set_unit_list() else: logging.error('No of tags spikes does not match with no of spikes')
[docs] def get_unit_list(self): """ Get the list of the units. Parameters ---------- None Returns ------- list List of the unique tags of spiking-waveforms from clustering """ return self._unit_list
def _set_unit_list(self): """ Set the list of units from the list of unit tags. Parameters ---------- None Returns ------- None """ self._unit_list = list(map(int, set(self._unit_Tags))) if 0 in self._unit_list: self._unit_list.remove(0) try: self._unit_list = sorted(self._unit_list) except BaseException: self._unit_list = self._unit_list
[docs] def set_unit_no(self, unit_no=None, spike_name=None): """ Set the unit number of the spike dataset to analyse. Parameters ---------- unit_no : int Unit or cell number to analyse spike_name : str The spike to set the unit on. Default is None. Returns ------- None """ if isinstance(unit_no, int): if unit_no in self.get_unit_list(): self._unit_no = unit_no self._set_unit_stamp() else: self._unit_no = self.get_unit_list()[0] self._set_unit_stamp() logging.error( f"Unit number {unit_no} " f"is not a valid unit on {self._filename} " + f"with available units {self.get_unit_list()} " + f"analysing unit {self._unit_no} instead.") else: if spike_name is None: spike_name = self.get_spike_names() if len(unit_no) == len(spike_name): spikes = self.get_spike(spike_name) for i, num in enumerate(unit_no): if num in spikes[i].get_unit_list(): spikes[i].set_unit_no(num) else: logging.error( 'Unit no. to set are not as many as child spikes!')
[docs] def get_unit_no(self, spike_name=None): """ Get currently set unit number of the spike dataset to analyse. Parameters ---------- None Returns ------- int Unit or cell number set to analyse """ if spike_name is None: unit_no = self._unit_no else: unit_no = [] spikes = self.get_spike(spike_name) for spike in spikes: unit_no.append(spike._unit_no) return unit_no
[docs] def get_timestamp(self, unit_no=None): """ Return the timestamps of the spike-waveforms of specified unit. Parameters ---------- None Returns ------- ndarray Timestamps of the spiking waveforms """ if unit_no is None: return self._timestamp else: if unit_no in self._unit_list: return self._timestamp[self._unit_Tags == unit_no] else: logging.warning('Unit ' + str(unit_no) + ' is not present in the spike data')
def _set_timestamp(self, timestamp=None): """ Set the timestamps for all spiking waveforms in the recording. Parameters ---------- timestamp : list or ndarray Timestamps of all spiking waveforms Returns ------- None """ if timestamp is not None: self._timestamp = timestamp
[docs] def get_unit_stamp(self): """ Get the timestamps for currently set unit to analyse. Parameters ---------- None Returns ------- list or ndarray Timestamps for currently set unit """ return self.get_timestamp(self._unit_no)
def _set_unit_stamp(self): """ Set timestamps of the unit currently set to analyse. Parameters ---------- None Returns ------- int Unit or cell number set to analyse """ self._unit_stamp = self.get_unit_stamp()
[docs] def get_unit_spikes_count(self, unit_no=None): """ Return the number of spikes in a unit. Parameters ---------- unit_no : int Units whose spike count is returned Returns ------- int Number of units spikes of a unit in a recording session """ if unit_no is None: unit_no = self._unit_no if unit_no in self._unit_list: return sum(self._unit_Tags == unit_no)
[docs] def get_waveform(self): """ Return spike-waveforms. Parameters ---------- None Returns ------- OrderedDict Dictionary of spiking waveforms where keys represent the channel number """ return self._waveform
def _set_waveform(self, spike_waves=[]): """ Set spike waveform to the NSpike() object. Parameters ---------- spike_waves : OrderedDict Spike waveforms where each key represents one channel Returns ------- None """ if spike_waves: self._waveform = spike_waves
[docs] def get_unit_waves(self, unit_no=None): """ Return spike waveform of a specified unit. Parameters ---------- unit_no : int Unit whose waveforms are to be returned Returns ------- OrderedDict Waveforms of the specified unit. If None, waveforms of currently set unit are returned """ if unit_no is None: unit_no = self._unit_no _waves = oDict() for chan, wave in self._waveform.items(): _waves[chan] = wave[self._unit_Tags == unit_no, :] return _waves
[docs] def get_unit_stamps_in_ranges(self, ranges): """ Return the unit timestamps in a list of ranges. Parameters ---------- ranges : list A list of tuples indicating time ranges to get stamps in. Should be specified in the same unit as the timestamps. This is usually in seconds. Returns ------- list The timestamps """ stamps = self.get_unit_stamp() new_stamps = [ val for val in stamps if any(lower <= val <= upper for (lower, upper) in ranges) ] return new_stamps
[docs] def subsample(self, sample_range=None): """ Extract a time range from the spikes. Parameters ---------- sample_range : tuple the time in seconds to extract from the spikes Returns ------- NSpike subsampled version of initial spike object """ if sample_range is None: return self new_spike = deepcopy(self) stamps = self.get_timestamp() lower, upper = sample_range sample_spike_idxs = ( (stamps <= upper) & (stamps >= lower)).nonzero() new_spike_times = stamps[sample_spike_idxs] new_tags = self.get_unit_tags()[sample_spike_idxs] new_waveform = new_spike.get_waveform() for ch in new_waveform.keys(): new_waveform[ch] = new_waveform[ch][sample_spike_idxs, :].squeeze() new_spike._set_timestamp(new_spike_times) new_spike.set_unit_tags(new_tags) new_spike._set_waveform(new_waveform) new_spike._set_duration(upper - lower) return new_spike
[docs] def shift_spike_times(self, n_shuffles, limit=None): """ Randomly shift the spike times for the currently set unit. Parameters ---------- n_shuffles : int The number of times to shuffle. limit : int How much to shuffle by in seconds. limit = None implies enirely random shuffle limit = 'x' implies shuffles in the range [-x x] Returns ------- np.ndarray The shifted spike times, shape (n_shuffles, n_spikes) """ dur = self.get_duration() if limit is None: low, high = -dur, dur else: low, high = -limit, limit shift = np.random.uniform(low=low, high=high, size=n_shuffles) ftimes = self.get_unit_stamp() shift_ftimes = np.zeros(shape=(n_shuffles, len(ftimes)), dtype=np.float64) for i in np.arange(n_shuffles): shift_ftimes[i] = ftimes + shift[i] # Wrapping up the time shift_ftimes[i][shift_ftimes[i] > dur] -= dur shift_ftimes[i][shift_ftimes[i] < 0] += dur shift_ftimes[i] = np.sort(shift_ftimes[i]) return shift_ftimes
[docs] def shuffle_spike_times(self, n_shuffles, limit=None): """ Randomly shuffle the spike times for the currently set unit. Parameters ---------- n_shuffles : int The number of times to shuffle. limit : int How much to shuffle by in seconds. limit = None implies enirely random shuffle limit = 'x' implies shuffles in the range [-x x] Returns ------- np.ndarray The shuffled spike times, shape (n_shuffles, n_spikes) """ dur = self.get_duration() if limit is None: low, high = -dur, dur else: low, high = -limit, limit ftimes = self.get_unit_stamp() shift = np.random.uniform(low=low, high=high, size=(n_shuffles, len(ftimes))) shift_ftimes = np.zeros(shape=(n_shuffles, len(ftimes)), dtype=np.float64) for i in np.arange(n_shuffles): shift_ftimes[i] = ftimes + shift[i] # Wrapping up the time shift_ftimes[i][shift_ftimes[i] > dur] -= dur shift_ftimes[i][shift_ftimes[i] < 0] += dur shift_ftimes[i] = np.sort(shift_ftimes[i]) if shift_ftimes[i][0] < 0: raise ValueError("Shuffled below 0") elif shift_ftimes[i][-1] >= dur: raise ValueError("Shuffled above recording length") return shift_ftimes
[docs] def load(self, filename=None, system=None): """ Load spike datasets. Parameters ---------- filename : str Name of the spike datafile system : str Recording system or format of the spike data file Returns ------- None See also -------- load_spike_axona(), load_spike_NLX(), load_spike_NWB() """ if system is None: system = self._system else: self._system = system if filename is None: filename = self._filename else: self._filename = filename loader = getattr(self, 'load_spike_' + system) loader(filename)
[docs] def add_spike(self, spike=None, **kwargs): """ Add new spike node to current NSpike() object. Parameters ---------- spike : NSpike NSPike object. If None, new object is created Returns ------- `:obj:NSpike` A new NSpike() object """ new_spike = self._add_node(self.__class__, spike, 'spike', **kwargs) return new_spike
[docs] def load_spike(self, names=None): """ Load datasets of the spike nodes. The name of each node is used for obtaining the filenames. Parameters ---------- names : list of str Names of the nodes to load. If None, current NSpike() object is loaded Returns ------- None """ if names is None: self.load() elif names == 'all': for spike in self._spikes: spike.load() else: logging.error("Spikes by name has yet to be implemented")
# for name in names: # spike = self.get_spikes_by_name(name) # spike.load()
[docs] def add_lfp(self, lfp=None, **kwargs): """ Add new LFP node to current NSpike() object. Parameters ---------- lfp : NLfp NLfp object. If None, new object is created Returns ------- `:obj:Nlfp` A new NLfp() object """ try: data_type = lfp.get_type() except BaseException: logging.error( 'The data type of the added object cannot be determined!') if data_type == 'lfp': cls = lfp.___class__ else: cls = None new_lfp = self._add_node(cls, lfp, 'lfp', **kwargs) return new_lfp
[docs] def load_lfp(self, names='all'): """ Load datasets of the LFP nodes. The name of each node is used for obtaining the filenames. Parameters ---------- names : list of str Names of the nodes to load. If `all`, all LFP nodes are loaded Returns ------- None """ if names == 'all': for lfp in self._lfp: lfp.load() else: logging.error("Lfp by name has yet to be implemented")
# for name in names: # lfp = self.get_lfp_by_name(name) # lfp.load()
[docs] def wave_property(self): """ Calculate different waveform properties for currently set unit. Parameters ---------- None Returns ------- dict Graphical data of the analysis """ _result = oDict() graph_data = {} def argpeak(data): data = np.array(data) peak_loc = [j for j in range(7, len(data)) if data[j] <= 0 and data[j - 1] > 0] return peak_loc[0] if peak_loc else 0 def argtrough1(data, peak_loc): data = data.tolist() trough_loc = [peak_loc - j for j in range(peak_loc - 2) if data[peak_loc - j] >= 0 and data[peak_loc - j - 1] <= 0] return trough_loc[0] if trough_loc else 0 def wave_width(wave, peak, thresh=0.25): p_loc, p_val = peak Len = wave.size if p_loc: w_start = find(wave[:p_loc] <= thresh * p_val, 1, 'last') w_start = w_start[0] if w_start.size else 0 w_end = find(wave[p_loc:] <= thresh * p_val, 1, 'first') w_end = p_loc + w_end[0] if w_end.size else Len else: w_start = 1 w_end = Len return w_end - w_start num_spikes = self.get_unit_spikes_count() _result['Number of Spikes'] = num_spikes _result['Mean Spiking Freq'] = num_spikes / self.get_duration() _waves = self.get_unit_waves() samples_per_spike = self.get_samples_per_spike() tot_chans = self.get_total_channels() meanWave = np.empty([samples_per_spike, tot_chans]) stdWave = np.empty([samples_per_spike, tot_chans]) width = np.empty([num_spikes, tot_chans]) amp = np.empty([num_spikes, tot_chans]) height = np.empty([num_spikes, tot_chans]) for i, (chan, wave) in enumerate(_waves.items()): if (wave.shape[0] == 1): slope = np.array([(np.gradient(wave[0]))]) else: slope = np.gradient(wave, axis=1) meanWave[:, i] = np.mean(wave, 0) stdWave[:, i] = np.std(wave, 0) max_val = wave.max(1) peak_val, trough1_val = 0, 0 if max_val.max() > 0: peak_loc = [argpeak(slope[I, :]) for I in range(num_spikes)] peak_val = [wave[I, peak_loc[I]] for I in range(num_spikes)] trough1_loc = [argtrough1(slope[I, :], peak_loc[I]) for I in range(num_spikes)] trough1_val = [wave[I, trough1_loc[I]] for I in range(num_spikes)] peak_loc = np.array(peak_loc) peak_val = np.array(peak_val) trough1_loc = np.array(trough1_loc) trough1_val = np.array(trough1_val) width[:, i] = np.array([wave_width(wave[I, :], (peak_loc[I], peak_val[I]), 0.25) for I in range(num_spikes)]) amp[:, i] = peak_val - trough1_val height[:, i] = peak_val - wave.min(1) max_chan = amp.mean(0).argmax() width = width[:, max_chan] * 10**6 / self.get_sampling_rate() amp = amp[:, max_chan] height = height[:, max_chan] graph_data = {'Mean wave': meanWave, 'Std wave': stdWave, 'Amplitude': amp, 'Width': width, 'Height': height, 'Max channel': max_chan} _result.update({'Mean amplitude': amp.mean(), 'Std amplitude': amp.std(), 'Mean height': height.mean(), 'Std height': height.std(), 'Mean width': width.mean(), 'Std width': width.std()}) self.update_result(_result) return graph_data
[docs] def isi(self, bins='auto', bound=None, density=False, refractory_threshold=2): """ Calculate the ISI histogram of the spike train. Parameters ---------- bins : str or int Number of ISI histogram bins. If 'auto', NumPy default is used bound : int Length of the ISI histogram in msec density : bool If true, normalized histogram is calculated refractory_threshold : int Length of the refractory period in msec Returns ------- dict Graphical data of the analysis """ graph_data = oDict() _results = oDict() unitStamp = self.get_unit_stamp() isi = 1000 * np.diff(unitStamp) below_refractory = isi[isi < refractory_threshold] graph_data['isiHist'], edges = np.histogram( isi, bins=bins, range=bound, density=density) graph_data['isiBins'] = edges[:-1] graph_data['isi'] = isi graph_data['maxCount'] = graph_data['isiHist'].max() graph_data['isiBefore'] = isi[:-1] graph_data['isiAfter'] = isi[1:] _results["Mean ISI"] = isi.mean() _results["Median ISI"] = np.median(isi) _results["Std ISI"] = isi.std() _results["CV ISI"] = _results["Std ISI"] / _results["Mean ISI"] _results["Mean Log ISI"] = np.log10(isi).mean() _results["Std Log ISI"] = np.log10(isi).std() _results["Refractory violation"] = ( below_refractory.size / unitStamp.size) self.update_result(_results) return graph_data
[docs] def isi_corr(self, spike=None, **kwargs): """ Calculate the correlation of ISI histogram. Parameters ---------- spike : NSpike() If specified, it calulates cross-correlation. **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis """ graph_data = oDict() if spike is None: _unit_stamp = np.copy(self.get_unit_stamp()) elif isinstance(spike, int): if spike in self.get_unit_list(): _unit_stamp = self.get_timestamp(spike) else: if isinstance(spike, str): spike = self.get_spike(spike) if isinstance(spike, self.__class__): _unit_stamp = spike.get_unit_stamp() else: logging.error('No valid spike specified') _corr = self.psth(_unit_stamp, **kwargs) graph_data['isiCorrBins'] = _corr['bins'] graph_data['isiAllCorrBins'] = _corr['all_bins'] center = find(_corr['bins'] == 0, 1, 'first')[0] graph_data['isiCorr'] = _corr['psth'] graph_data['isiCorr'][center] = graph_data['isiCorr'][center] \ - np.min([self.get_unit_stamp().size, _unit_stamp.size]) return graph_data
[docs] def psth(self, event_stamp, **kwargs): """ Calculate peri-stimulus time histogram (PSTH). Parameters ---------- event_stamp : ndarray Event timestamps **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis """ graph_data = oDict() bins = kwargs.get('bins', 1) if isinstance(bins, int): bound = np.array(kwargs.get('bound', [-500, 500])) bins = np.hstack( (np.arange(bound[0], 0, bins), np.arange(0, bound[1] + bins, bins))) bins = bins / 1000 # converted to sec n_bins = len(bins) - 1 hist_count = np.zeros([n_bins, ]) unitStamp = self.get_unit_stamp() for it in range(event_stamp.size): tmp_count, edges = np.histogram( unitStamp - event_stamp[it], bins=bins) hist_count = hist_count + tmp_count graph_data['psth'] = hist_count graph_data['bins'] = 1000 * edges[:-1] # Included in case the last point is needed graph_data['all_bins'] = 1000 * edges return graph_data
[docs] def burst(self, burst_thresh=5, ibi_thresh=50): """ Analysis of bursting properties of the spiking train. Parameters ---------- burst_thresh : int Minimum ISI between consecutive spikes in a burst ibi_thresh : int Minimum inter-burst interval between two bursting groups of spikes Returns ------- None """ _results = oDict() unitStamp = self.get_unit_stamp() isi = 1000 * np.diff(unitStamp) burst_start = [] burst_end = [] burst_duration = [] spikesInBurst = [] bursting_isi = [] num_burst = 0 ibi = [] duty_cycle = [] k = 0 while k < isi.size: if isi[k] <= burst_thresh: burst_start.append(k) spikesInBurst.append(2) bursting_isi.append(isi[k]) burst_duration.append(isi[k]) m = k + 1 while m < isi.size and isi[m] <= burst_thresh: spikesInBurst[num_burst] += 1 bursting_isi.append(isi[m]) burst_duration[num_burst] += isi[m] m += 1 # to compensate for the span of the last spike burst_duration[num_burst] += 1 burst_end.append(m) k = m + 1 num_burst += 1 else: k += 1 if num_burst: for j in range(0, num_burst - 1): ibi.append(unitStamp[burst_start[j + 1]] - unitStamp[burst_end[j]]) # ibi in sec, burst_duration in ms duty_cycle = np.divide(burst_duration[1:], ibi) / 1000 else: logging.warning( 'No burst detected in {}'.format(self.get_filename())) spikesInBurst = np.array( spikesInBurst) if spikesInBurst else np.array([]) bursting_isi = np.array(bursting_isi) if bursting_isi else np.array([]) # in sec unit, so converted to ms ibi = 1000 * np.array(ibi) if ibi else np.array([]) burst_duration = np.array( burst_duration) if burst_duration else np.array([]) duty_cycle = np.array(duty_cycle) if len(duty_cycle) else np.array([]) _results['Total burst'] = num_burst _results['Total bursting spikes'] = spikesInBurst.sum() _results['Mean bursting ISI ms'] = bursting_isi.mean( ) if bursting_isi.any() else None _results['Std bursting ISI ms'] = bursting_isi.std( ) if bursting_isi.any() else None _results['Mean spikes per burst'] = spikesInBurst.mean( ) if spikesInBurst.any() else None _results['Std spikes per burst'] = spikesInBurst.std( ) if spikesInBurst.any() else None _results['Mean burst duration ms'] = burst_duration.mean( ) if burst_duration.any() else None _results['Std burst duration'] = burst_duration.std( ) if burst_duration.any() else None _results['Mean duty cycle'] = duty_cycle.mean( ) if duty_cycle.any() else None _results['Std duty cycle'] = duty_cycle.std( ) if duty_cycle.any() else None _results['Mean IBI'] = ibi.mean() if ibi.any() else None _results['Std IBI'] = ibi.std() if ibi.any() else None _results['Propensity to burst'] = spikesInBurst.sum() / unitStamp.size self.update_result(_results)
[docs] def theta_index(self, **kwargs): """ Analysis of theta-modulation of a unit. Parameters ---------- **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis """ p_0 = kwargs.get('start', [6, 0.1, 0.05]) lb = kwargs.get('lower', [4, 0, 0]) ub = kwargs.get('upper', [14, 5, 0.1]) _results = oDict() graph_data = self.isi_corr(**kwargs) corrBins = graph_data['isiCorrBins'] corrCount = graph_data['isiCorr'] m = corrCount.max() center = find(corrBins == 0, 1, 'first')[0] x = corrBins[center:] / 1000 y = corrCount[center:] y_fit = np.empty([corrBins.size, ]) # This is for the double-exponent dip model # def fit_func(x, a, f, tau1, b, c1, tau2, c2, tau3): # return a*np.cos(2*np.pi*f*x)*np.exp(-np.abs(x)/tau1)+ b+ \ # c1*np.exp(-np.abs(x)/tau2)- c2*np.exp(-np.abs(x)/tau3) # popt, pcov = curve_fit(fit_func, x, y, \ # p0=[m, p_0[0], p_0[1], m, m, p_0[2], m, 0.005], \ # bounds=([0, lb[0], lb[1], 0, 0, lb[2], 0, 0], \ # [m, ub[0], ub[1], m, m, ub[2], m, 0.01]), # max_nfev=100000) # a, f, tau1, b, c1, tau2, c2, tau3 = popt # This is for the single-exponent dip model def fit_func(x, a, f, tau1, b, c, tau2): return ( a * np.cos(2 * np.pi * f * x) * np.exp(-np.abs(x) / tau1) + b + c * np.exp(-(x / tau2)**2)) try: popt, pcov = curve_fit( fit_func, x, y, p0=[m, p_0[0], p_0[1], m, m, p_0[2]], bounds=([0, lb[0], lb[1], 0, -m, lb[2]], [m, ub[0], ub[1], m, m, ub[2]]), max_nfev=100000) except Exception as e: logging.error("Failed curve_fit in theta_index: {} ".format(e)) _results['Theta Index'] = None _results['TI fit freq Hz'] = None _results['TI fit tau1 sec'] = None _results['TI adj Rsq'] = None _results['TI Pearse R'] = None _results['TI Pearse P'] = None self.update_result(_results) return None a, f, tau1, b, c, tau2 = popt y_fit[center:] = fit_func(x, *popt) y_fit[:center] = np.flipud(y_fit[center:]) gof = residual_stat(y, y_fit[center:], 6) graph_data['corrFit'] = y_fit _results['Theta Index'] = a / b _results['TI fit freq Hz'] = f _results['TI fit tau1 sec'] = tau1 _results['TI adj Rsq'] = gof['adj Rsq'] _results['TI Pearse R'] = gof['Pearson R'] _results['TI Pearse P'] = gof['Pearson P'] self.update_result(_results) return graph_data
[docs] def theta_skip_index(self, **kwargs): """ Analysis of theta-skipping of a unit. Parameters ---------- **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis """ p_0 = kwargs.get('start', [6, 0.1, 0.05]) lb = kwargs.get('lower', [4, 0, 0]) ub = kwargs.get('upper', [14, 5, 0.1]) _results = oDict() graph_data = self.isi_corr(**kwargs) corrBins = graph_data['isiCorrBins'] corrCount = graph_data['isiCorr'] m = corrCount.max() center = find(corrBins == 0, 1, 'first')[0] x = corrBins[center:] / 1000 y = corrCount[center:] y_fit = np.empty([corrBins.size, ]) # This is for the double-exponent dip model def fit_func(x, a1, f1, a2, f2, tau1, b, c1, tau2, c2, tau3): return ( (a1 * np.cos(2 * np.pi * f1 * x) + a2 * np.cos(2 * np.pi * f2 * x)) * np.exp(-np.abs(x) / tau1) + b + c1 * np.exp(-np.abs(x) / tau2) - c2 * np.exp(-np.abs(x) / tau3)) popt, pcov = curve_fit( fit_func, x, y, p0=[m, p_0[0], m, p_0[0] / 2, p_0[1], m, m, p_0[2], m, 0.005], bounds=( [0, lb[0], 0, lb[0] / 2, lb[1], 0, 0, lb[2], 0, 0], [m, ub[0], m, ub[0] / 2, ub[1], m, m, ub[2], m, 0.01]), max_nfev=100000) a1, f1, a2, f2, tau1, b, c1, tau2, c2, tau3 = popt # This is for the single-exponent dip model # def fit_func(x, a1, f1, a2, f2, tau1, b, c, tau2): # return (a1*np.cos(2*np.pi*f1*x)+ a2*np.cos(2*np.pi*f2*x))*np.exp(-np.abs(x)/tau1)+ b+ \ # c*np.exp(-(x/tau2)**2) # popt, pcov = curve_fit(fit_func, x, y, \ # p0=[m, p_0[0], m, p_0[0]/2, p_0[1], m, m, p_0[2]], \ # bounds=([0, lb[0], 0, lb[0]/2, lb[1], 0, -m, lb[2]], \ # [m, ub[0], m, ub[0]/2, ub[1], m, m, ub[2]]), # max_nfev=100000) # a1, f1, a2, f2, tau1, b, c, tau2 = popt temp_fit = fit_func(x, *popt) y_fit[center:] = temp_fit y_fit[:center] = np.flipud(temp_fit) peak_val, peak_loc = extrema(temp_fit[find(x >= 50 / 1000)])[0:2] if len(peak_val) >= 2: skipIndex = (peak_val[1] - peak_val[0]) / \ np.max(np.array([peak_val[1], peak_val[0]])) else: skipIndex = None gof = residual_stat(y, temp_fit, 6) graph_data['corrFit'] = y_fit _results['Theta Skip Index'] = skipIndex _results['TS jump factor'] = a2 / (a1 + a2) if skipIndex else None _results['TS f1 freq Hz'] = f1 if skipIndex else None _results['TS f2 freq Hz'] = f2 if skipIndex else None _results['TS freq ratio'] = f1 / f2 if skipIndex else None _results['TS tau1 sec'] = tau1 if skipIndex else None _results['TS adj Rsq'] = gof['adj Rsq'] _results['TS Pearse R'] = gof['Pearson R'] _results['TS Pearse P'] = gof['Pearson P'] self.update_result(_results) return graph_data
[docs] def phase_dist(self, lfp=None, **kwargs): """ Analysis of spike to LFP phase distribution. Delegates to NLfp().phase_dist() Parameters ---------- lfp : NLfp LFP object which contains the LFP data **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis See also -------- nc_lfp.NLfp().phase_dist() """ if lfp is None: logging.error('LFP data not specified!') else: try: lfp.phase_dist(self.get_unit_stamp(), **kwargs) except BaseException: logging.error('No phase_dist() method in lfp data specified!')
[docs] def plv(self, lfp=None, **kwargs): """ Calculate phase-locking value of spike train to underlying LFP signal. Delegates to NLfp().plv() Parameters ---------- lfp : NLfp LFP object which contains the LFP data **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis See also -------- nc_lfp.NLfp().plv() """ if lfp is None: logging.error('LFP data not specified!') else: try: lfp.plv(self.get_unit_stamp(), **kwargs) except BaseException: logging.error('No plv() method in lfp data specified!')
[docs] def spike_lfp_causality(self, lfp=None, **kwargs): """ Analyse spike to underlying LFP causality. Delegates to NLfp().spike_lfp_causality() Parameters ---------- **kwargs Keyword arguments Returns ------- dict Graphical data of the analysis See also -------- nc_lfp.NLfp().spike_lfp_causality() """ if lfp is None: logging.error('LFP data not specified!') else: try: lfp.spike_lfp_causality(self.get_unit_stamp(), **kwargs) except BaseException: logging.error('No sfc() method in lfp data specified!')
def _set_total_spikes(self, spike_count=1): """ Set the total number of spikes as part of storing the recording information. Parameters ---------- spike_count : int Total number of spikes Returns ------- None """ self._record_info['No of spikes'] = spike_count self.spike_count = spike_count def _set_total_channels(self, tot_channels=1): """ Set the value of number of channels as part of storing the recording information. Parameters ---------- tot_channels : int Total number of channels Returns ------- None """ self._record_info['No of channels'] = tot_channels def _set_channel_ids(self, channel_ids): """ Set identity of the channels as part of storing the recording information. Parameters ---------- channel_ids : int Total number of channels Returns ------- None """ self._record_info['Channel IDs'] = channel_ids def _set_timestamp_bytes(self, bytes_per_timestamp): """ Set `bytes per timestamp` value as part of storing the recording information. Parameters ---------- bytes_per_timestamp : int Total number of bytes to represent timestamp in the binary file Returns ------- None """ self._record_info['Bytes per timestamp'] = bytes_per_timestamp def _set_timebase(self, timebase=1): """ Set timbase for spike event timestamps as part of storing the recording information. Parameters ---------- timebase : int Timebase for the spike event timestamps Returns ------- None """ self._record_info['Timebase'] = timebase def _set_sampling_rate(self, sampling_rate=1): """ Set the sampling rate of the spike waveform as part of storing the recording information. Parameters ---------- sampling_rate : int Sampling rate of the spike waveforms Returns ------- None """ self._record_info['Sampling rate'] = sampling_rate def _set_bytes_per_sample(self, bytes_per_sample=1): """ Set `bytes per sample` value as part of storing the recording information. Parameters ---------- bytes_per_sample : int Total number of bytes to represent each waveform sample in the binary file Returns ------- None """ self._record_info['Bytes per sample'] = bytes_per_sample def _set_samples_per_spike(self, samples_per_spike=1): """ Set `samples per spike` value as part of storing the recording information. Parameters ---------- samples_per_spike : int Total number of samples to represent a spike waveform Returns ------- None """ self._record_info['Samples per spike'] = samples_per_spike def _set_fullscale_mv(self, adc_fullscale_mv=1): """ Set fullscale value of ADC value in mV as part of storing the recording information. Parameters ---------- adc_fullscale_mv : int Fullscale voltage of ADC signal in mV Returns ------- None """ self._record_info['ADC Fullscale mv'] = adc_fullscale_mv
[docs] def get_total_spikes(self): """ Return total number of spikes in the recording. Parameters ---------- None Returns ------- int Total number of spikes """ return self._record_info['No of spikes']
[docs] def get_total_channels(self): """ Return total number of electrode channels in the spike data file. Parameters ---------- None Returns ------- int Total number of electrode channels """ return self._record_info['No of channels']
[docs] def get_channel_ids(self): """ Return the identities of individual channels. Parameters ---------- None Returns ------- list Identities of individual channels """ return self._record_info['Channel IDs']
[docs] def get_timestamp_bytes(self): """ Return the number of bytes to represent a timestamp in the binary file. Parameters ---------- None Returns ------- int Number of bytes to represent timestamps """ return self._record_info['Bytes per timestamp']
[docs] def get_timebase(self): """ Return the timebase for spike event timestamps. Parameters ---------- None Returns ------- int Timebase for spike event timestamps """ return self._record_info['Timebase']
[docs] def get_sampling_rate(self): """ Return the sampling rate of spike waveforms. Parameters ---------- None Returns ------- int Sampling rate for spike waveforms """ return self._record_info['Sampling rate']
[docs] def get_bytes_per_sample(self): """ Return the number of bytes to represent each spike waveform sample. Parameters ---------- None Returns ------- int Number of bytes to represent each sample of the spike waveforms """ return self._record_info['Bytes per sample']
[docs] def get_samples_per_spike(self): """ Return number of bytes to represent each timestamp in the binary file. Parameters ---------- None Returns ------- int Number of bytes to represent timestamps """ return self._record_info['Samples per spike']
[docs] def get_fullscale_mv(self): """ Return the fullscale value of the ADC in mV. Parameters ---------- None Returns ------- int Fullscale ADC value in mV """ return self._record_info['ADC Fullscale mv']
[docs] def save_to_hdf5(self, file_name=None, system=None): """ Store NSpike() object to HDF5 file. Parameters ---------- file_name : str Full file directory for the spike data system : str Recoring system or data format Returns ------- None Also see -------- nc_hdf.Nhdf().save_spike() """ hdf = Nhdf() if file_name and system: if os.path.exists(file_name): self.set_filename(file_name) self.set_system(system) self.load() else: logging.error('Specified file cannot be found!') hdf.save_spike(spike=self) hdf.close()
[docs] def load_spike_NWB(self, file_name): """ Decode spike data from NWB (HDF5) file format. Parameters ---------- file_name : str Full file directory for the spike data Returns ------- None """ file_name, path = file_name.split('+') if os.path.exists(file_name): hdf = Nhdf() hdf.set_filename(file_name) _record_info = {} if path in hdf.f: g = hdf.f[path] elif '/processing/Shank/' + path in hdf.f: path = '/processing/Shank/' + path g = hdf.f[path] else: logging.error('Specified shank datapath does not exist!') return for key, value in g.attrs.items(): _record_info[key] = value self.set_record_info(_record_info) path_clust = 'Clustering' path_wave = 'EventWaveForm/WaveForm' if path_clust in g: g_clust = g[path_clust] self._set_timestamp(hdf.get_dataset( group=g_clust, name='times')) self.set_unit_tags(hdf.get_dataset(group=g_clust, name='num')) self._set_unit_list() else: logging.error('There is no /Clustering in the :' + path) if path_wave in g: g_wave = g[path_wave] self._set_total_spikes(hdf.get_dataset( group=g_wave, name='num_events')) chanIDs = hdf.get_dataset(group=g_wave, name='electrode_idx') self._set_channel_ids(chanIDs) spike_wave = oDict() data = hdf.get_dataset(group=g_wave, name='data') if len(data.shape) == 2: num_events, num_samples = data.shape tot_chans = 1 elif len(data.shape) == 3: num_events, num_samples, tot_chans = data.shape else: logging.error( path_wave + '/data contains for more than 3 dimensions!') if num_events != hdf.get_dataset( group=g_wave, name='num_events'): logging.error( 'Mismatch between num_events and 1st dimension of ' + path_wave + '/data') if num_samples != hdf.get_dataset( group=g_wave, name='num_samples'): logging.error( 'Mismatch between num_samples and 2nd dimension of ' + path_wave + '/data') for i in np.arange(tot_chans): spike_wave['ch' + str(i + 1)] = data[:, :, i] self._set_waveform(spike_wave) else: logging.error( 'There is no /EventWaveForm/WaveForm in the :' + path) hdf.close() else: logging.error(file_name + ' does not exist!')
[docs] def load_spike_Axona(self, file_name, return_raw=False): """ Decode spike data from Axona file format. Parameters ---------- file_name : str Full file directory for the spike data Returns ------- None """ file_directory, file_basename = os.path.split(file_name) file_tag, tet_no = os.path.splitext(file_basename) tet_no = tet_no[1:] set_file = os.path.join(file_directory, file_tag + '.set') cut_file = os.path.join( file_directory, file_tag + '_' + tet_no + '.cut') clu_file = os.path.join( file_directory, file_tag + '.clu.' + tet_no) self._set_data_source(file_name) self._set_source_format('Axona') with open(file_name, 'rb') as f: while True: line = f.readline() try: line = line.decode('latin-1') except BaseException: break if line == '': break if line.startswith('trial_date'): self._set_date( ' '.join(line.replace(',', ' ').split()[1:])) if line.startswith('trial_time'): self._set_time(line.split()[1]) if line.startswith('experimenter'): self._set_experimenter(' '.join(line.split()[1:])) if line.startswith('comments'): self._set_comments(' '.join(line.split()[1:])) if line.startswith('duration'): self._set_duration(float(''.join(line.split()[1:]))) if line.startswith('sw_version'): self._set_file_version(line.split()[1]) if line.startswith('num_chans'): self._set_total_channels(int(''.join(line.split()[1:]))) if line.startswith('timebase'): self._set_timebase( int(''.join(re.findall(r'\d+.\d+|\d+', line)))) if line.startswith('bytes_per_timestamp'): self._set_timestamp_bytes(int(''.join(line.split()[1:]))) if line.startswith('samples_per_spike'): self._set_samples_per_spike(int(''.join(line.split()[1:]))) if line.startswith('sample_rate'): self._set_sampling_rate( int(''.join(re.findall(r'\d+.\d+|\d+', line)))) if line.startswith('bytes_per_sample'): self._set_bytes_per_sample(int(''.join(line.split()[1:]))) if line.startswith('num_spikes'): self._set_total_spikes(int(''.join(line.split()[1:]))) if line.startswith('data_start'): break num_spikes = self.get_total_spikes() bytes_per_timestamp = self.get_timestamp_bytes() bytes_per_sample = self.get_bytes_per_sample() samples_per_spike = self.get_samples_per_spike() f.seek(0, 0) header_offset = [] while True: try: buff = f.read(10).decode('UTF-8') if buff == 'data_start': header_offset = f.tell() break else: f.seek(-9, 1) except BaseException: break tot_channels = self.get_total_channels() self._set_channel_ids( [(int(tet_no) - 1) * tot_channels + x for x in range(tot_channels)]) max_ADC_count = 2**(8 * bytes_per_sample - 1) - 1 max_byte_value = 2**(8 * bytes_per_sample) with open(set_file, 'r', encoding='latin-1') as f_set: lines = f_set.readlines() gain_lines = dict( [tuple(map(int, re.findall(r'\d+.\d+|\d+', line)[0].split())) for line in lines if 'gain_ch_' in line]) gains = np.array([gain_lines[ch_id] for ch_id in self.get_channel_ids()]) for line in lines: if line.startswith('ADC_fullscale_mv'): self._set_fullscale_mv( int(re.findall(r'\d+.\d+|d+', line)[0])) break AD_bit_uvolts = 2 * self.get_fullscale_mv() * 10**3 / \ (gains * (2**(8 * bytes_per_sample))) record_size = tot_channels * (bytes_per_timestamp + bytes_per_sample * samples_per_spike) time_be = 256**(np.arange(bytes_per_timestamp, 0, -1) - 1) sample_le = 256**(np.arange(0, bytes_per_sample, 1)) if not header_offset: print('Error: data_start marker not found!') else: f.seek(header_offset, 0) byte_buffer = np.fromfile(f, dtype='uint8') spike_time = np.zeros([num_spikes, ], dtype='uint32') for i in list(range(0, bytes_per_timestamp)): byte = byte_buffer[i:len(byte_buffer):record_size] byte = byte[:num_spikes] spike_time = spike_time + time_be[i] * byte spike_time = spike_time / self.get_timebase() spike_time = spike_time.reshape((num_spikes, )) spike_wave = oDict() for i in np.arange(tot_channels): chan_offset = (i + 1) * bytes_per_timestamp + \ i * bytes_per_sample * samples_per_spike chan_wave = np.zeros( [num_spikes, samples_per_spike], dtype=np.float64) for j in np.arange(0, samples_per_spike, 1): sample_offset = j * bytes_per_sample + chan_offset for k in np.arange(0, bytes_per_sample, 1): byte_offset = k + sample_offset sample_value = sample_le[k] * byte_buffer[byte_offset: len( byte_buffer) + byte_offset - record_size:record_size] sample_value = sample_value.astype( np.float64, casting='unsafe', copy=False) np.add(chan_wave[:, j], sample_value, out=chan_wave[:, j]) np.putmask(chan_wave[:, j], chan_wave[:, j] > max_ADC_count, chan_wave[:, j] - max_byte_value) spike_wave['ch' + str(i + 1)] = chan_wave * \ AD_bit_uvolts[i] unit_ID = None if return_raw: return spike_time, spike_wave if os.path.isfile(cut_file): used = cut_file with open(cut_file, 'r') as f_cut: while True: line = f_cut.readline() if line == '': break if line.startswith('Exact_cut'): unit_ID = np.fromfile( f_cut, dtype='uint8', sep=' ') elif os.path.isfile(clu_file): used = clu_file data = np.loadtxt(clu_file) unit_ID = data[1:].flatten() - 1 else: logging.error( "No cluster file found for spike file {} please make one at {} or {}".format( file_name, cut_file, clu_file)) return if unit_ID is None: string = "Unable to parse clusters from {}".format(used) logging.error(string) return self._set_timestamp(spike_time) self._set_waveform(spike_wave) self.set_unit_tags(unit_ID)
[docs] def load_spike_Neuralynx(self, file_name): """ Decode spike data from Neuralynx file format. Parameters ---------- file_name : str Full file directory for the spike data Returns ------- None """ self._set_data_source(file_name) self._set_source_format('Neuralynx') # Format description for the NLX file: file_ext = file_name[-3:] if file_ext == 'ntt': tot_channels = 4 elif file_ext == 'nst': tot_channels = 2 elif file_ext == 'nse': tot_channels = 1 header_offset = 16 * 1024 # fixed for NLX files bytes_per_timestamp = 8 bytes_chan_no = 4 bytes_cell_no = 4 bytes_per_feature = 4 num_features = 8 bytes_features = bytes_per_feature * num_features bytes_per_sample = 2 samples_per_record = 32 channel_pack_size = bytes_per_sample * \ tot_channels # ch1|ch2|ch3|ch4 each with 2 bytes max_byte_value = np.power(2, bytes_per_sample * 8) max_ADC_count = np.power(2, bytes_per_sample * 8 - 1) - 1 AD_bit_uvolts = np.ones([tot_channels, ]) * 10**-6 # Default value record_size = None with open(file_name, 'rb') as f: while True: line = f.readline() try: line = line.decode('UTF-8') except BaseException: break if line == '': break if 'SamplingFrequency' in line: self._set_sampling_rate( float(''.join(re.findall(r'\d+.\d+|\d+', line)))) if 'RecordSize' in line: record_size = int( ''.join(re.findall(r'\d+.\d+|\d+', line))) if 'Time Opened' in line: self._set_date(re.search(r'\d+/\d+/\d+', line).group()) self._set_time(re.search(r'\d+:\d+:\d+', line).group()) if 'FileVersion' in line: self._set_file_version(line.split()[1]) if 'ADMaxValue' in line: max_ADC_count = float( ''.join(re.findall(r'\d+.\d+|\d+', line))) if 'ADBitVolts' in line: AD_bit_uvolts = np.array( [float(x) * (10**6) for x in re.findall(r'\d+.\d+|\d+', line)]) if 'ADChannel' in line: self._set_channel_ids( np.array([int(x) for x in re.findall(r'\d+', line)])) if 'NumADChannels' in line: tot_channels = int(''.join(re.findall(r'\d+', line))) # gain = 1 assumed to keep in similarity to Axona self._set_fullscale_mv((max_byte_value / 2) * AD_bit_uvolts) self._set_bytes_per_sample(bytes_per_sample) self._set_samples_per_spike(samples_per_record) self._set_timestamp_bytes(bytes_per_timestamp) self._set_total_channels(tot_channels) if not record_size: record_size = bytes_per_timestamp + \ bytes_chan_no + \ bytes_cell_no + \ bytes_features + \ bytes_per_sample * samples_per_record * tot_channels time_offset = 0 unitID_offset = bytes_per_timestamp + \ bytes_chan_no sample_offset = bytes_per_timestamp + \ bytes_chan_no + \ bytes_cell_no + \ bytes_features f.seek(0, 2) num_spikes = int((f.tell() - header_offset) / record_size) self._set_total_spikes(num_spikes) f.seek(header_offset, 0) spike_time = np.zeros([num_spikes, ]) unit_ID = np.zeros([num_spikes, ], dtype=np.int64) spike_wave = oDict() sample_le = 256**(np.arange(bytes_per_sample)) for i in np.arange(tot_channels): spike_wave['ch' + str(i + 1)] = np.zeros([num_spikes, samples_per_record]) for i in np.arange(num_spikes): sample_bytes = np.fromfile(f, dtype='uint8', count=record_size) spike_time[i] = int.from_bytes( sample_bytes[time_offset + np.arange(bytes_per_timestamp)], byteorder='little', signed=False) / 10**6 unit_ID[i] = int.from_bytes( sample_bytes[unitID_offset + np.arange(bytes_cell_no)], byteorder='little', signed=False) for j in range(tot_channels): sample_value = np.zeros( [samples_per_record, bytes_per_sample]) ind = sample_offset + j * bytes_per_sample + \ np.arange(samples_per_record) * channel_pack_size for k in np.arange(bytes_per_sample): sample_value[:, k] = sample_bytes[ind + k] sample_value = sample_value.dot(sample_le) np.putmask(sample_value, sample_value > max_ADC_count, sample_value - max_byte_value) spike_wave['ch' + str(j + 1)][i, :] = sample_value * AD_bit_uvolts[j] spike_time -= spike_time.min() self._set_duration(spike_time.max()) self._set_timestamp(spike_time) self._set_waveform(spike_wave) self.set_unit_tags(unit_ID)
[docs] def load_spike_spikeinterface( self, sorting, channel_scaling=None, group=None): """ Load spike information from any Sorting Extractor object. This extracts timestamps, tags, and waveforms from a sorting object. Then stores them into NeuroChaT NSpike attributes. Parameters ---------- sorting : spikeinterface.extractors.SortingExtractor The sorting extractor object to load from. channel_scaling : list | np.ndarray This is used to apply gains. There should be one entry for each channel if provided. Applied as waveform_channel_i * channel_scaling[i]. Defaults to unit gain for each channel. group : int | str The group in the sorting extractor to consider. This can be used to split the data into tetrodes or probes. Returns ------- None Note ---- NC assumes that unit 0 is not present in the sorting. However, other sorters don't assume this. As such, if your sorter contains 0 as a unit, ALL unit numbers will be incremented by 1. Example ------- .. highlight:: python .. code-block:: python # This would convert from Phy to NeuroChaT native NWB import spikeinterface.extractors as se to_exclude = ["mua", "noise"] sorting = se.PhySortingExtractor( "phy_folder", exclude_cluster_groups=to_exclude, load_waveforms=True, verbose=False) spike = NSpike() hdf_path = "test.hdf5" nhdf = Nhdf(filename=hdf_path) groups = [] unit_ids = sorting.get_unit_ids() for unit in unit_ids: try: tetrode = sorting.get_unit_property(unit, "group") except BaseException: try: tetrode = sorting.get_unit_property(unit, "ch_group") except BaseException: tetrode = None if tetrode is not None: if tetrode not in groups: groups.append(tetrode) for g in groups: spike.load_spike_spikeinterface(sorting, group=g) spike.set_unit_no(spike.get_unit_list()[0]) print(spike) nhdf.save_spike(spike=spike) """ unit_ids = sorting.get_unit_ids() should_increment = (0 in unit_ids) if group is not None: units_to_use = [] for unit in unit_ids: try: tetrode = sorting.get_unit_property(unit, "group") except BaseException: try: tetrode = sorting.get_unit_property(unit, "ch_group") except BaseException: logging.warning( "Did not find any channel groups in sorting") units_to_use = unit_ids group = None tetrode = None break if tetrode is not None: if group == tetrode: units_to_use.append(unit) else: units_to_use = unit_ids sample_rate = sorting.params['sample_rate'] all_unit_trains = [ sorting.get_unit_spike_train(uid) for uid in units_to_use] timestamps = np.concatenate(all_unit_trains) / float(sample_rate) total_spikes = len(timestamps) # Set up the empty numpy arrays to hold the data unit_tags = np.zeros(total_spikes) waveform_eg = sorting.get_unit_spike_features( units_to_use[0], "waveforms") samples_per_spike = waveform_eg.shape[2] total_channels = waveform_eg.shape[1] out_waveforms = oDict() if channel_scaling is None: channel_scaling = np.ones([total_channels, ]) for j in range(total_channels): out_waveforms["ch{}".format(j + 1)] = np.zeros( shape=(total_spikes, samples_per_spike), dtype=np.float64) # Establish the tag of each and waveform start = 0 for u_i, u in enumerate(units_to_use): end = start + all_unit_trains[u_i].size unit_tags[start:end] = u wf = sorting.get_unit_spike_features(u, "waveforms") for j in range(total_channels): try: wave = wf[:, j, :] except BaseException: wave = wf[j, :] out_waveforms["ch{}".format(j + 1)][start:end] = ( wave * channel_scaling[j]) start = end # NC assumes unit 0 not in data if should_increment: unit_tags = unit_tags + 1 # Order spikes based on time ordering = timestamps.argsort() timestamps = timestamps[ordering] unit_tags = unit_tags[ordering].astype(np.uint64) for j in range(total_channels): out_waveforms["ch{}".format(j + 1)] = ( out_waveforms["ch{}".format(j + 1)][ordering]) self._set_total_channels(total_channels) self._set_total_spikes(total_spikes) self._set_duration(timestamps.max()) self._set_samples_per_spike(samples_per_spike) self._set_timestamp(timestamps) self._set_sampling_rate(sorting._sampling_frequency) self.set_unit_tags(unit_tags) self._set_waveform(out_waveforms) self.set_filename(sorting.params["dat_path"]) self.set_system("SpikeInterface") self._set_source_format(type(sorting).__name__) self._set_channel_ids([i for i in range(total_channels)]) self._spikeinterface_group = group
[docs] def __str__(self): """Return a friendly string representation of this object.""" if self.get_unit_no() not in self._unit_list: total_spikes = "not in unit list" else: total_spikes = "with {} spikes".format( self.get_unit_spikes_count()) return "{} object with units {} and current unit {} {}".format( "NeuroChaT NSpike", self.get_unit_list(), self.get_unit_no(), total_spikes )
# def sfc(self, lfp=None, **kwargs): # """ # Calculates spike-field coherence of spike train with underlying LFP # signal. # Delegates to NLfp().sfc() # Parameters # ---------- # lfp : NLfp # LFP object which contains the LFP data # **kwargs # Keyword arguments # Returns # ------- # dict # Graphical data of the analysis # See also # -------- # nc_lfp.NLfp().sfc() # """ # if lfp is None: # logging.error('LFP data not specified!') # else: # try: # lfp.sfc(self.get_unit_stamp(), **kwargs) # except: # logging.error('No sfc() method in lfp data specified!')