# -*- coding: utf-8 -*-
"""
This module implements utility functions and classes for NeuroChaT software.
@author: Md Nurul Islam; islammn at tcd dot ie
"""
import logging
import time
import datetime
import traceback
from collections import OrderedDict as oDict
import os
from os import listdir
from os.path import isfile, isdir, join
import re
import math
import pandas as pd
import numpy as np
import numpy.linalg as nalg
import scipy
import scipy.stats as stats
import scipy.signal as sg
from scipy.fftpack import fft
[docs]class NLog(logging.Handler):
"""
Class for handling log information (messages, errors and warnings).
It formats the incoming message in HTML and sends it to the log
interface of NeuroChaT.
"""
def __init__(self):
super().__init__()
self.setup()
[docs] def setup(self):
"""
Remove all the logging handlers and set up a logger in HTML format.
Parameters
----------
None
Returns
-------
None
"""
log = logging.getLogger()
for hdlr in log.handlers[:]: # remove all old handlers
log.removeHandler(hdlr)
fmt = logging.Formatter(
'%(asctime)s (%(filename)s) %(levelname)s-- %(message)s', '%H:%M:%S')
self.setFormatter(fmt)
log.addHandler(self)
# You can control the logging level
log.setLevel(logging.DEBUG)
logging.addLevelName(20, '')
[docs] def emit(self, record):
"""
Format the incoming record and display it.
Parameters
----------
record
Log record to display or store
Returns
-------
None
"""
msg = self.format(record)
level = record.levelname
msg = level + ':' + msg
print(msg)
time.sleep(0.25)
[docs]class Singleton(object):
"""Create a Singleton object created from a subclass of this class."""
[docs] def __new__(cls, *arg, **kwarg):
"""Create a Singleton object created from a subclass of this class."""
if not hasattr(cls, '_instance'):
cls._instance = super().__new__(cls, *arg, **kwarg)
return cls._instance
[docs]def bhatt(X1, X2):
"""
Calculate Bhattacharyya coefficient and distance between distributions.
Parameters
----------
X1, X2 : ndarray
Distributions under consideration
Returns
-------
bc, d : float
Bhattacharyya coefficient and Bhattacharyya distance
"""
r1, c1 = X1.shape
r2, c2 = X2.shape
if c1 == c2:
mu1 = X1.mean(axis=0)
mu2 = X2.mean(axis=0)
C1 = np.cov(X1.T)
C2 = np.cov(X2.T)
C = (C1 + C2) / 2
chol = nalg.cholesky(C).T
dmu = (mu1 - mu2) @ nalg.inv(chol)
try:
d = 0.125 * dmu @ (dmu.T) + 0.5 * np.log(
nalg.det(C) / np.sqrt(nalg.det(C1) * nalg.det(C2)))
except BaseException:
d = 0.125 * dmu @ (dmu.T) + 0.5 * np.log(
np.abs(nalg.det(C @ nalg.inv(scipy.linalg.sqrtm(C1 @ C2)))))
bc = np.exp(-1 * d)
return bc, d
else:
logging.error(
'Cannot measure Bhattacharyya distance, column sizes do not match!')
[docs]def butter_filter(x, Fs, *args):
"""
Filter using bidirectional zero-phase shift Butterworth filter.
Parameters
----------
x : ndarray
Data or signal to filter
Fs : Sampling frequency
*kwargs
Arguments with filter paramters
Returns
-------
ndarray
Filtered signal
"""
gstop = 20 # minimum dB attenuation at stopabnd
gpass = 3 # maximum dB loss during ripple
for arg in args:
if isinstance(arg, str):
filttype = arg
if filttype == 'lowpass' or filttype == 'highpass':
wp = args[1] / (Fs / 2)
if wp > 1:
wp = 1
if filttype == 'lowpass':
logging.warning(
'Butterworth filter critical frequency Wp is capped at 1')
else:
logging.error('Cannot highpass filter over Nyquist frequency!')
elif filttype == 'bandpass':
if len(args) < 4:
logging.error('Insufficient Butterworth filter arguments')
else:
wp = np.array(args[1:3]) / (Fs / 2)
if wp[0] >= wp[1]:
logging.error(
'Butterworth filter lower cutoff frequency must be smaller than upper cutoff frequency!')
if wp[0] == 0 and wp[1] >= 1:
logging.error(
'Invalid filter specifications, check cut off frequencies and sampling frequency!')
elif wp[0] == 0:
wp = wp[1]
filttype = 'lowpass'
logging.warning('Butterworth filter type selected: lowpass')
elif wp[1] >= 1:
wp = wp[0]
filttype = 'highpass'
logging.warning('Butterworth filter type selected: highpass')
if filttype == 'lowpass':
ws = min([wp + 0.1, 1])
elif filttype == 'highpass':
ws = max([wp - 0.1, 0.01 / (Fs / 2)])
elif filttype == 'bandpass':
ws = np.zeros_like(wp)
ws[0] = max([wp[0] - 0.1, 0.01 / (Fs / 2)])
ws[1] = min([wp[1] + 0.1, 1])
min_order, min_wp = sg.buttord(wp, ws, gpass, gstop)
b, a = sg.butter(min_order, min_wp, btype=filttype, output='ba')
return sg.filtfilt(b, a, x)
[docs]def chop_edges(x, xlen, ylen):
"""
Chop the edges of a firing rate map.
They are considered to be the edges if they are not
visited at all or with zero firing rate.
Parameters
----------
x : ndarray
Matrix of firing rate
xlen : int
Maximum length of the x-axis
ylen : int
Maximum length of the y-axis
Returns
-------
low_ind : list of int
Index of low end of valid edges
hig_end :
Index of high end of valid edges
y : ndarray
Chopped firing map
"""
y = np.copy(x)
low_ind = [0, 0]
high_ind = [x.shape[0], x.shape[1]]
MOVEON = True
while y.shape[1] > xlen and MOVEON:
no_filled_bins1 = np.sum(y[:, 0] > 0)
no_filled_bins2 = np.sum(y[:, -1] > 0)
if no_filled_bins1 == 0:
low_ind[1] += 1
MOVEON = True
else:
MOVEON = False
if no_filled_bins2 == 0:
high_ind[1] -= 1
MOVEON = True
else:
MOVEON = False
y = x[low_ind[0]: high_ind[0], low_ind[1]:high_ind[1]]
MOVEON = True
while y.shape[0] > ylen and MOVEON:
no_filled_bins1 = np.sum(y[0, :] > 0)
no_filled_bins2 = np.sum(y[-1, :] > 0)
if no_filled_bins1 == 0:
low_ind[0] += 1
MOVEON = True
else:
MOVEON = False
if no_filled_bins2 == 0:
high_ind[0] -= 1
MOVEON = True
else:
MOVEON = False
y = x[low_ind[0]: high_ind[0], low_ind[1]:high_ind[1]]
return low_ind, high_ind, y
[docs]def corr_coeff(x1, x2):
"""
Correlation coefficient between two numeric series or two signals.
Parameters
----------
x1, x2 : ndarray
Input numeric array or signals
Returns
-------
float
Correlation coefficient of input arrays
"""
try:
return np.sum(np.multiply(x1 - x1.mean(), x2 - x2.mean())) / \
np.sqrt(np.sum((x1 - x1.mean())**2) * np.sum((x2 - x2.mean())**2))
except BaseException:
return 0
[docs]def extrema(x, mincap=None, maxcap=None):
"""
Find the extrema in a numeric array or a signal.
Parameters
----------
mincap
Maximum value for the minima
maxcap
Minimum value for the maxima
Returns
-------
xmax : ndarray
Maxima values
imax : ndarray
Maxima indices
xmin : ndarray
Minima values
imin : ndarray
Minima indices
"""
x = np.array(x)
# Flat peaks at the end of the series are not considered yet
dx = np.diff(x)
if not np.any(dx):
return [], [], [], []
a = find(dx != 0) # indices where x changes
lm = find(np.diff(a) != 1) + 1 # indices where a is not sequential
d = a[lm] - a[lm - 1]
a[lm] = a[lm] - np.floor(d // 2)
xa = x[a] # series without flat peaks
d = np.sign(xa[1:-1] - xa[:-2]) - np.sign(xa[2:] - xa[1:-1])
imax = a[find(d > 0) + 1]
xmax = x[imax]
imin = a[find(d < 0) + 1]
xmin = x[imin]
if mincap:
imin = imin[xmin <= mincap]
xmin = xmin[xmin <= mincap]
if maxcap:
imax = imax[xmax <= maxcap]
xmax = xmax[xmax <= maxcap]
return xmax, imax, xmin, imin
[docs]def fft_psd(x, Fs, nfft=None, side='one', ptype='psd'):
"""
Calculate the Fast Fourier Transform (FFT) of a signal.
Parameters
----------
x : ndarray
Input signal
Fs
Sampling frequency
nfft : int
Number of FFT points
side : str
'one'-sided or 'two'-sided FFT
ptype : str
Calculates power-spectral density if set to 'psd'
Returns
-------
x_fft : ndarray
FFT of input
f : ndarray
FFt frequency
"""
if nfft is None:
nfft = 2**(np.floor(np.log2(len(x))) + 1)
if nfft < Fs:
nfft = 2**(np.floor(np.log2(Fs)) + 1)
nfft = int(nfft)
dummy = np.zeros(nfft)
if nfft > len(x):
dummy[:len(x)] = x
x = dummy
winfun = np.hanning(nfft)
xf = np.arange(0, Fs, Fs / nfft)
f = xf[0: int(nfft / 2) + 1]
if side == 'one':
x_fft = fft(np.multiply(x, winfun), nfft)
if ptype == 'psd':
x_fft = np.absolute(x_fft[0: int(nfft / 2) + 1])**2 / nfft**2
x_fft[1:-1] = 2 * x_fft[1:-1]
return x_fft, f
[docs]def find(X, n=None, direction='all'):
"""
Find the non-zero entries of a signal or array.
Parameters
----------
X : ndarray or list
Array or list of numbers whose non-zero entries need to find out
n : int
Number of such entries
direction : str
If 'all', all entries of length n are returned.
If 'first', first n entries are returned.
If 'last', last n entries are returned.
Returns
-------
ndarray
Indices of non-zero entries.
"""
if isinstance(X, list):
X = np.array(X)
X = X.flatten()
if n is None:
n = len(X)
ind = np.where(X)[0]
if ind.size:
if direction == 'all' or direction == 'first':
ind = ind[:n]
elif direction == 'last':
ind = ind[np.flipud(np.arange(-1, -(n + 1), - 1))]
return np.array(ind)
[docs]def find2d(X, n=None):
"""
Find the non-zero entries of a matrix.
Parameters
----------
X : ndarray
Matrix whose non-zero entries need to find out
n : int
Number of such entries
Returns
-------
ndarray
x-indices of non-zero entries.
ndarray
y-indices of non-zero entries.
"""
if len(X.shape) == 2:
J = []
I = []
for r in np.arange(X.shape[0]):
I.extend(find(X[r, ]))
J.extend(r * np.ones((len(find(X[r, ])), ), dtype=int))
if len(I):
if n is not None and n < len(I):
I = I[:n]
J = J[:n]
return np.array(J), np.array(I)
else:
logging.error('ndrray is not 2D. Check shape attributes of the input!')
[docs]def find_chunk(x):
"""
Find size and indices of chunks of non-zero segments in an array.
Parameters
----------
x : ndarray
Inout array whose non-zero chunks are to be explored
Returns
-------
segsize : ndarray
Lengths of non-zero chunks
segind : ndarray
Indices of non-zero chunks
"""
# x is a binary array input i.e. x= data> 0.5 will find all the chunks in
# data where data is greater than 0.5
i = 0
segsize = []
segind = np.zeros(x.shape)
while i < len(x):
if x[i]:
c = 0
j = i
while i < len(x):
if x[i]:
c += 1
i += 1
else:
break
segsize.append(c)
segind[j:i] = c # indexing by size of the chunk
i += 1
return segsize, segind
[docs]def hellinger(X1, X2):
"""
Calculate Hellinger distance between two distributions.
Parameters
----------
X1, X2 : ndarray
Distributions under consideration
Returns
-------
d : float
Calculated Hellinger distance
"""
if X1.shape[1] != X2.shape[1]:
logging.error(
'Hellinger distance cannot be computed, column sizes do not match!')
else:
return np.sqrt(1 - bhatt(X1, X2)[0])
[docs]def histogram(x, bins):
"""
Calculate the histogram count of input array.
This function is not a replacement of np.histogram;
it is created for convenience of binned-based rate calculations
and mimicking matlab histc that includes digitized indices
Parameters
----------
x : ndarray
Array whose histogram needs to be calculated
bins
Number of histogram bins
Returns
-------
ndarray
Histogram count
ndarray
Histogram bins(lowers edges)
"""
if isinstance(bins, int):
bins = np.arange(np.min(x), np.max(x), (np.max(x) - np.min(x)) / bins)
bins = np.append(bins, bins[-1] + np.mean(np.diff(bins)))
return np.histogram(x, bins)[0], np.digitize(x, bins) - 1, bins[:- 1]
[docs]def histogram2d(y, x, ybins, xbins):
"""
Calculate the joint histogram count of two arrays.
This function is not a replacement of np.histogram2d;
it is created for convenience of binned-based rate calculations
and mimicking matlab histc that includes digitized indices
Parameters
----------
y, x : ndarray
Arrays whose histogram needs to be calculated
ybins
Number of histogram bins in y-axis
xbins
Number of histogram bins in x-axis
Returns
-------
ndarray
Histogram count
ndarray
Histogram bins in x-axis (lowers edges)
ndarray
Histogram bins in y-axis (lowers edges)
"""
if isinstance(xbins, int):
xbins = np.arange(np.min(x), np.max(
x), (np.max(x) - np.min(x)) / xbins)
xbins = np.append(xbins, xbins[-1] + np.mean(np.diff(xbins)))
if isinstance(ybins, int):
ybins = np.arange(np.min(y), np.max(
y), (np.max(y) - np.min(y)) / ybins)
ybins = np.append(ybins, ybins[-1] + np.mean(np.diff(ybins)))
return np.histogram2d(y, x, [ybins, xbins])[0], ybins[:-1], xbins[:-1]
[docs]def linfit(X, Y, getPartial=False):
"""
Calculate the linear regression coefficients in least-square sense.
Parameters
----------
X : ndarray
Matrix with input variables or factors (num_dim X num_obs)
Y : ndarray
Array of oservation data
getPartial : bool
Get the partial correlation coefficients if 'True'
Returns
-------
_results : dict
Dictionary with results of least-square optimization of linear regression
"""
_results = oDict()
if len(X.shape) == 2:
Nd, Nobs = X.shape
else:
Nobs = X.shape[0]
Nd = 1
if Nobs == len(Y):
A = np.vstack([X, np.ones(X.shape[0])]).T
B = np.linalg.lstsq(A, Y, rcond=-1)[0]
Y_fit = np.matmul(A, B)
_results['coeff'] = B[:-1]
_results['intercept'] = B[-1]
_results['yfit'] = Y_fit
_results.update(residual_stat(Y, Y_fit, 1))
else:
logging.error('linfit: Number of rows in X and Y does not match!')
if Nd > 1 and getPartial:
semiCorr = np.zeros(Nd) # Semi partial correlation
for d in np.arange(Nd):
part_results = linfit(np.delete(X, 1, axis=0), Y, getPartial=False)
semiCorr[d] = _results['Rsq'] - part_results['Rsq']
_results['semiCorr'] = semiCorr
return _results
[docs]def nxl_write(
file_name, data_frame, sheet_name='Sheet1', startRow=0, startColumn=0):
"""
Write Pandas DataFrame to excel file, wraps Pandas.ExcelWriter().
Parameters
----------
filename : str
Name of the output file
data_frame : pandas.DataFrame
DataFrame to export
sheet_name : str
Sheet name of the Excel file where the data is written
startRow : int
Which row in the file the data writing should start
startColumn : int
Which column in the file the data writing should start
Returns
-------
None
"""
# Create a Pandas Excel writer using XlsxWriter as the engine.
writer = pd.ExcelWriter(file_name, engine='xlsxwriter')
# Convert the dataframe to an XlsxWriter Excel object.
data_frame.to_excel(writer, sheet_name)
# Close the Pandas Excel writer and output the Excel file.
writer.save()
[docs]def residual_stat(y, y_fit, p):
"""
Calculate the goodness of fit and other residual statistics.
These are calculated between observed and fitted values from a model.
Parameters
----------
y : ndarray
Observed data
y_fit : ndarray
Fitted data to a linear model
p : int
Model order
Returns
-------
_results : dict
Dictionary of residual statistics
"""
# p= total explanatory variables excluding constants
_results = oDict()
res = y - y_fit
ss_res = np.sum(res**2)
ss_tot = np.sum((y - np.mean(y))**2)
r_sq = 1 - ss_res / ss_tot
adj_r_sq = 1 - (ss_res / ss_tot) * ((len(y) - 1) / (len(y) - p - 1))
_results['Pearson R'], _results['Pearson P'] = stats.pearsonr(y, y_fit)
_results['Rsq'] = r_sq
_results['adj Rsq'] = adj_r_sq
return _results
[docs]def rot_2d(x, theta):
"""
Rotate a firing map by a specified angle.
Parameters
----------
x : ndarray
Matrix of firing rate map
theta
Angle of rotation in theta
Returns
-------
ndarray
Rotated matrix
"""
return scipy.ndimage.interpolation.rotate(
x, theta, reshape=False, mode='constant', cval=np.min(x))
[docs]def angle_between_points(a, b, c):
"""
Return the angle between the lines ab and bc, <abc.
This function always returns an angle less than 180degrees.
The orientation of the lines can be used to determine which
side of the lines this angle is formed from.
Returns np.nan if ab and bc are the same point.
Parameters
----------
a : ndarray
The first point
b : ndarray
The second point
c : the last point
Returns
-------
float
The angle in degrees
"""
ba = a - b
bc = c - b
length_ba = np.linalg.norm(ba)
length_bc = np.linalg.norm(bc)
if length_bc != 0 and length_ba != 0:
cosine_angle = np.dot(ba, bc) / (length_ba * length_bc)
angle = np.arccos(cosine_angle)
else:
logging.error(
"Angle between points: Two points are the same" +
" can't measure angle as a result")
angle = np.NAN
return np.degrees(angle)
[docs]def centre_of_mass(co_ords, weights, axis=0):
"""
Calculate the co-ordinate centre of mass for a 2D system of particles.
The particles all have co-ords and weights.
Parameters
----------
co_ords : ndarray
Array of co-ordinate positions,
assumed to have co_ords.shape[axis] co-ordinates.
weights : ndarray
Array of corresponding weights
axis : int, default 0
The axis along which the co-ordinates are specified, expected 0 or 1
Returns
-------
ndarray
Co-ordinate of the centre of mass
"""
shape = co_ords.shape
if axis == 0:
weighted = np.multiply(
co_ords,
np.repeat(weights, shape[1]).reshape(shape))
elif axis == 1:
weighted = np.multiply(
co_ords,
np.tile(weights, shape[0]).reshape(shape))
else:
logging.error("centre_of_mass: Expected axis to be 0 or 1")
return np.sum(weighted, axis=axis) / np.sum(weights)
[docs]def smooth_1d(x, kernel_type='b', kernel_size=5, axis=0, **kwargs):
"""
Filter a 1D array or signal.
Parameters
----------
x : ndarray
Array or signal to be filtered.
If matrix, each column or row is filtered
individually depending on 'dir' parameter that takes either '0' for along-column
and '1' for along-row filtering.
kernel_type : str
'b' for moving average or box filter. 'g' for Gaussian filter.
'hs' for Heaviside filter
'hg' for half-Gaussian filter
kernel_size : int
Box size for box filter and sigma for Gaussian filter
axis : int
Defaults to 0. The axis along which to smooth matrices.
Returns
-------
ndarray
Filtered data
"""
def pad_and_convolve(xx, kernel):
npad = len(kernel)
xx = np.pad(xx, (npad, npad), 'edge')
yy = np.convolve(xx, kernel, mode='same')
return yy[npad:-npad]
x = np.array(x)
half_width = kernel_size / 2
xx = np.arange(-half_width, half_width + 1, 1)
if kernel_type == 'g':
half_width = kernel_size / 2
xx = np.arange(-half_width, half_width + 1, 1)
sigma = kernel_size / (2 * 2.7)
kernel = np.exp(-(xx**2) / (2 * sigma**2)) / \
(np.sqrt(2 * np.pi) * sigma)
elif kernel_type == 'b':
kernel = np.ones(kernel_size) / kernel_size
elif kernel_type == 'hs':
half_width = kernel_size
xx = np.arange(-half_width, half_width + 1, 1)
sigma = kernel_size / 2 / np.sqrt(3)
kernel = (0.5 / (np.sqrt(3) * sigma)) * \
(xx < 2 * np.sqrt(3) * sigma and xx >= 0)
elif kernel_type == 'hg':
half_width = kernel_size
xx = np.arange(-half_width, half_width + 1, 1)
sigma = 2 * kernel_size / 2 / np.sqrt(3)
kernel = np.exp(-(xx**2) / (2 * sigma**2)) / \
(np.sqrt(2 * np.pi) * sigma)
kernel[xx < 0] = 0
kernel[xx > 0] = 2 * kernel[xx > 0]
result = np.apply_along_axis(
lambda xx: pad_and_convolve(xx, kernel), axis, x)
return result
[docs]def smooth_2d(x, filttype='b', filtsize=5):
"""
Filter a 2D array or signal.
Parameters
----------
x : ndarray
Matrix to be filtered
filttype : str
'b' for moving average or box filter. 'g' for Gaussian filter.
filtsize
Box size for box filter and sigma for Gaussian filter
Returns
-------
smoothX
Filtered matrix
"""
nanInd = np.isnan(x)
x[nanInd] = 0
if filttype == 'g':
halfwid = np.round(3 * filtsize)
xx, yy = np.meshgrid(np.arange(-halfwid, halfwid + 1, 1),
np.arange(-halfwid, halfwid + 1, 1), copy=False)
# /(2*np.pi*filtsize**2) # This is the scaling used before;
filt = np.exp(-(xx**2 + yy**2) / (2 * filtsize**2))
# But tested with ones(50, 50); gives a hogher value
filt = filt / np.sum(filt)
elif filttype == 'b':
filt = np.ones((filtsize, filtsize)) / filtsize**2
smoothX = sg.convolve2d(x, filt, mode='same')
smoothX[nanInd] = np.nan
return smoothX
[docs]def find_true_ranges(arr, truth_arr, min_range, return_idxs=False):
"""
Return a list of ranges where truth values occur in sorted array.
Also return the corresponding values from the input array.
Note
----
The input array arr is assumed to be a sorted list.
Parameters
----------
arr : ndarray
list of values to get ranges from, equal in length to truth_arr
truth_arr : ndarray
list of truth values to make the ranges
min_range : int or float
the minimum length of range
Returns
-------
list
A list of tuples, ranges in arr where truth values are truth_arr
"""
in_range = False
ranges = []
range_idxs = []
for idx, b in enumerate(truth_arr):
if b and not in_range:
in_range = True
range_start = arr[idx]
range_start_idx = idx
if not b and in_range:
in_range = False
range_end = arr[idx - 1]
range_end_idx = idx
if range_end - range_start >= min_range:
ranges.append((range_start, range_end))
for i in range(range_start_idx, range_end_idx):
range_idxs.append(i)
if not return_idxs:
return ranges
else:
return ranges, range_idxs
[docs]def find_peaks(data, **kwargs):
"""
Return the peaks in the data based on gradient calculations.
Parameters
----------
kwargs
start : int
Where to start looking for peaks in the data, default 0
end : int
Where to stop looking for peaks in the data, default data.size - 1
thresh : float
Don't consider any peaks with a value below this, default 0
"""
data = np.array(data)
slope = np.diff(data)
start_at = kwargs.get('start', 0)
end_at = kwargs.get('end', slope.size)
thresh = kwargs.get('thresh', 0)
peak_loc = [j for j in np.arange(start_at, end_at - 1)
if slope[j] > 0 and slope[j + 1] <= 0]
peak_val = [data[peak_loc[i]] for i in range(0, len(peak_loc))]
valid_loc = [
i for i in range(0, len(peak_loc)) if peak_val[i] >= thresh]
if len(valid_loc) == 0:
return []
peak_val, peak_loc = zip(*((peak_val[i], peak_loc[i]) for i in valid_loc))
return np.array(peak_val), np.array(peak_loc)
[docs]def log_exception(ex, more_info=""):
"""
Log an expection and additional info.
Parameters
----------
ex : Exception
The python exception that occured
more_info : str, optional
Additional string to log. Default is "".
Returns
-------
None
"""
default_loc = os.path.join(
os.path.expanduser("~"), ".nc_saved", "nc_caught.txt")
now = datetime.datetime.now()
# tb = traceback.format_tb(ex.__traceback__)
make_dir_if_not_exists(default_loc)
with open(default_loc, "a+") as f:
f.write("\n----------Caught Exception at {}----------\n".format(now))
traceback.print_exc(file=f)
logging.error(
"{} failed with caught exception.\nSee {} for more information.".format(
more_info, default_loc), exc_info=False)
# template = "{0} because exception of type {1} occurred. Arguments:\n{2!r}"
# message = template.format(more_info, type(ex).__name__, ex.args)
# logging.error(message)
[docs]def window_rms(a, window_size, mode="same"):
"""
Calculate the rms envelope, similar to matlab.
Parameters
----------
a : ndarray
The input signal to envelope.
window_size : int
The length of the window to convolve the signal with.
mode : str
The mode determines how many points are output
mode "valid" will have no border effects
mode "same" will produce a value for each input
See np.convolve for more information.
Returns
-------
np.ndarray
The RMS envelope of the signal
"""
a2 = np.power(a, 2)
window = np.ones(window_size) / float(window_size)
return np.sqrt(np.convolve(a2, window, mode))
[docs]def distinct_window_rms(a, N):
"""
Calculate the rms of an array in windows of N data points.
Parameters
----------
a : np.ndarray
The input array to compute the RMS of.
N : int
The length of the window to compute RMS in.
Returns
-------
list
The RMS in each window.
"""
a = np.array(a)
a = np.square(a) / float(N)
rms_array = []
rms = 0
# For now, just throw away the last window if it does not fit
for idx, point in enumerate(a):
rms += point
if idx % N == N - 1:
rms_array.append(np.sqrt(rms))
rms = 0
return rms_array
[docs]def static_vars(**kwargs):
"""Return decorator to create a function with static variables."""
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
[docs]@static_vars(colorcells=[])
def get_axona_colours(index=None):
"""
Create Axona cell colours.
Parameters
----------
index : int
Optional integer to get colours at
Returns
-------
list | tuple
A list of colours as rgb tuples with values in 0 to 1.
Or a single rgb tuple if index is specified.
"""
if len(get_axona_colours.colorcells) == 0:
# create Axona cell colours if don't exist
get_axona_colours.colorcells.append((0, 0, 200 / 255))
get_axona_colours.colorcells.append((80 / 255, 1, 80 / 255))
get_axona_colours.colorcells.append((1, 0, 0))
get_axona_colours.colorcells.append((245 / 255, 0, 1))
get_axona_colours.colorcells.append((75 / 255, 200 / 255, 255 / 255))
get_axona_colours.colorcells.append((0 / 255, 185 / 255, 0 / 255))
get_axona_colours.colorcells.append((255 / 255, 185 / 255, 50 / 255))
get_axona_colours.colorcells.append((0 / 255, 150 / 255, 175 / 255))
get_axona_colours.colorcells.append((150 / 255, 0 / 255, 175 / 255))
get_axona_colours.colorcells.append((170 / 255, 170 / 255, 0 / 255))
get_axona_colours.colorcells.append((200 / 255, 0 / 255, 0 / 255))
get_axona_colours.colorcells.append((255 / 255, 255 / 255, 0 / 255))
get_axona_colours.colorcells.append((140 / 255, 140 / 255, 140 / 255))
get_axona_colours.colorcells.append((0 / 255, 255 / 255, 255 / 255))
get_axona_colours.colorcells.append((255 / 255, 0 / 255, 160 / 255))
get_axona_colours.colorcells.append((175 / 255, 75 / 255, 75 / 255))
get_axona_colours.colorcells.append((255 / 255, 155 / 255, 175 / 255))
get_axona_colours.colorcells.append((190 / 255, 190 / 255, 190 / 255))
get_axona_colours.colorcells.append((255 / 255, 255 / 255, 75 / 255))
get_axona_colours.colorcells.append((154 / 255, 205 / 255, 50 / 255))
get_axona_colours.colorcells.append((255 / 255, 99 / 255, 71 / 255))
get_axona_colours.colorcells.append((0 / 255, 255 / 255, 127 / 255))
get_axona_colours.colorcells.append((255 / 255, 140 / 255, 0 / 255))
get_axona_colours.colorcells.append((32 / 255, 178 / 255, 170 / 255))
get_axona_colours.colorcells.append((255 / 255, 69 / 255, 0 / 255))
get_axona_colours.colorcells.append((240 / 255, 230 / 255, 140 / 255))
get_axona_colours.colorcells.append((100 / 255, 145 / 255, 237 / 255))
get_axona_colours.colorcells.append((255 / 255, 218 / 255, 185 / 255))
get_axona_colours.colorcells.append((153 / 255, 50 / 255, 204 / 255))
get_axona_colours.colorcells.append((250 / 255, 128 / 255, 114 / 255))
if index is None:
return get_axona_colours.colorcells
else:
if index >= len(get_axona_colours.colorcells):
logging.error("Passed colour index out of range")
return
return get_axona_colours.colorcells[index]
[docs]def has_ext(filename, ext, case_sensitive_ext=False):
"""
Check if the filename ends in the extension.
Parameters
----------
filename : str
The name of the file
ext : str
The extension, may have leading dot (e.g txt == .txt)
case_sensitive_ext: bool, optional. Defaults to False,
Whether to match the case of the file extension
Returns
-------
bool indicating if the filename has the extension
"""
if ext is None:
return True
if ext[0] != ".":
ext = "." + ext
if case_sensitive_ext:
return filename[-len(ext):] == ext
else:
return filename[-len(ext):].lower() == ext.lower()
[docs]def get_all_files_in_dir(
in_dir, ext=None, return_absolute=True,
recursive=False, verbose=False, re_filter=None,
case_sensitive_ext=False):
"""
Get all files in the directory with the given extension.
Parameters
----------
in_dir : str
The absolute path to the directory
ext : str, optional. Defaults to None.
The extension of files to get.
return_absolute : bool, optional. Defaults to True.
Whether to return the absolute filename or not.
recursive : bool, optional. Defaults to False.
Whether to recurse through directories.
verbose : bool, optional. Defaults to False.
Whether to print the files found.
re_filter : str, optional. Defaults to None
a regular expression used to filter the results
case_sensitive_ext : bool, optional. Defaults to False,
Whether to match the case of the file extension
Returns
-------
list
A list of filenames
"""
if not isdir(in_dir):
print("Non existant directory " + str(in_dir))
return []
def match_filter(f):
if re_filter is None:
return True
search_res = re.search(re_filter, f)
return search_res is not None
def ok_file(root_dir, f):
good_ext = has_ext(f, ext, case_sensitive_ext=case_sensitive_ext)
good_file = isfile(join(root_dir, f))
good_filter = match_filter(f)
return good_ext and good_file and good_filter
def convert_to_path(root_dir, f):
return join(root_dir, f) if return_absolute else f
if verbose:
print("Adding following files from {}".format(in_dir))
if recursive:
onlyfiles = []
for root, _, filenames in os.walk(in_dir):
start_root = root[:len(in_dir)]
if len(root) == len(start_root):
end_root = ""
else:
end_root = root[len(in_dir + os.sep):]
for filename in filenames:
filename = join(end_root, filename)
if ok_file(start_root, filename):
to_add = convert_to_path(start_root, filename)
if verbose:
print(to_add)
onlyfiles.append(to_add)
else:
onlyfiles = [
convert_to_path(in_dir, f) for f in sorted(listdir(in_dir))
if ok_file(in_dir, f)
]
if verbose:
for f in onlyfiles:
print(f)
if verbose:
print()
return onlyfiles
[docs]def make_dir_if_not_exists(location):
"""Make directory structure for given location."""
os.makedirs(os.path.dirname(location), exist_ok=True)
[docs]def remove_extension(filename, keep_dot=True, return_ext=False):
"""
Return the filename without the extension.
Very similar to os.path.splitext()
Parameters
----------
filename : str
The filename to remove extension from.
keep_dot : bool
Whether to return filename + ".".
return_ext : bool
Whether to return filename or filename, ext.
Returns
-------
str | tuple
str if return_ext is False, the filename with no ext
(str, str) if return_ext is True, (filename, ext)
"""
modifier = 0 if keep_dot else 1
ext = filename.split(".")[-1]
remove = len(ext) + modifier
if return_ext:
return filename[:-remove], ext
else:
return filename[:-remove]
[docs]class RecPos:
"""
Read .pos file.
Work in progress and does not support head direction.
TODO
----
Read different numbers of LEDs
Verbose file reading (prints info like number of untracked points)
Attributes
----------
pos_file : str
The path to the position file.
x : np.ndarray
The x position data
y : np.ndarray
The y position data
speed : np.ndarray
The speed in cm/s
head_direction : np.ndarray
The head direction data
raw_position : dict
The raw position data decoded from .pos file.
Parameters
----------
file_name : str
The path to the .set file or .pos file to load the data from
load : bool
If file_name is passed, load the data from this
"""
def __init__(self, file_name=None, load=True):
"""See help(RecPos)."""
self.pos_file = ""
self.x = np.array([])
self.y = np.array([])
self.speed = np.array([])
self.head_direction = np.array([])
self.raw_position = {}
if file_name is not None:
self.set_file(file_name)
if load:
self.load()
[docs] def set_file(self, file_name):
"""Set the input file - can be .pos or .set"""
file_directory, file_basename = os.path.split(file_name)
file_tag, file_extension = os.path.splitext(file_basename)
if file_extension != ".pos":
self.pos_file = os.path.join(file_directory, file_tag + ".pos")
else:
self.pos_file = file_name
[docs] def load(self, file_name=None):
"""Load data, optionally from given file name."""
if file_name is not None:
self.set_file(file_name)
self.load_raw()
self.calculate_position()
self.calculate_speed()
self.calculate_angular()
[docs] def load_raw(self):
"""Load raw position data."""
self.bytes_per_sample = 20 # Axona daqUSB manual
if os.path.isfile(self.pos_file):
with open(self.pos_file, "rb") as f:
while True:
line = f.readline()
try:
line = line.decode("latin-1")
except BaseException:
break
if line == "":
break
if line.startswith("trial_date"):
# Blank pos file
if line.strip() == "trial_date":
logging.error("No position data.")
return
# date = " ".join(line.replace(",", " ").split()[1:])
# if line.startswith("num_colours"):
# colors = int(line.split()[1])
if line.startswith("min_x"):
self.min_x = int(line.split()[1])
if line.startswith("max_x"):
self.max_x = int(line.split()[1])
if line.startswith("min_y"):
self.min_y = int(line.split()[1])
if line.startswith("max_y"):
self.max_y = int(line.split()[1])
if line.startswith("window_min_x"):
self.window_min_x = int(line.split()[1])
if line.startswith("window_max_x"):
self.window_max_x = int(line.split()[1])
if line.startswith("window_min_y"):
self.window_min_y = int(line.split()[1])
if line.startswith("window_max_y"):
self.window_max_y = int(line.split()[1])
if line.startswith("bytes_per_timestamp"):
self.bytes_per_tstamp = int(line.split()[1])
if line.startswith("bytes_per_coord"):
self.bytes_per_coord = int(line.split()[1])
if line.startswith("pixels_per_metre"):
self.pixels_per_metre = int(float(line.split()[1]))
self.pixels_per_cm = self.pixels_per_metre / 100.0
if line.startswith("num_pos_samples"):
self.total_samples = int(line.split()[1])
if line.startswith("pos_format"):
info = line.split(" ")[-1]
if info[:-2] != "t,x1,y1,x2,y2,numpix1,numpix2":
logging.error(
".pos reading only supports 2-spot mode currently")
print(info[:-2])
print("t,x1,y1,x2,y2,numpix1,numpix2")
return
if line.startswith("data_start"):
break
f.seek(0, 0)
header_offset = []
while True:
try:
buff = f.read(10).decode("UTF-8")
except BaseException:
break
if buff == "data_start":
header_offset = f.tell()
break
else:
f.seek(-9, 1)
if not header_offset:
print("Error: data_start marker not found!")
else:
f.seek(header_offset, 0)
byte_buffer = np.fromfile(f, dtype="uint8")
big_spotx = np.zeros([self.total_samples, 1])
big_spoty = np.zeros([self.total_samples, 1])
little_spotx = np.zeros([self.total_samples, 1])
little_spoty = np.zeros([self.total_samples, 1])
# pos format: t,x1,y1,x2,y2,numpix1,numpix2 => 20 bytes
for i, k in enumerate(
np.arange(0, self.total_samples * 20, 20)
): # Extract bytes from 20 bytes words
big_spotx[i] = int(
256 * byte_buffer[k + 4] + byte_buffer[k + 5]
) # 4,5 bytes for big LED x
big_spoty[i] = int(
256 * byte_buffer[k + 6] + byte_buffer[k + 7]
) # 6,7 bytes for big LED x
little_spotx[i] = int(
256 * byte_buffer[k + 8] + byte_buffer[k + 9]
)
little_spoty[i] = int(
256 * byte_buffer[k + 10] + byte_buffer[k + 11]
)
self.raw_position = {
"big_spotx": big_spotx,
"big_spoty": big_spoty,
"little_spotx": little_spotx,
"little_spoty": little_spoty,
}
else:
print(f"No pos file found for file {self.pos_file}")
# Methods
[docs] def get_cam_view(self):
self.cam_view = {
"min_x": self.min_x,
"max_x": self.max_x,
"min_y": self.min_y,
"max_y": self.max_y,
}
return self.cam_view
[docs] def get_window_view(self):
try:
self.windows_view = {
"window_min_x": self.window_min_x,
"window_max_x": self.window_max_x,
"window_min_y": self.window_min_y,
"window_max_y": self.window_max_y,
}
return self.windows_view
except BaseException:
print("No window view")
[docs] def get_pixel_per_metre(self):
return self.pixels_per_metre
[docs] def get_raw_pos(self):
bigx = [value[0] for value in self.raw_position["big_spotx"]]
bigy = [value[0] for value in self.raw_position["big_spoty"]]
smallx = [value[0] for value in self.raw_position["little_spotx"]]
smally = [value[0] for value in self.raw_position["little_spoty"]]
return bigx, bigy, smallx, smally
[docs] def filter_max_speed(self, x, y, max_speed=4):
tmp_x = x.copy()
tmp_y = y.copy()
# max speed * distance (m) / 50 samples (s)
threshold = max_speed * self.pixels_per_metre * 50
for i in range(1, len(tmp_x)):
distance = math.sqrt((x[i] - x[i - 1]) ** 2 + (y[i] - y[i - 1]) ** 2)
if distance > threshold:
tmp_x[i] = np.nan
tmp_y[i] = np.nan
return tmp_x, tmp_y
[docs] def get_position(self, raw=False):
if not raw:
return self.x, self.y
else:
return self.get_raw_pos()
[docs] def calculate_position(self, raw=False):
try:
count_missing = 0
bxx, sxx = [], []
byy, syy = [], []
bigx = [value[0] for value in self.raw_position["big_spotx"]]
bigy = [value[0] for value in self.raw_position["big_spoty"]]
smallx = [value[0] for value in self.raw_position["little_spotx"]]
smally = [value[0] for value in self.raw_position["little_spoty"]]
for bx, sx in zip(bigx, smallx): # Try to clean single blocked LED x
if bx == 1023 and sx != 1023:
bx = sx
elif bx != 1023 and sx == 1023:
sx = bx
elif bx == 1023 and sx == 1023:
count_missing += 1
bx = np.nan
sx = np.nan
bxx.append(bx)
sxx.append(sx)
for by, sy in zip(bigy, smally): # Try to clean single blocked LED y
if by == 1023 and sy != 1023:
by = sy
elif by != 1023 and sy == 1023:
sy = by
elif by == 1023 and sy == 1023:
by = np.nan
sy = np.nan
byy.append(by)
syy.append(sy)
# Remove coordinates with max_speed > 4ms
bxx, byy = self.filter_max_speed(bxx, byy)
sxx, syy = self.filter_max_speed(sxx, syy)
# Interpolate missing values
bxx = (pd.Series(bxx).astype(float)).interpolate("linear")
sxx = (pd.Series(sxx).astype(float)).interpolate("linear")
byy = (pd.Series(byy).astype(float)).interpolate("linear")
syy = (pd.Series(syy).astype(float)).interpolate("linear")
if raw:
return [(bxx, byy), (sxx, syy)]
# Average both LEDs
x = list((bxx + sxx) / 2)
y = list((byy + syy) / 2)
# Boxcar filter 400 ms
# sample rate = 20 ms
b = int(400 / 20)
kernel = np.ones(b) / b
def pad_and_convolve(xx, kernel):
npad = len(kernel)
xx = np.pad(xx, (npad, npad), "edge")
yy = np.convolve(xx, kernel, mode="same")
return yy[npad:-npad]
x = pad_and_convolve(x, kernel)
y = pad_and_convolve(y, kernel)
if np.count_nonzero(np.isnan(x)) != 0:
num_start_nan = 0
for val in x:
if np.isnan(val):
num_start_nan += 1
else:
break
if num_start_nan == len(x):
self.x = np.zeros(len(x))
self.y = np.zeros(len(y))
return self.x, self.y
if num_start_nan != 0:
np.put(x, np.arange(0, num_start_nan, 1), x[num_start_nan])
np.put(y, np.arange(0, num_start_nan, 1), y[num_start_nan])
if np.count_nonzero(np.isnan(x)) != 0:
num_end_nan = 0
for val in x[::-1]:
if np.isnan(val):
num_end_nan += 1
else:
break
from_end = len(x) - num_end_nan
back = num_end_nan + 1
if num_end_nan != 0:
np.put(x, np.arange(from_end, len(x), 1), x[-back])
np.put(y, np.arange(from_end, len(x), 1), y[-back])
self.x = x / self.pixels_per_cm
self.y = y / self.pixels_per_cm
return x, y
except BaseException:
print(f"No position information found in {self.pos_file}")
[docs] def get_speed(self):
return self.speed
[docs] def calculate_speed(self, num_samples=5, smooth_size=5, smooth=False):
"""
Calculate the speed.
Performs as follows:
1. Get the box smoothed position data.
2. Calculate the speed at 10Hz (real sample rate is 50Hz).
2a. Do this calculating the speed at time x by using positions at
time x + 0.1, and x - 0.1. Want the real time point in the middle.
4. Interpolate these values to get speed at every time point(50Hz)
5. Smooth the interpolated speeds to remove bumps around sample times.
"""
x, y = self.get_position()
def pad_and_convolve(xx, kernel):
npad = len(kernel)
xx = np.pad(xx, (npad, npad), "edge")
yy = np.convolve(xx, kernel, mode="same")
return yy[npad:-npad]
speed = [0]
s_rate = num_samples # 50 Hz is too fine grained
t_rate = 0.02 * s_rate
duration = len(x) * 0.02
for i in range(s_rate * 3 // 2, len(x), s_rate):
cm_dist = math.sqrt(
(x[i] - x[i - s_rate]) ** 2 + (y[i] - y[i - s_rate]) ** 2
)
# (pixel/s) - 300 pixels per metre * 100 (cm/s)
cms_speed = cm_dist / t_rate
speed.append(cms_speed)
xp = np.array(
[0.0] + [0.02 * i for i in range(s_rate, len(x) - (s_rate // 2), s_rate)]
)
xs = np.arange(0, duration, 0.02)
kernel_size = smooth_size
interp_speed = np.interp(xs, xp, speed)
if smooth:
kernel = np.ones(kernel_size) / kernel_size
interp_speed = pad_and_convolve(interp_speed, kernel)
self.speed = interp_speed
return interp_speed
[docs] def calculate_angular(self):
bigx, bigy, smallx, smally = self.get_position(raw=True)
angles = np.zeros(len(bigx))
# for i in range(len(bigx)):
# A = np.array([bigx[i], bigy[i]]) + np.array([1, 0])
# B = np.array([bigx[i], bigy[i]])
# C = np.array([smallx[i], smally[i]])
# try:
# angle = angle_between_points(A, B, C)
# angles[i] = angle
self.angular = angles
[docs] def get_angular_pos(self):
return self.angular