Kilosort4 API

run_kilosort

kilosort.run_kilosort.cluster_spikes(st, tF, ops, device, bfile, tic0=nan, progress_bar=None, clear_cache=False, verbose=False)[source]

Cluster spikes using graph-based methods.

Parameters:
  • st (np.ndarray) – 3-column array of peak time (in samples), template, and thresold amplitude for each spike.

  • tF (torch.Tensor) – PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs)

  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • device (torch.device) – Indicates whether pytorch operations should be run on cpu or gpu.

  • bfile (kilosort.io.BinaryFiltered) – Wrapped file object for handling data.

  • tic0 (float; default=np.nan.) – Start time of run_kilosort.

  • progress_bar (TODO; optional.) – Informs tqdm package how to report progress, type unclear.

  • clear_cache (bool; False.) – If True, force pytorch to clear cached cuda memory after some memory-intensive steps in the pipeline.

  • verbose (bool; False.) – If True, include additional debug-level logging statements.

Returns:

  • clu (np.ndarray) – 1D vector of cluster ids indicating which spike came from which cluster, same shape as st.

  • Wall (torch.Tensor) – PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs).

kilosort.run_kilosort.compute_drift_correction(ops, device, tic0=nan, progress_bar=None, file_object=None, clear_cache=False, verbose=False)[source]

Compute drift correction parameters and save them to ops.

Parameters:
  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • device (torch.device) – Indicates whether pytorch operations should be run on cpu or gpu.

  • tic0 (float; default=np.nan.) – Start time of run_kilosort.

  • progress_bar (TODO; optional.) – Informs tqdm package how to report progress, type unclear.

  • file_object (array-like file object; optional.) – Must have ‘shape’ and ‘dtype’ attributes and support array-like indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy array or memmap.

  • clear_cache (bool; False.) – If True, force pytorch to clear cached cuda memory after some memory-intensive steps in the pipeline.

  • verbose (bool; False.) – If true, include additional debug-level logging statements.

Returns:

  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • bfile (kilosort.io.BinaryFiltered) – Wrapped file object for handling data.

  • st0 (np.ndarray.) – Intermediate spike times variable with 6 columns. This is only used for generating the ‘Drift Scatter’ plot through the GUI.

kilosort.run_kilosort.compute_preprocessing(ops, device, tic0=nan, file_object=None)[source]

Compute preprocessing parameters and save them to ops.

Parameters:
  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • device (torch.device) – Indicates whether pytorch operations should be run on cpu or gpu.

  • tic0 (float; default=np.nan) – Start time of run_kilosort.

  • file_object (array-like file object; optional.) – Must have ‘shape’ and ‘dtype’ attributes and support array-like indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy array or memmap.

Returns:

ops

Return type:

dict

kilosort.run_kilosort.detect_spikes(ops, device, bfile, tic0=nan, progress_bar=None, clear_cache=False, verbose=False)[source]

Detect spikes via template deconvolution.

Parameters:
  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • device (torch.device) – Indicates whether pytorch operations should be run on cpu or gpu.

  • bfile (kilosort.io.BinaryFiltered) – Wrapped file object for handling data.

  • tic0 (float; default=np.nan.) – Start time of run_kilosort.

  • progress_bar (TODO; optional.) – Informs tqdm package how to report progress, type unclear.

  • clear_cache (bool; False.) – If True, force pytorch to clear cached cuda memory after some memory-intensive steps in the pipeline.

  • verbose (bool; False.) – If true, include additional debug-level logging statements.

Returns:

  • st (np.ndarray) – 3-column array of peak time (in samples), template, and thresold amplitude for each spike.

  • clu (np.ndarray) – 1D vector of cluster ids indicating which spike came from which cluster, same shape as st.

  • tF (torch.Tensor) – PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs)

  • Wall (torch.Tensor) – PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs).

kilosort.run_kilosort.get_run_parameters(ops) list[source]

Get ops dict values needed by run_kilosort subroutines.

kilosort.run_kilosort.initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, device, save_preprocessed_copy, gui_mode=False) dict[source]

Package settings and probe information into a single ops dictionary.

kilosort.run_kilosort.load_sorting(results_dir, device=None, load_extra_vars=False)[source]

Load saved sorting results into memory.

Parameters:
  • results_dir (str or pathlib.Path) – Directory where results were saved.

  • device (torch.device; optional.) – CPU or GPU device to use to load Pytorch tensors. By default, PyTorch will use the first detected GPU. If no GPUs are detected, CPU will be used. To set this manually, specify device = torch.device(<device_name>). See PyTorch documentation for full description.

  • load_extra_vars (default=False.) – If True, load tF, Wall, and full copies of st, clu, and spike amplitudes in addition to the other variables.

