Over the last few years, it has been exciting to see the xarray project evolve, add new functionality, and mature. This post is an attempt at giving xarray another visit to see how it could integrate into electrophysiology workflows.
A quick background on our data¶
It is common in neuroscience to ask individuals to perform a task over and over again. You record
the activity in the brain each time they perform the task (called an “epoch” or a “trial”).
Time is recorded relative to some onset when the task begins. That is t==0
. The result
is usually a matrix of epochs x channejupyls x time
. You can do a lot of stuff with this
data, but our task in this paper is to detect changes in neural activity at trial onset (t==0
).
In our case, we’ve got a small dataset from an old paper of mine. The repository contains several tutorial notebooks and sample data to describe predictive modeling in cognitive neuroscience. You can find the repository here. The task that individuals were performing was passively listening to spoken sentences through a speaker. While they did this, we recorded electrical activity at the surface of their brain (these were surgical patients, and had implanted electrodes under their scalp).
In the Feature Extraction notebook, I covered how to do some simple data manipulation and feature extraction with timeseries analysis. Let’s try to re-create some of the main steps in that tutorial, but now using xarray as an in-memory structure for our data.
Note: The goal here is to learn a bit about xarray moreso than to discuss ecog modeling, so I’ll spend more time talking about my thoughts on the various functions/methods/etc in Xarray than talking about neuroscience.
In this post, we’ll perform a few common processing and extraction steps. The goal is to do a few munging operations that require manipulating data and visualizing simple statistics.
# Imports we'll use later
import mne
import numpy as np
import matplotlib.pyplot as plt
from download import download
import os
from sklearn.preprocessing import scale
import xarray as xr
xr.set_options(display_style="html")
import warnings
warnings.simplefilter('ignore')
%matplotlib inline
We’ll load the data from my GitHub repository (probably not the most efficient way to store or retrieve the data, but hey, this was 3 years ago :-) ).
url_epochs = "https://github.com/choldgraf/paper-encoding_decoding_electrophysiology/blob/master/raw_data/ecog-epo.fif?raw=true"
path_data = download(url_epochs, './ecog-epo.fif', replace=True)
ecog = mne.read_epochs(path_data, preload=True)
os.remove(path_data)
file_sizes: 0%| | 0.00/8.36M [00:00<?, ?B/s]
Downloading data from https://raw.githubusercontent.com/choldgraf/paper-encoding_decoding_electrophysiology/master/raw_data/ecog-epo.fif?raw=true (8.0 MB)
file_sizes: 100%|██████████████████████████| 8.36M/8.36M [00:00<00:00, 12.5MB/s]
Successfully downloaded file to ./ecog-epo.fif
Reading ./ecog-epo.fif ...
Isotrak not found
Found the data of interest:
t = -1500.00 ... 5996.67 ms
0 CTF compensation matrices available
29 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Here’s what the raw data looks like - each horizontal line is electrical activity in a channel over time. The faint vertical green lines show the onset of each trial (they are concatenated together, but in reality there’s a bit of time between trials). This will be one of the last times we use MNE hopefully.
_ = ecog.plot(scalings='auto', n_epochs=5, n_channels=10)
Converting to xarray¶
First off, we’ll define a helper function that converts the MNE Epochs object into an xarray DataArray object. DataArrays provide an N-Dimensional representation of data, but with the option to include a lot of extra metadata.
DataArrays are useful because you can include information about each dimension of the data. For example, we can tell our DataArray the name, values, and units of each dimension. In this case, in our case one dimension is “time” so we can label it as such.
def epochs_to_dataarray(epochs):
"""A simple function to convert an Epochs object to DataArray"""
da = xr.DataArray(
epochs._data,
dims=['epoch', 'channel', 'time'],
coords={
'time': ecog.times,
'channel': ecog.ch_names,
'epoch': range(ecog._data.shape[0])
},
name='Sample dataset',
attrs=dict(ecog.info)
)
return da
Just look at all the metadata that we were able to pack into the DataArray.
Almost all of MNE’s metadata fit nicely into .attrs
.
# There's quite a lot of output, so keep scrolling down!
da = epochs_to_dataarray(ecog)
da
The data consists of many trials, channels, and timepoints. Let’s start by selecting a time region within each trial that we can visualize more cleanly.
Subsetting out data with da.sel
¶
In xarray, we select items with the sel
and isel
method. This
behaves kind of like the pandas loc
and iloc
methods, however
because we have named dimensions, we can directly specify them in
our call.
# We'll drop a subset of timepoints for visualization
da = da.sel(time=slice(-1, 3))
Now let’s calculate the average across all epochs for each electrode/time point.
This is a reduction of our data array, in that it reduces the number of dimensions.
Xarray has many of the same statistical methods that NumPy does. An interesting
twist is that you can specify named dimensions instead of simply an axis=<integer>
argument. In addition, we’ll choose the colors that we’ll use for cycling through
our channels - because we can quickly reference the channels axis by name, we don’t
need to remember which axis corresponds to channels.
fig, ax = plt.subplots(figsize=(15, 5))
n_channels = da['channel'].shape[0]
ax.set_prop_cycle(color=plt.cm.viridis(np.linspace(0, 1, n_channels)))
da.mean(dim='epoch').plot.line(x='time', hue='channel')
ax.get_legend().remove()
It doesn’t look like much is going on...let’s see if we can clean it up a bit.
De-meaning the data with da.where
¶
First off - we’ll subtract the “pre-baseline mean” from each trial. This makes it easier to visualize how each channel’s activity changed at time == 0.
To accomplish this we’ll use da.where
. This takes some kind of
boolean-style mask, does a bunch of clever projections according to the
names of coordinates, and returns the dataarray masked values removed
(as NaN
s) and other values unchanged. We can use this to calculate the
mean of each channel / epoch only for the pre-baseline timepoints.
# This returns a version of the data array with NaNs where the query is False
# The dimensions will intelligently broadcast
prebaseline_mean = da.where(da.time < 0).mean(dim='time')
da_demeaned = da - prebaseline_mean
Now we can visualize the de-baseline-meaned data
fig, ax = plt.subplots(figsize=(15, 5))
ax.set_prop_cycle(color=plt.cm.viridis(np.linspace(0, 1, da['channel'].shape[0])))
da_demeaned.mean(dim='epoch').plot.line(x='time', hue='channel')
ax.get_legend().remove()
Hmmm, there still doesn’t seem to be much going on (that channel down at the bottom looks noisy to me, rather than having a meaningful signal) so let’s transform this signal into something with a bit more SNR to it.
Extracting a more useful feature with xr.apply_ufunc
¶
Without going into too much details on the neuroscience, iEEG data is particularly useful because there is information about neuronal activity in the higher frequency parts of the signal (AKA, parts of the electrical signal that change very quickly, but have very low amplitude). To pull that out, we’ll do the following:
- High-pass filter the signal, which will remove all the slow-moving components
- Calculate the envelope of the signal, which will tell us the power of high-frequency activity over time.
High-pass filtering the signal¶
MNE has a lot of nice functions for filtering a timeseries. Most of these
operate on numpy arrays instead of MNE objects. We’ll use
xarray’s apply_ufunc
function to simply map that function onto our dataarray.
xarray should keep track of the metadata (e.g. coordinates etc) and output a
new DataArray with updated values.
flow = 80
fhigh = 140
da_lowpass = xr.apply_ufunc(
mne.filter.filter_data, da,
kwargs=dict(
sfreq=da.sfreq,
l_freq=flow,
h_freq=fhigh,
)
)
Setting up band-pass filter from 80 - 1.4e+02 Hz
FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 80.00
- Lower transition bandwidth: 20.00 Hz (-6 dB cutoff frequency: 70.00 Hz)
- Upper passband edge: 140.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 145.00 Hz)
- Filter length: 99 samples (0.330 sec)
Visualizing our data, we can see all the slower fluctuations (e.g. long arcs over time) are gone.
fig, ax = plt.subplots(figsize=(15, 5))
da_lowpass.mean(dim='epoch').plot.line(x='time')
ax.get_legend().remove()
Calculate the envelope of this signal with da.groupby
¶
Next, we’ll calculate the envelope of the high-pass-filtered data. This is roughly the power that is present in these high frequencies over time. We do so by using something called a hilbert transform.
MNE also has a function for applying Hilbert transforms to data, but it has a weird quirk
that expects the data to be of a particular shape. We can work around this by using our
DataArray’s groupby
method. This works similar to DataFrame.groupby
- we’ll iterate
through each channel, which will return a DataArray with shape epochs x timepoints
.
We can then calculate the Hilbert transform in each and re-combine into the original shape.
Note: This can be an expensive operation depending on the number of channels/epochs and the length of each trial. This might be a good place to insert paralellization via Dask.
def hilbert_2d(array):
"""Perform a Hilbert transforms on an (n_channels, n_times) array."""
for ii, channel in enumerate(array):
array[ii] = mne.filter._my_hilbert(channel, envelope=True)
return array
da_hf_power = da_lowpass.groupby(da.coords['epoch']).apply(hilbert_2d)
The output dataarray should be the exact same shape, because we haven’t done any dimensional reductions. If we take a look at the resulting data, we can see what seems to be more structure in there:
fig, ax = plt.subplots(figsize=(15, 5))
da_hf_power.mean(dim='epoch').plot.line(x='time', hue='channel')
ax.get_legend().remove()
Cleaning up our HFA data¶
Next let’s clean up this high-frequency activity (HFA) data.
Z-scoring our array¶
Instead of simple de-meaning the data like before, we’ll re-scale our data using the same baseline timepoints. What we’d like to do is the following:
- Calculate the mean and standard deviation across trials of all pre-baseline data values, per channel
- Z-score each channel using this mean and standard deviation
Once again we’ll use the groupby / apply combination to apply our function to subsets of the data.
# For each channel, apply a z-score that uses the mean/std of pre-baseline activity for all trials
def z_score(activity):
"""Take a DataArray and apply a z-score using the baseline"""
baseline = activity.where(activity.time < -.1 )
return (activity - np.nanmean(baseline)) / np.nanstd(baseline)
da_hf_zscored = da_hf_power.groupby('channel').apply(z_score)
Taking a look at the result, we can see a much cleaner separation of activity for
some of the channels after time==0
.
fig, ax = plt.subplots(figsize=(15, 5))
da_hf_zscored.mean(dim='epoch').plot.line(x='time', hue='channel')
ax.get_legend().remove()
Smoothing our HFA data¶
Finally, let’s smooth this HFA so it has less jitter to it, and pick a smaller window that removes some of the filtering artifacts at the edges.
We’ll use the same filter_data
function as before, but this time
applied with the .groupby
and .apply
combination to show two ways
of accomplishing the same thing. We’ll also use .sel
to pick a subset
of time for visualization
da_hf_zscored_lowpass = da_hf_zscored.groupby('epoch').apply(
mne.filter.filter_data,
sfreq=da.sfreq,
l_freq=None,
h_freq=10,
verbose=False
)
Note that quickly selecting a subset of timepoints if we used numpy is much more verbose. Here’s a quick comparison:
# Numpy alone
mask_time = (times > -.8) * (times < 2.8)
epoch_dim = 0
da_hf_zscored_lowpass[..., mask_time].mean(epoch_dim)
# xarray
da_hf_zscored_lowpass.sel(time=slice(-.8, 2.8)).mean(dim='epoch')
fig, ax = plt.subplots(figsize=(15, 5))
da_hf_zscored_lowpass.mean(dim='epoch').sel(time=slice(-.8, 2.8)).plot.line(x='time', hue='channel')
ax.get_legend().remove()
Now we can see there are clearly some channels that become active just after t==0
.
We can reduce our dataarray to a single dimension of “mean post-baseline activity in each channel”
and convert it to a DataFrame for further processing:
# Find the channel with the most activity by first converting to a dataframe
total_activity = da_hf_zscored_lowpass.sel(time=slice(0, 2)).mean(dim=['epoch', 'time'])
total_activity = total_activity.to_dataframe()
total_activity.head()
Let’s grab the channel with maximal activation to look into a bit further.
max_chan = total_activity.squeeze().sort_values(ascending=False).index[0]
Time frequency analysis¶
As a final step, let’s expand our DataArray and add another dimension. In the above steps we specifically focused on high-frequency activity. A more common approach is to first create a spectrogram of your data to see activity across many frequencies.
To do this, we’ll use another MNE function for creating a Time-Frequency Representation or TFR. We’ll define a range of frequencies, and apply MNE’s function directly on our DataArray. This will return a NumPy array with the filtered values.
frequencies = [2**ii for ii in np.arange(2, 9, .5)]
tfr = mne.time_frequency.tfr_array_morlet(
da,
sfreq=da.sfreq,
freqs=frequencies,
n_cycles=4,
)
# Take the absolute value to throw out the non-real parts of the numbers
tfr = np.abs(tfr)
tfr[:2, :2, :2, :2]
array([[[[160.04103045, 160.38413909],
[171.73704543, 175.09249553]],
[[283.28699104, 285.21855726],
[241.65630528, 245.77098295]]],
[[[ 93.99546124, 94.4406537 ],
[ 78.02050045, 79.15324341]],
[[ 47.85148993, 49.90845151],
[ 53.97221461, 54.55793674]]]])
Convert this data into a DataArray with .expand_dims
¶
Next, we’ll convert this into a DataArray by using the metadata from our original
DataArray. We can use the expand_dims
method to create a new dimension for our DataArray.
We’ll use this to store frequency information.
We’ll then reshape our new DataArray so that it matches the output of the MNE function,
and use the copy
method to create a new DataArray. By supplying the data=
argument
to copy, we directly insert the new data inside the generated DataArray.
da_tfr = (da
.expand_dims(frequency=frequencies)
.transpose('epoch', 'channel', 'frequency', 'time')
.copy(data=np.log(tfr))
)
We can now visualize this time-frequency representation over time
fig, ax = plt.subplots(figsize=(15, 5))
(da_tfr
.sel({'frequency': slice(None, 180), 'channel': max_chan})
.mean('epoch')
.plot.imshow(x='time', y='frequency')
)
Similar to our one-dimensional visualizations above, it can be hard to visualize relative changes in activity over a baseline (particularly because the amplitude scales inversely with the frequency).
Let’s apply a re-scaling function to our data so that
we can see things more clearly. This time we’ll use MNE’s rescale
function, which
acts similarly to our zscore
function above.
da_tfr_baselined = xr.apply_ufunc(
mne.baseline.rescale,
da_tfr,
kwargs={'times': da_tfr.coords['time'], 'baseline': (None, -.1), "mode": 'zscore'}
)
Applying baseline correction (mode: zscore)
again, the result should be a DataArray, so we can directly visualize it:
(da_tfr_baselined
.sel({'frequency': slice(None, 180), 'channel': max_chan, 'time': slice(-.8, 2.5)})
.mean('epoch')
.plot.imshow(x='time', y='frequency')
)
Now we can see a clear increase in activity in the higher frequencies at t==0
.
Combining the two with xr.merge
¶
Finally, let’s combine these two DataArrays into one. We know that they
share much of the same metadata - the first is “Amplitude of High-Frequency Activity”
and the second is “Time-frequency power”. We should be able to merge these
into a single xarray DataSet
, which will allow us to perform operations across
both by using their shared dimensions. DataSets are kind of like collections of
DataArrays, with assumptions that the DataArrays share some metadata or coordinates.
First, we’ll rename each DataArray so that we can merge them nicely. Then, we’ll simply
use the xr.merge
function, which tries to automatically figure out which dimensions are
shared based on their names and coordinate values.
da_tfr_baselined.name = "Time Frequency Representation"
da_hf_zscored_lowpass.name = "Low-pass filtered HFA"
ds = xr.merge([da_tfr_baselined, da_hf_zscored_lowpass])
ds
Since we’ve got a single dataset, we can grab subsets along each axis across both DataArrays at the same time. We’ll select a subset of channels, time, and frequency bands to visualize.
ds_plt = ds.sel({'channel': max_chan, 'frequency': slice(10, 150), 'time': slice(-.5, 2)})
Now, we’ll plot both the spectrogram and the HFA in the same Matplotlib figure. As you can see, these plots contain somewhat redundant information. The top plot tells us that there is a general increase in power for high-frequencies. The bottom plot gives us the average increase in power across the higher frequencies.
fig, (ax_tfr, ax_hfa) = plt.subplots(2, 1, figsize=(15, 10))
im = ds_plt['Time Frequency Representation'].mean('epoch').plot.imshow(x='time', y='frequency',
ax=ax_tfr)
ds_plt['Low-pass filtered HFA'].mean('epoch').plot.line(x='time', ax=ax_hfa)
Wrapping up¶
In all, I was pretty happy with what you can do using xarray’s DataArray
structure.
It’s pretty nice to be able to refer to axes by their names, and to make more intelligent
selection / slicing operations using their coordinate values. Moreover, this post is just
scratching the surface for how to use this information in a way that speeds up the exploration
and analysis post.
For example, we might have sped-up some feature extraction steps by using a distributed processing framework like Dask in the operations above. Dask integrates nicely with xarray, and offers a lot of interesting opportunities to parallelize interactive computation. I’ll explore that in another blog post.
Finally - the goal of this post has largely been to learn a bit more about xarray. This means I might be totally mis-using functionality, or missing something that would have made the above process much easier. If anybody has tips or thoughts on the code above, please do reach out!
- Holdgraf, C. R., Rieger, J. W., Micheli, C., Martin, S., Knight, R. T., & Theunissen, F. E. (2017). Encoding and Decoding Models in Cognitive Electrophysiology. Frontiers in Systems Neuroscience, 11. 10.3389/fnsys.2017.00061