# -*- coding: utf-8 -*-
"""
This module implements a container for the NData class.
This module is intended to simplify multi experiment analyses.
@author: Sean Martin; martins7 at tcd dot ie
"""
from enum import Enum
import copy
import logging
import os
import pprint
import re
import pandas as pd
import numpy as np
from neurochat.nc_data import NData
from neurochat.nc_utils import get_all_files_in_dir, make_dir_if_not_exists
from neurochat.nc_utils import log_exception, remove_extension
[docs]class NDataContainerIterator():
"""A custom iterator for the NDataContainer class."""
def __init__(self, container):
"""Start the container at index 0."""
self._index = 0
self._container = container
[docs] def __next__(self):
"""Increase the container index by 1."""
if self._index < len(self._container):
self._index += 1
return self._container[self._index - 1]
raise StopIteration
[docs]class NDataContainer():
"""
Class for storing multiple file locations for NData objects.
Optionally, the NData objects themselves can be stored.
Parameters
----------
share_positions : bool
Share the same position file between the data objects
load_on_fly : bool
If True, don't store all the data in memory,
instead load it as needed, or on the fly
share_stms : bool
Share the same stimulation file between the data objects
Attributes
----------
_file_names_dict : Dict
A dictionary containing all the file names
_container : List
The underlying NData object list
_units : List
The list of all units available for these data objects
_unit_count : List
The number of units available for these data objects
_share_positions : bool
Whether to share positional data between the data objects
_share_stms : bool
Whether to share stimulation data between the data objects
_load_on_fly : bool
Whether all the data should be stored in memory or loaded as needed
_smoothed_speed : bool
Stores has the data been smoothed yet.
This is usually not needed as NeuroChaT automatically smooths speed
_last_data_pt : tuple (int, NData)
A cached version of the last data point used from the collection.
This is stored to avoid reloading if the same object is used multiple
times in succession.
"""
def __init__(
self, share_positions=False, load_on_fly=False, share_stms=False):
"""See the class description."""
self._file_names_dict = {}
self._units = []
self._container = []
self._unit_count = []
self._share_positions = share_positions
self._share_stms = share_stms
self._load_on_fly = load_on_fly
self._last_data_pt = (1, None)
self._smoothed_speed = False
[docs] class EFileType(Enum):
"""The different filetypes that can be added to an object."""
Spike = 1
Position = 2
LFP = 3
STM = 4
[docs] def setup(self):
"""Perform data initialisation based on the input filenames."""
if self._load_on_fly:
self._last_data_pt = (1, None)
else:
self._load_all_data()
[docs] def get_num_data(self):
"""Return the number of Ndata objects in the container."""
if self._load_on_fly:
for _, vals in self.get_file_dict().items():
return len(vals)
return len(self._container)
[docs] def get_file_dict(self, key=None):
"""
Return the key value filename dictionary for this collection.
Parameters
----------
key : str, optional
The key to retrieve the filenames for.
See EFileType for options.
Defaults to None, in which case the whole dict is returned.
Returns
-------
list or dict
"""
if key:
return self._file_names_dict.get(key, None)
return self._file_names_dict
[docs] def get_units(self, index=None):
"""
Return the units in this collection, optionally at a given index.
Parameters
----------
index : int, optional
Collection data index to get the units for.
Defaults to None. If None all units are returned.
Returns
-------
list
Either a list containing lists of all units in the collection
or the list of units for the given data index
"""
if index is None:
return self._units
if index >= self.get_num_data() and (not self._load_on_fly):
logging.error("Input index to get_data out of range")
return
return self._units[index]
[docs] def set_units(self, units='all'):
"""
Set the list of units for the collection.
Parameters
----------
units : list or str:
If a list, indicates the units to use for each data object
stored in the collection. "all" is accepted as a single element.
Otherwise, a string "all" is expected, which sets all available
units picked up from clustering files.
"""
self._units = []
if self.get_file_dict() == {}:
raise ValueError("Can't set units for empty collection")
if units == 'all':
if self._load_on_fly:
vals = self.get_file_dict()["Spike"]
for descriptor in vals:
result = NData()
self._load("Spike", descriptor, ndata=result)
self._units.append(result.get_unit_list())
else:
for data in self.get_data():
self._units.append(data.get_unit_list())
elif isinstance(units, list):
for idx, unit in enumerate(units):
if unit == 'all':
if self._load_on_fly:
vals = self.get_file_dict()["Spike"]
descriptor = vals[idx]
result = NData()
self._load("Spike", descriptor, ndata=result)
all_units = result.get_unit_list()
else:
all_units = self.get_data(idx).get_unit_list()
self._units.append(all_units)
elif isinstance(unit, int):
self._units.append([unit])
elif isinstance(unit, list):
self._units.append(unit)
else:
logging.error(
"Unrecognised type {} passed to set units".format(
type(unit)))
else:
logging.error(
"Unrecognised type {} passed to set units".format(type(units)))
self._count_num_units()
[docs] def get_data(self, index=None):
"""
Return the NData objects in this collection, or a specific object.
Do not call this with no index if loading data on the fly.
Parameters
----------
index : int
Optional index to get data at
Defaults to None, in which case all data is returned.
Returns
-------
NData or list of NData objects
"""
if self._load_on_fly:
if index is None:
logging.error("Can't load all data when loading on the fly")
result = NData()
for key, vals in self.get_file_dict().items():
descriptor = vals[index]
self._load(key, descriptor, ndata=result)
return result
if index is None:
return self._container
if index >= self.get_num_data():
logging.error("Input index to get_data out of range")
return
return self._container[index]
[docs] def add_data(self, data):
"""
Add an NData object to this container.
Parameters
----------
data : NData
The NData object to add to this container.
"""
if isinstance(data, NData):
self._container.append(data)
else:
logging.error("Adding incorrect object to data container")
return
[docs] def add_files(self, f_type, descriptors):
"""
Add a list of filenames of the given type to the container.
Parameters
----------
f_type : EFileType:
The type of file being added (Spike, LFP, Position)
descriptors : list
Either a list of filenames, or a tuple of lists in the order
(filenames_list, obj_names_list, data_system_list).
Filenames should be absolute.
Returns
-------
None
"""
if isinstance(descriptors, list):
descriptors = (descriptors, None, None)
filenames, _, _ = descriptors
if not isinstance(f_type, self.EFileType):
logging.error(
"Parameter f_type in add files must be of EFileType\n" +
"given {}".format(f_type))
return
if f_type.name == "Position" and self._share_positions and len(
filenames) == 1:
for _ in range(len(self.get_file_dict()["Spike"]) - 1):
filenames.append(filenames[0])
if f_type.name == "STM" and self._share_stms and len(filenames) == 1:
for _ in range(len(self.get_file_dict()["Spike"]) - 1):
filenames.append(filenames[0])
# Ensure lists are empty or of equal size
for desc in descriptors:
if desc is not None:
if len(desc) != len(filenames):
logging.error(
"add_files called with differing number of filenames"
+ " and other data")
return
for idx in range(len(filenames)):
description = []
for el in descriptors:
if el is not None:
description.append(el[idx])
else:
description.append(None)
self._file_names_dict.setdefault(
f_type.name, []).append(description)
[docs] def add_all_files(self, spats, spikes, lfps):
"""
Quickly add a list of positions, spikes and lfps.
Parameters
----------
spats : list
The list of spatial files
spikes : list
The list of spike files
lfps : list
The list of lfp files
Returns
-------
None
"""
self.add_files(self.EFileType.Position, spats)
self.add_files(self.EFileType.Spike, spikes)
self.add_files(self.EFileType.LFP, lfps)
[docs] def add_files_from_excel(self, file_loc, unit_sep=" "):
"""
Add filepaths from an excel file.
These should be setup to be in the order:
directory | position file | spike file | unit numbers | eeg extension
Parameters
----------
file_loc : str
Name of the excel file that contains the data specifications
unit_sep : str
Optional separator character for unit numbers, default " "
Returns
-------
excel_info :
The raw info parsed from the excel file for further use
"""
pos_files = []
spike_files = []
units = []
lfp_files = []
to_merge = []
if os.path.exists(file_loc):
excel_info = pd.read_excel(file_loc, index_col=None)
if excel_info.shape[1] % 5 != 0:
logging.error(
"Incorrect excel file format, it should be:\n" +
"directory | position file | spike file" +
"| unit numbers | eeg extension")
return
# excel_info = excel_info.iloc[:, 1:] # Can be used to remove index
count = 0
for full_row in excel_info.itertuples():
split = [full_row[i:i + 5]
for i in range(1, len(full_row), 5)
if not pd.isna(full_row[i])]
merge = True if len(split) > 1 else False
merge_list = []
for row in split:
base_dir = row[0]
pos_name = row[1]
tetrode_name = row[2]
if pos_name[-4:] == '.txt':
spat_file = base_dir + os.sep + pos_name
else:
spat_file = base_dir + os.sep + pos_name + '.txt'
spike_file = base_dir + os.sep + tetrode_name
# Load the unit numbers
unit_info = row[3]
if unit_info == "all":
unit_list = "all"
elif isinstance(unit_info, int):
unit_list = unit_info
elif isinstance(unit_info, float):
unit_list = int(unit_info)
else:
unit_list = [
int(x) for x in unit_info.split(" ") if x != ""]
# Load the lfp
lfp_ext = row[4]
if lfp_ext[0] != ".":
lfp_ext = "." + lfp_ext
spike_name = remove_extension(spike_file, keep_dot=False)
lfp_file = spike_name + lfp_ext
pos_files.append(spat_file)
spike_files.append(spike_file)
lfp_files.append(lfp_file)
units.append(unit_list)
merge_list.append(count)
count += 1
if merge:
to_merge.append(merge_list)
# Complete the file setup based on parsing from the excel file
self.add_all_files(pos_files, spike_files, lfp_files)
self.setup()
self.set_units(units)
for idx, merge_list in enumerate(to_merge):
self.merge(merge_list)
for j in range(idx + 1, len(to_merge)):
to_merge[j] = [
k - len(merge_list) + 1 for k in to_merge[j]]
return excel_info
else:
logging.error('Excel file does not exist!')
return None
[docs] def add_axona_files_from_dir(
self, directory, recursive=False, verbose=False, **kwargs):
"""
Go through a directory, extracting Axona files from it automatically.
Parameters
----------
directory : str
The directory to parse through
recursive : bool, optional.
Defaults to False.
Whether to recurse through dirs
verbose: bool, optional.
Defaults to False.
Whether to print the files being added.
**kwargs: keyword arguments
tetrode_list : list
list of tetrodes to consider
default is 1 to 16
data_extension : str
default .set
cluster_extension : str
default .cut
pos_extension : str
default .txt
lfp_extension : str
default .eeg
re_filter : str
default None - no regex performed
regex string for matching filenames
save_result : bool
default True
should save the resulting collection to a file
unit_cutoff : tuple of ints
don't consider any recordings with units outside this range
e.g. if the cutoff is set at 10, any clustering containing
11 or more units will not be considered valid and won't be
added to the container.
Returns
-------
List or str:
If save_result is true,
returns a string indicated where the result was saved
Otherwise returns a list of the cluster files which were used.
"""
default_tetrode_list = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
tetrode_list = kwargs.get("tetrode_list", default_tetrode_list)
data_extension = kwargs.get("data_extension", ".set")
cluster_extension = kwargs.get("cluster_extension", ".cut")
clu_extension = kwargs.get("clu_extension", ".clu.X")
pos_extension = kwargs.get("pos_extension", ".txt")
lfp_extension = kwargs.get("lfp_extension", ".eeg")
stm_extension = kwargs.get("stm_extension", ".stm")
re_filter = kwargs.get("re_filter", None)
save_result = kwargs.get("save_result", True)
unit_cutoff = kwargs.get("unit_cutoff", None)
if verbose:
print("Finding set files:")
files = get_all_files_in_dir(
directory, data_extension,
recursive=recursive, verbose=verbose,
re_filter=re_filter, return_absolute=True)
if verbose:
print("Finding txt files:")
txt_files = get_all_files_in_dir(
directory, pos_extension,
recursive=recursive, verbose=verbose,
re_filter=re_filter, return_absolute=True)
if verbose:
print("Finding stm files:")
stm_files = get_all_files_in_dir(
directory, stm_extension,
recursive=recursive, verbose=verbose,
re_filter=re_filter, return_absolute=True)
num_found = 0
cluster_files = []
for filename in files:
filename = filename[:-len(data_extension)]
for tetrode in tetrode_list:
spike_name = filename + '.' + str(tetrode)
cut_name = filename + '_' + str(tetrode) + cluster_extension
clu_name = filename + clu_extension[:-1] + str(tetrode)
lfp_name = filename + lfp_extension
stm_name = ""
if not os.path.isfile(os.path.join(directory, spike_name)):
continue
# Don't consider files that have not been clustered
if not (
os.path.isfile(os.path.join(directory, cut_name)) or
os.path.isfile(os.path.join(directory, clu_name))):
logging.info(
"Skipping tetrode {} - no clust file {} or {}".format(
tetrode, cut_name, os.path.basename(clu_name)))
continue
for fname in txt_files:
if pos_extension == ".txt":
if fname[:(len(filename) + 1)] == filename + "_":
pos_name = fname
break
elif pos_extension == ".pos":
if fname[:(len(filename))] == filename:
pos_name = fname
break
else:
logging.info(
"Skipping tetrode {} - no position file {}".format(
tetrode, filename))
continue
for fname in stm_files:
if fname[:len(filename)] == filename:
stm_name = fname
break
if os.path.isfile(os.path.join(directory, cut_name)):
cluster_name = cut_name
else:
cluster_name = clu_name
cluster_files.append(cluster_name)
logging.info(
"Adding tetrode {} with spikes {}, clusters {}, positions {}".format(
tetrode, os.path.basename(spike_name),
os.path.basename(cluster_name),
os.path.basename(pos_name))
)
num_found += 1
self.add_files(NDataContainer.EFileType.Spike, [spike_name])
self.add_files(NDataContainer.EFileType.Position, [pos_name])
self.add_files(NDataContainer.EFileType.LFP, [lfp_name])
self.add_files(NDataContainer.EFileType.STM, [stm_name])
if num_found == 0:
logging.warning("Did not find any Axona files to add")
return
self.set_units()
if unit_cutoff:
self.remove_recordings_units(
unit_cutoff[0], unit_cutoff[1], verbose=verbose)
if save_result:
friendly_re = ""
if re_filter:
friendly_re = "_" + \
"-".join(re.findall("[a-zA-Z0-9_]+", re_filter))
name = (
"file_list_" + os.path.basename(directory) +
friendly_re + ".txt")
out_loc = os.path.join(directory, "nc_results", name)
make_dir_if_not_exists(out_loc)
with open(out_loc, 'w') as f:
f.write(str(self))
logging.info(
"Wrote list of files considered to {}".format(out_loc))
return out_loc
return cluster_files
[docs] def subsample(self, key):
"""
Return a subsample of the original data collection.
This subsample is not a reference, but a deep copy
Parameters
----------
key : Slice, list of int, or int
How to sample the original collection
Returns
-------
NDataContainer
The deep copied subsample
"""
try:
result = copy.deepcopy(self)
except BaseException:
print("Could not copy the original collection, will modify in place")
result = NDataContainer()
for k in self._file_names_dict:
if isinstance(key, int):
result._file_names_dict[k] = [self._file_names_dict[k][key]]
if isinstance(key, list):
result._file_names_dict[k] = [
self._file_names_dict[k][i] for i in key
]
else:
result._file_names_dict[k] = self._file_names_dict[k][key]
if len(self._units) > 0:
if isinstance(key, int):
result._units = [self._units[key]]
if isinstance(key, list):
result._units = [self._units[i] for i in key]
else:
result._units = self._units[key]
if len(self._container) > 0:
if isinstance(key, int):
result._container = [self._container[key]]
if isinstance(key, list):
result._container = [self._container[i] for i in key]
else:
result._container = self._container[key]
result._count_num_units()
return result
[docs] def sort_units_spatially(self, should_sort_list=None, mode="vertical"):
"""
Sort the units in the collection based on the place field centroid.
Parameters
----------
should_sort_list: list
Optional list of boolean values indicating what objects
mode: str
"horizontal" or "vertical", indicating what axis to sort on.
Returns
-------
None
"""
if mode == "vertical":
h = 1
elif mode == "horizontal":
h = 0
else:
logging.error(
"NDataContainer: "
+ "Only modes horizontal and vertical are supported")
if should_sort_list is None:
should_sort_list = [True for _ in range(self.get_num_data())]
for idx, bool_val in enumerate(should_sort_list):
if bool_val:
centroids = []
data = self.get_data(idx)
for unit in self.get_units()[idx]:
data.set_unit_no(unit)
place_info = data.place()
centroid = place_info["centroid"]
centroids.append(centroid)
self._units[idx] = [unit for _, unit in sorted(
zip(centroids, self.get_units()[idx]),
key=lambda pair: pair[0][h])]
[docs] def get_index_info(self, idx, absolute=False):
"""Return the Spike, LFP, Position and Unit info at idx."""
str_info = {}
dirnames = []
if absolute:
idx, u_idx = self._index_to_data_pos(idx)
for key in ["Spike", "LFP", "Position", "STM"]:
name = self.get_file_dict(key)[idx][0]
str_info[key] = (os.path.basename(name))
if name != "":
dirnames.append(os.path.dirname(name))
if absolute:
str_info["Units"] = (self.get_units(idx)[u_idx])
else:
str_info["Units"] = (self.get_units(idx))
if len(set(dirnames)) == 1:
str_info["Root"] = dirnames[0]
else:
print("Not all files are in the same directory {} {}".format(
":Spike, LFP, Position, STM: ", dirnames))
str_info["Root"] = dirnames
return str_info
[docs] def string_repr(self, pretty=True):
"""
Return a string representation of this class.
Parameters
----------
pretty : str, Default True
Should return a pretty version or all the info.
Returns
-------
str
"""
if pretty:
return self._pretty_string()
else:
return self._full_string()
[docs] def remove_recordings_units(
self, unit_lb=0, unit_ub=10000, verbose=False):
"""
Remove data objects in this recording with units outside bounds.
Parameters
----------
unit_lb : int, optional.
Lower bound on the number of allowed units.
Defaults to 0.
unit_ub : int, optional.
Upper bound on the number of allowed units.
Defaults to 10000.
verbose : bool, optional.
Whether to print extra information on what is being removed.
Defaults to False.
Returns
-------
None
"""
start_size = self.get_num_data()
start_total = len(self)
for i in range(self.get_num_data() - 1, -1, -1):
unit_count = len(self.get_units(i))
if (unit_count > unit_ub) or (unit_count < unit_lb):
for key in ("Spike", "LFP", "Position"):
name = self._file_names_dict[key].pop(i)
if (key == "Spike") and verbose:
print("Removed {} with {} units".format(
os.path.basename(name[0]), unit_count))
if self._unit_count.pop(i) != unit_count:
raise ValueError(
"Error in remove recording {}".format(name))
self._last_data_pt = (1, None)
self._units.pop(i)
if not self._load_on_fly:
self._container.pop(i)
end_size = self.get_num_data()
self._count_num_units()
end_total = len(self)
if end_total != start_total:
print(("{} tetrodes with {} units reduced to "
+ "{} tetrodes with {} units").format(
start_size, start_total, end_size, end_total))
else:
print("{} tetrodes with {} units".format(
start_size, start_total))
[docs] def get_data_at(self, data_index, unit_index):
"""
Return an NData object from the given indices.
Parameters
----------
data_index : int
The index in the container to return data at.
unit_index : int
The unit number to set on the returned data.
Returns
-------
NData
The ndata object at data_index with unit number unit_index
"""
if self._load_on_fly:
try:
if (data_index == self._last_data_pt[0] and
(self._last_data_pt[1] is not None)):
result = self._last_data_pt[1]
else:
result = NData()
for key, vals in self.get_file_dict().items():
if key == "STM":
continue
descriptor = vals[data_index]
self._load(key, descriptor,
idx=data_index, ndata=result)
self._last_data_pt = (data_index, result)
except Exception as e:
log_exception(e, "During loading data")
else:
result = self.get_data(data_index)
if len(self.get_units()) > 0:
result.set_unit_no(self.get_units(data_index)[unit_index])
return result
[docs] def get_name_at_idx(
self, idx, ext, opt_end="", base_dir=None, out_dirname="nc_plots"):
"""
Get the filename to save an index in the collection to.
Parameters
----------
idx : int
The index of the collection to get the filename for.
ext : str
The extension to append to the filename.
opt_end : str, optional.
Used like this default_name + opt_end + ext
base_dir : str, optional.
One can specify a directory that all files originated from.
It is used like so:
Say data1 is in test/foo/data.txt
and data2 is in test/bar/data.txt
Then passing base_dir as test
Would set the names to
out_dirname/foo--data.txt
out_dirname/bar--data.txt
outdirname : str, optional
The directory to save the plots to.
This is relative to the directory of the filename
if base dir is None.
Returns
-------
str
"""
data_idx, unit_idx = self._index_to_data_pos(idx)
filename = self.get_file_dict()["Spike"][data_idx][0]
unit_number = self.get_units(data_idx)[unit_idx]
spike_name = os.path.basename(filename)
final_bname, final_ext = os.path.splitext(spike_name)
final_ext = final_ext[1:]
f_dir = os.path.dirname(filename)
data_basename = (
final_bname + "_" + final_ext + "_" +
str(unit_number) + opt_end + "." + ext)
if base_dir is not None:
main_dir = base_dir
out_base = f_dir[len(base_dir + os.sep):]
if len(out_base) != 0:
out_base = ("--").join(out_base.split(os.sep))
data_basename = out_base + "--" + data_basename
else:
main_dir = f_dir
out_name = os.path.join(
main_dir, out_dirname, data_basename)
make_dir_if_not_exists(out_name)
return out_name
[docs] def list_all_units(self):
"""Print all the units in the container."""
if self._load_on_fly:
for key, vals in self.get_file_dict().items():
if key == "Spike":
for descriptor in vals:
result = NData()
self._load(key, descriptor, ndata=result)
print("units are {}".format(result.get_unit_list()))
else:
for data in self._container:
print("units are {}".format(data.get_unit_list()))
# Methods from here on should be for private class use
def _pretty_string(self):
"""Alternative string representation should be prettier."""
all_str_info = []
for i in range(self.get_num_data()):
str_info = self.get_index_info(i)
b_str = "{}: \n\tSpk {}\n\tUnt {}: {}\n\tLfp {}\n\tPos {}\n\tSTM {}\n\tDir {}".format(
i, str_info["Spike"], len(str_info["Units"]),
str_info["Units"], str_info["LFP"],
str_info["Position"], str_info["STM"],
str_info["Root"])
all_str_info.append(b_str)
return "\n".join(all_str_info)
def _full_string(self):
"""Full string representation of the container."""
string = (
"NData Container Object with {} objects:\n" +
"Set to Load on Fly? {}\n" +
"Files are:\n{}\n" +
"Units are:\n{}").format(
self.get_num_data(),
self._load_on_fly,
pprint.pformat(self.get_file_dict()),
pprint.pformat(self.get_units()))
return string
def _load_all_data(self):
"""Intended private function which loads all the data."""
if self._load_on_fly:
logging.error(
"Don't load all the data in container if loading on the fly")
for key, vals in self.get_file_dict().items():
for idx, _ in enumerate(vals):
if idx >= self.get_num_data():
self.add_data(NData())
for idx, descriptor in enumerate(vals):
self._load(key, descriptor, idx=idx)
def _load(self, key, descriptor, idx=None, ndata=None):
"""
Intended private function which loads data for a specific filetype.
The NData object loaded into is either passed in, or found by idx.
Parameters
----------
key : str
"Spike", "Position", or "LFP", which filetype to load
descriptor : tuple
(filename, objectname, system) tuple
idx : int
Optional parameter to get corresponding data from _collection
ndata : NData
Optional parameter to allow passing in an ndata object to load to
Returns
-------
None
"""
if ndata is None:
ndata = self.get_data(idx)
key_fn_pairs = {
"Spike": [
getattr(ndata, "set_spike_file"),
getattr(ndata, "set_spike_name"),
getattr(ndata, "load_spike")],
"Position": [
getattr(ndata, "set_spatial_file"),
getattr(ndata, "set_spatial_name"),
getattr(ndata, "load_spatial")],
"LFP": [
getattr(ndata, "set_lfp_file"),
getattr(ndata, "set_lfp_name"),
getattr(ndata, "load_lfp")],
}
filename, objectname, system = descriptor
if objectname is not None:
key_fn_pairs[key][1](objectname)
if system is not None:
ndata.set_system(system)
if key == "Position" and self._share_positions and idx != 0:
if self._load_on_fly:
ndata.spatial = self._last_data_pt[1].spatial
else:
ndata.spatial = self.get_data(0).spatial
return
if filename is not None:
key_fn_pairs[key][0](filename)
key_fn_pairs[key][2]()
[docs] def __str__(self):
"""Return a string representation of the collection."""
return self.string_repr(pretty=True)
[docs] def __getitem__(self, index):
"""Return the data object with corresponding unit at index."""
data_index, unit_index = self._index_to_data_pos(index)
return self.get_data_at(data_index, unit_index)
[docs] def __len__(self):
"""Return the number of units in the collection."""
counts = self._unit_count
if len(counts) == 0:
self._count_num_units()
counts = self._unit_count
return sum(counts)
def _count_num_units(self):
"""Intended private function to count units in the collection."""
counts = []
for unit_list in self.get_units():
counts.append(len(unit_list))
self._unit_count = counts
def _index_to_data_pos(self, index):
"""
Intended private function to turn an index into a tuple indices.
Parameters
----------
index : int
The unit index to convert to a data index and unit index for that
Returns
-------
tuple
(data collection index, unit index for this data object)
"""
counts = self._unit_count
if len(counts) == 0:
print("Recounting units")
self._count_num_units()
counts = self._unit_count
if index >= len(self):
raise IndexError("index {} is out of range {} for {}".format(
index, len(self) - 1, self))
else:
running_sum, running_idx = 0, 0
for count in counts:
if index < (running_sum + count):
return running_idx, (index - running_sum)
else:
running_sum += count
running_idx += 1
[docs] def __iter__(self):
"""Iterate over all units in the container."""
return NDataContainerIterator(self)
# def merge(self, indices, force_equal_units=True):
# """
# Merge the data from multiple indices together into the first index.
# ONLY FUNCTIONS FOR POSITIONS AND SPIKES CURRENTLY - DOES NOT MERGE LFP.
# ALSO DOES NOT MERGE THE SPIKE WAVEFORMS, MERELY THE SPIKE TIMES.
# Only call this after loading the data, and not while loading on the fly
# Parameters
# ----------
# indices: list
# The list of indices in the data to merge together
# force_equal_units:
# The merged indexes must have the same unit numbers available
# Returns
# -------
# The merged data point
# """
# if self._load_on_fly:
# logging.error("Don't call merge when loading on the fly")
# return
# target_index = indices[0]
# data_to_merge = []
# target_data = self.get_data(target_index)
# for idx in indices[1:]:
# data = self.get_data(idx)
# data_to_merge.append(data)
# units1 = self.get_units(target_index)
# for idx, data in zip(indices[1:], data_to_merge):
# units2 = self.get_units(idx)
# if force_equal_units and (not units1 == units2):
# logging.error(
# "Can't merge files with unequal units\n" +
# "Units are {} , {}".format(units1, units2))
# return
# # Merge the spikes based on times (waveforms not done yet)
# new_spike_times = (
# data.spike.get_timestamp() +
# target_data.spike.get_duration())
# new_duration = (
# target_data.spike.get_duration() +
# data.spike.get_duration())
# new_tags = data.spike.get_unit_tags()
# target_data.spike._timestamp = np.append(
# target_data.spike._timestamp, new_spike_times)
# target_data.spike._unit_Tags = np.append(
# target_data.spike._unit_Tags, new_tags)
# target_data.spike._set_duration(new_duration)
# # Merge the spatial information based on times
# new_spat_times = (
# data.spatial._time +
# target_data.spike.get_duration())
# new_pos_x = data.spatial._pos_x
# new_pos_y = data.spatial._pos_y
# new_direction = data.spatial._direction
# new_speed = data.spatial._speed
# target_data.spatial._time = np.append(
# target_data.spatial._time, new_spat_times)
# # NB this may not work properly due to different borders
# target_data.spatial._pos_x = np.append(
# target_data.spatial._pos_x, new_pos_x)
# target_data.spatial._pos_y = np.append(
# target_data.spatial._pos_y, new_pos_y)
# target_data.spatial._direction = np.append(
# target_data.spatial._direction, new_direction)
# target_data.spatial._speed = np.append(
# target_data.spatial._speed, new_speed)
# self._container[target_index] = target_data
# for idx in indices[1:]:
# self._container.pop(idx)
# self._units.pop(idx)
# indices[1:] = [a - 1 for a in indices[1:]]
# self._count_num_units()
# self._container[target_index].set_unit_no(
# self.get_units(target_index)[0])
# return self.get_data(target_index)