Returns:

  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • st (np.ndarray) – 1D vector of spike times (in samples) for all clusters. This is only the first column of the 3-column array returned by run_kilosort.

  • clu (np.ndarray) – 1D vector of cluster ids indicating which spike came from which cluster, same shape as st.

  • similar_templates (np.ndarray.) – Similarity score between each pair of clusters, computed as correlation between clusters. Shape (n_clusters, n_clusters).

  • is_ref (np.ndarray.) – 1D boolean array with shape (n_clusters,) indicating whether each cluster is refractory.

  • est_contam_rate (np.ndarray.) – Contamination rate for each cluster, computed as fraction of refractory period violations relative to expectation based on a Poisson process. Shape (n_clusters,).

  • kept_spikes (np.ndarray.) – Boolean mask with shape (n_spikes,) that is False for spikes that were removed by kilosort.postprocessing.remove_duplicate_spikes and True otherwise.

  • tF (torch.Tensor.) – Only returned if load_extra_vars is True. PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs)

  • Wall (torch.Tensor.) – Only returned if load_extra_vars is True. PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs).

  • full_st (np.ndarray.) – Only returned if load_extra_vars is True. 3-column array of peak time (in samples), template, and threshold amplitude for each spike. Includes spikes removed by kilosort.postprocessing.remove_duplicate_spikes.

  • full_clu (np.ndarray.) – Only returned if load_extra_vars is True. 1D vector of cluster ids indicating which spike came from which cluster, same shape as st[:,0]. Includes spikes removed by kilosort.postprocessing.remove_duplicate_spikes.

  • full_amp (np.ndarray.) – Only returned if load_extra_vars is True. Per-spike amplitudes, computed as the L2 norm of the PC features for each spike. Includes spikes removed by kilosort.postprocessing.remove_duplicate_spikes.

kilosort.run_kilosort.run_kilosort(settings, probe=None, probe_name=None, filename=None, data_dir=None, file_object=None, results_dir=None, data_dtype=None, do_CAR=True, invert_sign=False, device=None, progress_bar=None, save_extra_vars=False, clear_cache=False, save_preprocessed_copy=False, bad_channels=None, shank_idx=None, verbose_console=False, verbose_log=False, torch_thread_lim=None)[source]

Run full spike sorting pipeline on specified data.

Parameters:
  • settings (dict) –

    Specifies a number of configurable parameters used throughout the spike sorting pipeline. See kilosort/parameters.py for a full list of available parameters. NOTE: n_chan_bin must be specified here, but all other settings are

    optional.

  • probe (dict; optional.) – A Kilosort4 probe dictionary, as returned by kilosort.io.load_probe.

  • probe_name (str; optional.) – Filename of probe to use, within the default PROBE_DIR. Only include the filename without any preceeding directories. Will ony be used if probe is None. Alternatively, the full filepath to a probe stored in any directory can be specified with settings = {‘probe_path’: …}. See kilosort.utils for default PROBE_DIR definition.

  • filename (Path-like or list of Path-likes; optional.) – Full path to binary data file(s). If specified, will also set data_dir = filename.parent. If filename is a list, files will be treated as a single recording concatenated in time in the order provided.

  • data_dir (str or Path; optional.) – Specifies directory where binary data file is stored. Kilosort will attempt to find the binary file. This works best if there is exactly one file in the directory with a .bin, .bat, .dat, or .raw extension. Only used if filename is None. Also see kilosort.io.find_binary.

  • file_object (array-like file object; optional.) – Must have ‘shape’ and ‘dtype’ attributes and support array-like indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy array or memmap. Must specify a valid filename as well, even though data will not be directly loaded from that file.

  • results_dir (str or Path; optional.) – Directory where results will be stored. By default, will be set to data_dir / ‘kilosort4’.

  • data_dtype (str or type; optional.) – dtype of data in binary file, like ‘int32’ or np.uint16. By default, dtype is assumed to be ‘int16’.

  • do_CAR (bool; default=True.) – If True, apply common average reference during preprocessing (recommended).

  • invert_sign (bool; default=False.) – If True, flip positive/negative values in data to conform to standard expected by Kilosort4.

  • device (torch.device; optional.) – CPU or GPU device to use for PyTorch calculations. By default, PyTorch will use the first detected GPU. If no GPUs are detected, CPU will be used. To set this manually, specify device = torch.device(<device_name>). See PyTorch documentation for full description.

  • progress_bar (tqdm.std.tqdm or QtWidgets.QProgressBar; optional.) – Used by sorting steps and GUI to track sorting progress. Users should not need to specify this.

  • save_extra_vars (bool; default=False.) – If True, save tF and Wall to disk after sorting.

  • clear_cache (bool; default=False.) – If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting clear_cache=True is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting.

  • save_preprocessed_copy (bool; default=False.) – If True, save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data.

  • bad_channels (list; optional.) – A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary.

  • shank_idx (float or list; optional.) – If not None, only channels from the specified shank index will be used. If a list is provided, each shank will be sorted sequentially and results will be saved in separate subfolders. Note that the shank_idx value(s) must match the actual value specified in probe[‘kcoords’]. For example, probe_idx=0 will not work if probe[‘kcoords’] uses 1,2,3,4.

  • verbose_console (bool; default=False.) – If True, set logging level for console output to DEBUG instead of INFO, so that additional information normally only saved to the log file will also show up in real time while sorting.

  • verbose_log (bool; default=False.) – If True, include additional debug-level logging statements for some steps. This provides more detail for debugging, but may impact performance.

  • torch_thread_lim (int; optional.) – If set, this will limit the number of pytorch threads on CPU. See docs for torch.set_num_threads.

Raises:

ValueError – If settings[n_chan_bin] is None (default). User must specify, for example: run_kilosort(settings={‘n_chan_bin’: 385}).

Returns:

  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • st (np.ndarray) – 3-column array of peak time (in samples), template, and thresold amplitude for each spike.

  • clu (np.ndarray) – 1D vector of cluster ids indicating which spike came from which cluster, same shape as st[:,0].

  • tF (torch.Tensor) – PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs)

  • Wall (torch.Tensor) – PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs).

  • similar_templates (np.ndarray.) – Similarity score between each pair of clusters, computed as correlation between clusters. Shape (n_clusters, n_clusters).

  • is_ref (np.ndarray.) – 1D boolean array with shape (n_clusters,) indicating whether each cluster is refractory.

  • est_contam_rate (np.ndarray.) – Contamination rate for each cluster, computed as fraction of refractory period violations relative to expectation based on a Poisson process. Shape (n_clusters,).

  • kept_spikes (np.ndarray.) – Boolean mask with shape (n_spikes,) that is False for spikes that were removed by kilosort.postprocessing.remove_duplicate_spikes and True otherwise.

Notes

For documentation of saved files, see kilosort.io.save_to_phy.

kilosort.run_kilosort.save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=nan, save_extra_vars=False, save_preprocessed_copy=False, skip_dat_path=False)[source]

Save sorting results, and format them for use with Phy

Parameters:
  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • results_dir (pathlib.Path) – Directory where results should be saved.

  • st (np.ndarray) – 3-column array of peak time (in samples), template, and thresold amplitude for each spike.

  • clu (np.ndarray) – 1D vector of cluster ids indicating which spike came from which cluster, same shape as st[:,0].

  • tF (torch.Tensor) – PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs)

  • Wall (torch.Tensor) – PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs).

  • imin (int) – Minimum sample index used by BinaryRWFile, exported spike times will be shifted forward by this number.

  • tic0 (float; default=np.nan.) – Start time of run_kilosort.

  • save_extra_vars (bool; default=False.) – If True, save tF and Wall to disk along with copies of st, clu and amplitudes with no postprocessing applied.

  • save_preprocessed_copy (bool; default=False.) – If True, save a pre-processed copy of the data (including drift correction) to temp_wh.dat in the results directory and format Phy output to use that copy of the data.

  • skip_dat_path (bool; default=False.) – If True, will save dat_path = ‘no_path.bin’ in params.py in place of a real filename. This is done to prevent an error in Phy when filename has an unexpected format, like when using a file_object loaded from an external data format through SpikeInterface. The full filename(s) will still be included in params.py for reference, but will be commented out.

Returns:

  • ops (dict) – Dictionary storing settings and results for all algorithmic steps.

  • similar_templates (np.ndarray.) – Similarity score between each pair of clusters, computed as correlation between clusters. Shape (n_clusters, n_clusters).

  • is_ref (np.ndarray.) – 1D boolean array with shape (n_clusters,) indicating whether each cluster is refractory.

  • est_contam_rate (np.ndarray.) – Contamination rate for each cluster, computed as fraction of refractory period violations relative to expectation based on a Poisson process. Shape (n_clusters,).

  • kept_spikes (np.ndarray.) – Boolean mask with shape (n_spikes,) that is False for spikes that were removed by kilosort.postprocessing.remove_duplicate_spikes and True otherwise.

Notes

For documentation of saved files, see kilosort.io.save_to_phy.

kilosort.run_kilosort.set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels, shank_idx)[source]

Parse file and directory information for data, probe, and results.