Source code for kilosort.run_kilosort

import time
from pathlib import Path
import logging
import warnings
import platform
logger = logging.getLogger(__name__)

import numpy as np
import torch

import kilosort
from kilosort import (preprocessing, datashift, template_matching, clustering_qr, 
                      clustering_qr, io, spikedetect, CCG, PROBE_DIR)
from kilosort.parameters import DEFAULT_SETTINGS
from kilosort.utils import (
    log_performance, log_cuda_details, probe_as_string, ops_as_string,
    get_performance, log_sorting_summary, log_thread_count
    )
import kilosort.plots as kplots

RECOGNIZED_SETTINGS = list(DEFAULT_SETTINGS.keys())
RECOGNIZED_SETTINGS.extend([
    'filename', 'data_dir', 'results_dir', 'probe_name', 'probe_path',
])
# These get mixed in with the other parameters when running through the GUI.
# When using the API, these should NOT be included in a settings dictionary
# even if they share a name with run_kilosort options.
GUI_SETTINGS = [
    'data_file_path', 'probe', 'data_dtype', 'save_preprocessed_copy',
    'clear_cache', 'do_CAR', 'invert_sign', 'verbose_log'
]


[docs] def run_kilosort(settings, probe=None, probe_name=None, filename=None, data_dir=None, file_object=None, results_dir=None, data_dtype=None, do_CAR=True, invert_sign=False, device=None, progress_bar=None, save_extra_vars=False, clear_cache=False, save_preprocessed_copy=False, bad_channels=None, shank_idx=None, verbose_console=False, verbose_log=False, torch_thread_lim=None): """Run full spike sorting pipeline on specified data. Parameters ---------- settings : dict Specifies a number of configurable parameters used throughout the spike sorting pipeline. See `kilosort/parameters.py` for a full list of available parameters. NOTE: `n_chan_bin` must be specified here, but all other settings are optional. probe : dict; optional. A Kilosort4 probe dictionary, as returned by `kilosort.io.load_probe`. probe_name : str; optional. Filename of probe to use, within the default `PROBE_DIR`. Only include the filename without any preceeding directories. Will ony be used if `probe is None`. Alternatively, the full filepath to a probe stored in any directory can be specified with `settings = {'probe_path': ...}`. See `kilosort.utils` for default `PROBE_DIR` definition. filename: Path-like or list of Path-likes; optional. Full path to binary data file(s). If specified, will also set `data_dir = filename.parent`. If `filename` is a list, files will be treated as a single recording concatenated in time in the order provided. data_dir : str or Path; optional. Specifies directory where binary data file is stored. Kilosort will attempt to find the binary file. This works best if there is exactly one file in the directory with a .bin, .bat, .dat, or .raw extension. Only used if `filename is None`. Also see `kilosort.io.find_binary`. file_object : array-like file object; optional. Must have 'shape' and 'dtype' attributes and support array-like indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy array or memmap. Must specify a valid `filename` as well, even though data will not be directly loaded from that file. results_dir : str or Path; optional. Directory where results will be stored. By default, will be set to `data_dir / 'kilosort4'`. data_dtype : str or type; optional. dtype of data in binary file, like `'int32'` or `np.uint16`. By default, dtype is assumed to be `'int16'`. do_CAR : bool; default=True. If True, apply common average reference during preprocessing (recommended). invert_sign : bool; default=False. If True, flip positive/negative values in data to conform to standard expected by Kilosort4. device : torch.device; optional. CPU or GPU device to use for PyTorch calculations. By default, PyTorch will use the first detected GPU. If no GPUs are detected, CPU will be used. To set this manually, specify `device = torch.device(<device_name>)`. See PyTorch documentation for full description. progress_bar : tqdm.std.tqdm or QtWidgets.QProgressBar; optional. Used by sorting steps and GUI to track sorting progress. Users should not need to specify this. save_extra_vars : bool; default=False. If True, save tF and Wall to disk after sorting. clear_cache : bool; default=False. If True, force pytorch to free up memory reserved for its cache in between memory-intensive operations. Note that setting `clear_cache=True` is NOT recommended unless you encounter GPU out-of-memory errors, since this can result in slower sorting. save_preprocessed_copy : bool; default=False. If True, save a pre-processed copy of the data (including drift correction) to `temp_wh.dat` in the results directory and format Phy output to use that copy of the data. bad_channels : list; optional. A list of channel indices (rows in the binary file) that should not be included in sorting. Listing channels here is equivalent to excluding them from the probe dictionary. shank_idx : float or list; optional. If not None, only channels from the specified shank index will be used. If a list is provided, each shank will be sorted sequentially and results will be saved in separate subfolders. Note that the shank_idx value(s) must match the actual value specified in `probe['kcoords']`. For example, `probe_idx=0` will not work if `probe['kcoords']` uses 1,2,3,4. verbose_console : bool; default=False. If True, set logging level for console output to `DEBUG` instead of `INFO`, so that additional information normally only saved to the log file will also show up in real time while sorting. verbose_log : bool; default=False. If True, include additional debug-level logging statements for some steps. This provides more detail for debugging, but may impact performance. torch_thread_lim : int; optional. If set, this will limit the number of pytorch threads on CPU. See docs for `torch.set_num_threads`. Raises ------ ValueError If settings[`n_chan_bin`] is None (default). User must specify, for example: `run_kilosort(settings={'n_chan_bin': 385})`. Returns ------- ops : dict Dictionary storing settings and results for all algorithmic steps. st : np.ndarray 3-column array of peak time (in samples), template, and thresold amplitude for each spike. clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st[:,0]`. tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). similar_templates : np.ndarray. Similarity score between each pair of clusters, computed as correlation between clusters. Shape (n_clusters, n_clusters). is_ref : np.ndarray. 1D boolean array with shape (n_clusters,) indicating whether each cluster is refractory. est_contam_rate : np.ndarray. Contamination rate for each cluster, computed as fraction of refractory period violations relative to expectation based on a Poisson process. Shape (n_clusters,). kept_spikes : np.ndarray. Boolean mask with shape (n_spikes,) that is False for spikes that were removed by `kilosort.postprocessing.remove_duplicate_spikes` and True otherwise. Notes ----- For documentation of saved files, see `kilosort.io.save_to_phy`. """ if torch_thread_lim is not None: torch.set_num_threads(torch_thread_lim) # Configure settings, ops, and file paths if settings is None or settings.get('n_chan_bin', None) is None: raise ValueError( '`n_chan_bin` is a required setting. This is the total number of ' 'channels in the binary file, which may or may not be equal to the ' 'number of channels specified by the probe.' ) settings = {**DEFAULT_SETTINGS, **settings} # NOTE: This modifies settings in-place if not isinstance(shank_idx, list): shank_idx = [shank_idx] for idx in shank_idx: _filename, _data_dir, _results_dir, _probe = \ set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels, idx) setup_logger(_results_dir, verbose_console=verbose_console) ops, st, clu, tF, Wall, similar_templates, \ is_ref, est_contam_rate, kept_spikes = _sort( _filename, _results_dir, _probe, settings, data_dtype, device, do_CAR, clear_cache, invert_sign, save_preprocessed_copy, verbose_log, save_extra_vars, file_object, progress_bar, ) return ops, st, clu, tF, Wall, similar_templates, \ is_ref, est_contam_rate, kept_spikes
def _sort(filename, results_dir, probe, settings, data_dtype, device, do_CAR, clear_cache, invert_sign, save_preprocessed_copy, verbose_log, save_extra_vars, file_object, progress_bar, gui_sorter=None): """Run sorting pipeline. See `run_kilosort` for documentation. Notes ----- filename is expected to be a list of Paths at this point, even if it's a singleton list. """ try: logger.info(f"Kilosort version {kilosort.__version__}") logger.info(f"Python version {platform.python_version()}") logger.info('-'*40) logger.info('System information:') logger.info(f'{platform.platform()} {platform.machine()}') logger.info(platform.processor()) if device is None: if torch.cuda.is_available(): logger.info('Using GPU for PyTorch computations. ' 'Specify `device` to change this.') device = torch.device('cuda') else: logger.info('Using CPU for PyTorch computations. ' 'Specify `device` to change this.') device = torch.device('cpu') if device != torch.device('cpu'): memory = torch.cuda.get_device_properties(device).total_memory/1024**3 logger.info(f'Using CUDA device: {torch.cuda.get_device_name()} {memory:.2f}GB') logger.info('-'*40) if len(filename) == 1: logger.info(f"Sorting {filename}") else: logger.info(f"Sorting {filename[0].parent}/... (multiple files)") if data_dtype is None: logger.info( "Interpreting binary file as default dtype='int16'. If data was " "saved in a different format, specify `data_dtype`." ) data_dtype = 'int16' if not do_CAR: logger.info("Skipping common average reference.") if clear_cache: logger.info('clear_cache=True') if probe['chanMap'].max() >= settings['n_chan_bin']: raise ValueError( f'Largest value of chanMap exceeds channel count of data, ' 'make sure chanMap is 0-indexed.' ) tic0 = time.time() ops, settings = initialize_ops( settings, probe, data_dtype, do_CAR, invert_sign, device, save_preprocessed_copy, gui_mode=(gui_sorter is not None) ) # Pretty-print ops and probe for log logger.debug(f"Initial ops:\n\n{ops_as_string(ops)}\n") logger.debug(f"Probe dictionary:\n\n{probe_as_string(ops['probe'])}\n") # Baseline performance metrics log_performance(logger, 'info', 'Resource usage before sorting') log_thread_count(logger) # Set preprocessing and drift correction parameters ops = compute_preprocessing(ops, device, tic0=tic0, file_object=file_object) np.random.seed(1) torch.cuda.manual_seed_all(1) torch.random.manual_seed(1) ops, bfile, st0 = compute_drift_correction( ops, device, tic0=tic0, progress_bar=progress_bar, file_object=file_object, clear_cache=clear_cache, verbose=verbose_log ) log_thread_count(logger) # Save preprocessing steps if save_preprocessed_copy: io.save_preprocessing(results_dir / 'temp_wh.dat', ops, bfile) log_performance(logger, 'info', 'Resource usage after saving preprocessing.', reset=True) logger.info('Generating drift plots ...') # st0 will be None if nblocks = 0 (no drift correction) if st0 is not None: if gui_sorter is not None: gui_sorter.dshift = ops['dshift'] gui_sorter.st0 = st0 gui_sorter.plotDataReady.emit('drift') else: kplots.plot_drift_amount(ops, results_dir, tmin=settings['tmin']) kplots.plot_drift_scatter(st0, results_dir, tmin=settings['tmin']) # Sort spikes and save results st,tF, Wall0, clu0 = detect_spikes( ops, device, bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose_log ) log_thread_count(logger) logger.info('Generating diagnostic plots ...') if gui_sorter is not None: gui_sorter.Wall0 = Wall0 gui_sorter.wPCA = torch.clone(ops['wPCA'].cpu()).numpy() gui_sorter.clu0 = clu0 gui_sorter.plotDataReady.emit('diagnostics') else: kplots.plot_diagnostics(Wall0, clu0, ops, results_dir) clu, Wall, st, tF = cluster_spikes( st, tF, ops, device, bfile, tic0=tic0, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose_log, ) log_thread_count(logger) ops, similar_templates, is_ref, est_contam_rate, kept_spikes = \ save_sorting( ops, results_dir, st, clu, tF, Wall, bfile.imin, tic0, save_extra_vars=save_extra_vars, save_preprocessed_copy=save_preprocessed_copy, skip_dat_path=(file_object is not None) ) if torch.cuda.is_available(): ops['cuda_postproc'] = torch.cuda.memory_stats(device) log_thread_count(logger) logger.info('Generating spike position plot ...') if gui_sorter is not None: gui_sorter.clu = clu[kept_spikes] gui_sorter.is_refractory = is_ref gui_sorter.plotDataReady.emit('probe') else: kplots.plot_spike_positions(clu[kept_spikes], is_ref, results_dir) logger.info('Sorting finished.') log_sorting_summary(ops, log=logger, level='info') except Exception as e: if isinstance(e, torch.cuda.OutOfMemoryError): logger.exception('Out of memory error, printing performance...') log_cuda_details(logger) log_performance(logger, level='info') # This makes sure the full traceback is written to log file. logger.exception('Encountered error in `run_kilosort`:') # Annoyingly, this will print the error message twice for console, but # I haven't found a good way around that. raise finally: close_logger() return ops, st, clu, tF, Wall, similar_templates, \ is_ref, est_contam_rate, kept_spikes
[docs] def set_files(settings, filename, probe, probe_name, data_dir, results_dir, bad_channels, shank_idx): """Parse file and directory information for data, probe, and results.""" # Check for filename filename = settings.get('filename', None) if filename is None else filename # Use data_dir if filename not available if filename is None: data_dir = settings.get('data_dir', None) if data_dir is None else data_dir if data_dir is None: raise ValueError('no path to data provided, set "data_dir=" or "filename="') data_dir = Path(data_dir).resolve() if not data_dir.exists(): raise FileExistsError(f"data_dir '{data_dir}' does not exist") # Find binary file in the folder filename = io.find_binary(data_dir=data_dir) filename = [filename] else: if not isinstance(filename, list): filename = [filename] filename = [Path(f) for f in filename] for f in filename: if not f.exists(): raise FileExistsError(f"filename '{filename}' does not exist") data_dir = filename[0].parent # Convert paths to strings when saving to ops, otherwise ops can only # be loaded on the operating system that originally ran the code. settings['filename'] = filename settings['data_dir'] = data_dir # Try to set results_dir based on settings, otherwise use default. results_dir = settings.get('results_dir', None) if results_dir is None else results_dir results_dir = Path(results_dir).resolve() if results_dir is not None else None if results_dir is None: results_dir = data_dir / 'kilosort4' if shank_idx is not None: results_dir = results_dir / f'shank_{shank_idx}' # Make sure results directory exists results_dir.mkdir(exist_ok=True, parents=True) # find probe configuration file and load if probe is None: if probe_name is not None: probe_path = PROBE_DIR / probe_name elif 'probe_name' in settings: probe_path = PROBE_DIR / settings['probe_name'] elif 'probe_path' in settings: probe_path = Path(settings['probe_path']).resolve() else: raise ValueError('no probe_name or probe_path provided, set probe_name=') if not probe_path.exists(): raise FileExistsError(f"probe_path '{probe_path}' does not exist") probe = io.load_probe(probe_path) else: # Make sure xc, yc are float32, otherwise there are casting problems # with some pytorch functions. probe['xc'] = probe['xc'].astype(np.float32) probe['yc'] = probe['yc'].astype(np.float32) # Let user know if there are too many dimensions in probe entries. # Don't want to automatically flatten them incase they've made assumptions # about higher-D ordering. for k in ['xc', 'yc', 'kcoords', 'chanMap']: if probe[k].ndim > 1: raise ValueError(f"Array-valued probe entries should have 1 dim, " f"but key: {k} has ndim == {probe[k].ndim}.") if bad_channels is not None: probe = io.remove_bad_channels(probe, bad_channels) if shank_idx is not None: probe = io.select_shank(probe, shank_idx) return filename, data_dir, results_dir, probe
def setup_logger(results_dir, verbose_console=False): results_dir = Path(results_dir) # Get root logger for Kilosort application ks_log = logging.getLogger('kilosort') ks_log.setLevel(logging.DEBUG) # Add file handler at debug level, include timestamps and logging level # in text output. file = logging.FileHandler(results_dir / 'kilosort4.log', mode='w') file.setLevel(logging.DEBUG) text_format = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' file_formatter = logging.Formatter(text_format) file.setFormatter(file_formatter) # Skip this if the handlers were already added, like when running multiple # times in a single session. if not ks_log.handlers: # Add console handler at info level with shorter messages, # unless verbose is requested. console = logging.StreamHandler() if verbose_console: console.setLevel(logging.DEBUG) console.setFormatter(file_formatter) else: console.setLevel(logging.INFO) console_formatter = logging.Formatter('%(name)-12s: %(message)s') console.setFormatter(console_formatter) ks_log.addHandler(console) # Always add file handler since log file might change locations ks_log.addHandler(file) def close_logger(): ks_log = logging.getLogger('kilosort') for handler in ks_log.handlers.copy(): ks_log.removeHandler(handler) handler.close()
[docs] def initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, device, save_preprocessed_copy, gui_mode=False) -> dict: """Package settings and probe information into a single `ops` dictionary.""" settings = settings.copy() if settings['nt'] % 2 == 0: raise ValueError(f'`nt` must be odd, but got nt={settings["nt"]}') if settings['nt0min'] is None: settings['nt0min'] = int(20 * settings['nt']/61) if settings['max_channel_distance'] is None: # Default used to be None, now it's a constant. Adding this so that # cached settings values in the GUI don't cause disruption. settings['max_channel_distance'] = DEFAULT_SETTINGS['max_channel_distance'] if settings['nearest_chans'] > len(probe['chanMap']): msg = f""" Parameter `nearest_chans` must be less than or equal to the number of data channels being sorted.\n Changing from {settings['nearest_chans']} to {len(probe['chanMap'])}. """ warnings.warn(msg, UserWarning) settings['nearest_chans'] = len(probe['chanMap']) if 'duplicate_spike_bins' in settings: msg = """ The `duplicate_spike_bins` parameter has been replaced with `duplicate_spike_ms`. Specifying the former will have no effect, since it gets overwritten based on sampling rate. """ warnings.warn(msg, DeprecationWarning) dup_bins = int(settings['duplicate_spike_ms'] * (settings['fs']/1000)) # If running through GUI, also allow some additional relevant keys in # settings dictionary. recognized = RECOGNIZED_SETTINGS.copy() if gui_mode: recognized.extend(GUI_SETTINGS.copy()) # Raise an error if there are unrecognized settings entries to make users # aware if they've made a typo, are using a deprecated setting, etc. unrecognized = [] for k, _ in settings.items(): if k not in recognized: unrecognized.append(k) if len(unrecognized) > 0: logger.info('Unrecognized keys found in `settings`') logger.info('See `kilosort.run_kilosort.RECOGNIZED_SETTINGS`') raise ValueError(f'Unrecognized settings: {unrecognized}') # TODO: Clean this up during refactor. Lots of confusing duplication here. ops = settings.copy() ops['settings'] = settings ops['probe'] = probe ops['data_dtype'] = data_dtype ops['do_CAR'] = do_CAR ops['invert_sign'] = invert_sign ops['NTbuff'] = ops['batch_size'] + 2 * ops['nt'] ops['Nchan'] = len(probe['chanMap']) ops['n_chan_bin'] = settings['n_chan_bin'] ops['duplicate_spike_bins'] = dup_bins ops['torch_device'] = str(device) ops['save_preprocessed_copy'] = save_preprocessed_copy if not settings['templates_from_data'] and settings['nt'] != 61: raise ValueError('If using pre-computed universal templates ' '(templates_from_data=False), nt must be 61') ops = {**ops, **probe} return ops, settings
[docs] def get_run_parameters(ops) -> list: """Get `ops` dict values needed by `run_kilosort` subroutines.""" parameters = [ ops['settings']['n_chan_bin'], ops['settings']['fs'], ops['settings']['batch_size'], ops['settings']['nt'], ops['settings']['nt0min'], # also called twav_min ops['probe']['chanMap'], ops['data_dtype'], ops['do_CAR'], ops['invert_sign'], ops['probe']['xc'], ops['probe']['yc'], ops['settings']['tmin'], ops['settings']['tmax'], ops['settings']['artifact_threshold'], ops['settings']['shift'], ops['settings']['scale'], ops['settings']['batch_downsampling'] ] return parameters
[docs] def compute_preprocessing(ops, device, tic0=np.nan, file_object=None): """Compute preprocessing parameters and save them to `ops`. Parameters ---------- ops : dict Dictionary storing settings and results for all algorithmic steps. device : torch.device Indicates whether `pytorch` operations should be run on cpu or gpu. tic0 : float; default=np.nan Start time of `run_kilosort`. file_object : array-like file object; optional. Must have 'shape' and 'dtype' attributes and support array-like indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy array or memmap. Returns ------- ops : dict """ tic = time.time() logger.info(' ') logger.info('Computing preprocessing variables.') logger.info('-'*40) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, \ xc, yc, tmin, tmax, artifact, shift, scale, batch_downsampling = \ get_run_parameters(ops) nskip = ops['settings']['nskip'] whitening_range = ops['settings']['whitening_range'] # Compute high pass filter cutoff = ops['settings']['highpass_cutoff'] hp_filter = preprocessing.get_highpass_filter(fs, cutoff, device=device) # Compute whitening matrix bfile = io.BinaryFiltered(ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map, hp_filter, device=device, do_CAR=do_CAR, invert_sign=invert, dtype=dtype, tmin=tmin, tmax=tmax, artifact_threshold=artifact, shift=shift, scale=scale, file_object=file_object, batch_downsampling=batch_downsampling) logger.info(f'N samples: {bfile.n_samples}') logger.info(f'N seconds: {bfile.n_samples/fs}') logger.info(f'N batches: {bfile.n_batches}') whiten_mat = preprocessing.get_whitening_matrix(bfile, xc, yc, nskip=nskip, nrange=whitening_range) # Save results ops['Nbatches'] = bfile.n_batches ops['preprocessing'] = {} ops['preprocessing']['whiten_mat'] = whiten_mat ops['preprocessing']['hp_filter'] = hp_filter ops['Wrot'] = whiten_mat ops['fwav'] = hp_filter elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_preproc'] = elapsed ops['usage_preproc'] = get_performance() logger.info(f'Preprocessing filters computed in {elapsed:.2f}s; ' + f'total {total:.2f}s') logger.debug(f'hp_filter shape: {hp_filter.shape}') logger.debug(f'whiten_mat shape: {whiten_mat.shape}') # Check scale of data for log file b1 = bfile.padded_batch_to_torch(0).cpu().numpy() logger.debug(f"First batch min, max: {b1.min(), b1.max()}") log_performance(logger, 'info', 'Resource usage after preprocessing', reset=True) return ops
[docs] def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, file_object=None, clear_cache=False, verbose=False): """Compute drift correction parameters and save them to `ops`. Parameters ---------- ops : dict Dictionary storing settings and results for all algorithmic steps. device : torch.device Indicates whether `pytorch` operations should be run on cpu or gpu. tic0 : float; default=np.nan. Start time of `run_kilosort`. progress_bar : TODO; optional. Informs `tqdm` package how to report progress, type unclear. file_object : array-like file object; optional. Must have 'shape' and 'dtype' attributes and support array-like indexing (e.g. [:100,:], [5, 7:10], etc). For example, a numpy array or memmap. clear_cache : bool; False. If True, force pytorch to clear cached cuda memory after some memory-intensive steps in the pipeline. verbose : bool; False. If true, include additional debug-level logging statements. Returns ------- ops : dict Dictionary storing settings and results for all algorithmic steps. bfile : kilosort.io.BinaryFiltered Wrapped file object for handling data. st0 : np.ndarray. Intermediate spike times variable with 6 columns. This is only used for generating the 'Drift Scatter' plot through the GUI. """ tic = time.time() logger.info(' ') logger.info('Computing drift correction.') logger.info('-'*40) n_chan_bin, fs, NT, nt, twav_min, chan_map, dtype, do_CAR, invert, \ _, _, tmin, tmax, artifact, shift, scale, batch_downsampling = \ get_run_parameters(ops) hp_filter = ops['preprocessing']['hp_filter'] whiten_mat = ops['preprocessing']['whiten_mat'] bfile = io.BinaryFiltered( ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map, hp_filter=hp_filter, whiten_mat=whiten_mat, device=device, do_CAR=do_CAR, invert_sign=invert, dtype=dtype, tmin=tmin, tmax=tmax, artifact_threshold=artifact, shift=shift, scale=scale, file_object=file_object, batch_downsampling=batch_downsampling ) ops, st = datashift.run(ops, bfile, device=device, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose) elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_drift'] = elapsed ops['usage_drift'] = get_performance() if torch.cuda.is_available(): ops['cuda_drift'] = torch.cuda.memory_stats() logger.info(f'drift computed in {elapsed:.2f}s; total {total:.2f}s') if st is not None: logger.debug(f'st shape: {st.shape}') logger.debug(f'yblk shape: {ops["yblk"].shape}') logger.debug(f'dshift shape: {ops["dshift"].shape}') logger.debug(f'iKxx shape: {ops["iKxx"].shape}') # binary file with drift correction bfile = io.BinaryFiltered( ops['filename'], n_chan_bin, fs, NT, nt, twav_min, chan_map, hp_filter=hp_filter, whiten_mat=whiten_mat, device=device, dshift=ops['dshift'], do_CAR=do_CAR, dtype=dtype, tmin=tmin, tmax=tmax, artifact_threshold=artifact, shift=shift, scale=scale, file_object=file_object, batch_downsampling=batch_downsampling ) log_cuda_details(logger) log_performance(logger, 'info', 'Resource usage after drift correction', reset=True) return ops, bfile, st
[docs] def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None, clear_cache=False, verbose=False): """Detect spikes via template deconvolution. Parameters ---------- ops : dict Dictionary storing settings and results for all algorithmic steps. device : torch.device Indicates whether `pytorch` operations should be run on cpu or gpu. bfile : kilosort.io.BinaryFiltered Wrapped file object for handling data. tic0 : float; default=np.nan. Start time of `run_kilosort`. progress_bar : TODO; optional. Informs `tqdm` package how to report progress, type unclear. clear_cache : bool; False. If True, force pytorch to clear cached cuda memory after some memory-intensive steps in the pipeline. verbose : bool; False. If true, include additional debug-level logging statements. Returns ------- st : np.ndarray 3-column array of peak time (in samples), template, and thresold amplitude for each spike. clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st`. tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). """ tic = time.time() logger.info(' ') logger.info(f'Extracting spikes using templates') logger.info('-'*40) st0, tF, ops = spikedetect.run( ops, bfile, device=device, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose ) tF = torch.from_numpy(tF) elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_st0'] = elapsed ops['usage_st0'] = get_performance() if torch.cuda.is_available(): ops['cuda_st0'] = torch.cuda.memory_stats(device) logger.info(f'{len(st0)} spikes extracted in {elapsed:.2f}s; ' + f'total {total:.2f}s') logger.debug(f'st0 shape: {st0.shape}') logger.debug(f'tF shape: {tF.shape}') if len(st0) == 0: raise ValueError('No spikes detected, cannot continue sorting.') log_performance(logger, 'info', 'Resource usage after spike detect (univ)', reset=True) log_thread_count(logger) tic = time.time() logger.info(' ') logger.info('First clustering') logger.info('-'*40) clu, Wall = clustering_qr.run( ops, st0, tF, mode='spikes', device=device, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose ) Wall3 = template_matching.postprocess_templates( Wall, ops, clu, st0, tF, device=device ) elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_clu0'] = elapsed ops['usage_clu0'] = get_performance() if torch.cuda.is_available(): ops['cuda_clu0'] = torch.cuda.memory_stats(device) logger.info(f'{clu.max()+1} clusters found, in {elapsed:.2f}s; ' + f'total {total:.2f}s') logger.debug(f'clu shape: {clu.shape}') logger.debug(f'Wall shape: {Wall.shape}') log_performance(logger, 'info', 'Resource usage after first clustering', reset=True) log_thread_count(logger) tic = time.time() logger.info(' ') logger.info('Extracting spikes using cluster waveforms') logger.info('-'*40) st, tF, ops = template_matching.extract( ops, bfile, Wall3, device=device, progress_bar=progress_bar ) log_thread_count(logger) elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_st'] = elapsed ops['usage_st'] = get_performance() if torch.cuda.is_available(): ops['cuda_st'] = torch.cuda.memory_stats(device) logger.info(f'{len(st)} spikes extracted in {elapsed:.2f}s; ' + f'total {total:.2f}s') logger.debug(f'st shape: {st.shape}') logger.debug(f'tF shape: {tF.shape}') logger.debug(f'iCC shape: {ops["iCC"].shape}') logger.debug(f'iU shape: {ops["iU"].shape}') log_cuda_details(logger) log_performance(logger, 'info', 'Resource usage after spike detect (learned)', reset=True) return st, tF, Wall, clu
[docs] def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None, clear_cache=False, verbose=False): """Cluster spikes using graph-based methods. Parameters ---------- st : np.ndarray 3-column array of peak time (in samples), template, and thresold amplitude for each spike. tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) ops : dict Dictionary storing settings and results for all algorithmic steps. device : torch.device Indicates whether `pytorch` operations should be run on cpu or gpu. bfile : kilosort.io.BinaryFiltered Wrapped file object for handling data. tic0 : float; default=np.nan. Start time of `run_kilosort`. progress_bar : TODO; optional. Informs `tqdm` package how to report progress, type unclear. clear_cache : bool; False. If True, force pytorch to clear cached cuda memory after some memory-intensive steps in the pipeline. verbose : bool; False. If True, include additional debug-level logging statements. Returns ------- clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st`. Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). """ tic = time.time() logger.info(' ') logger.info('Final clustering') logger.info('-'*40) clu, Wall = clustering_qr.run( ops, st, tF, mode = 'template', device=device, progress_bar=progress_bar, clear_cache=clear_cache, verbose=verbose ) elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_clu'] = elapsed ops['usage_clu'] = get_performance() if torch.cuda.is_available(): ops['cuda_clu'] = torch.cuda.memory_stats(device) logger.info(f'{clu.max()+1} clusters found, in {elapsed:.2f}s; ' + f'total {total:.2f}s') logger.debug(f'clu shape: {clu.shape}') logger.debug(f'Wall shape: {Wall.shape}') log_thread_count(logger) tic = time.time() logger.info(' ') logger.info('Merging clusters') logger.info('-'*40) Wall, clu, is_ref, st, tF = template_matching.merging_function( ops, Wall, clu, st, tF, device=device, check_dt=True ) clu = clu.astype('int32') elapsed = time.time() - tic total = time.time() - tic0 ops['runtime_merge'] = elapsed ops['usage_merge'] = get_performance() if torch.cuda.is_available(): ops['cuda_merge'] = torch.cuda.memory_stats(device) logger.info(f'{clu.max()+1} units found, in {elapsed:.2f}s; ' + f'total {total:.2f}s') logger.debug(f'clu shape: {clu.shape}') logger.debug(f'Wall shape: {Wall.shape}') log_cuda_details(logger) log_performance(logger, 'info', 'Resource usage after clustering', reset=True) return clu, Wall, st, tF
[docs] def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan, save_extra_vars=False, save_preprocessed_copy=False, skip_dat_path=False): """Save sorting results, and format them for use with Phy Parameters ------- ops : dict Dictionary storing settings and results for all algorithmic steps. results_dir : pathlib.Path Directory where results should be saved. st : np.ndarray 3-column array of peak time (in samples), template, and thresold amplitude for each spike. clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st[:,0]`. tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). imin : int Minimum sample index used by BinaryRWFile, exported spike times will be shifted forward by this number. tic0 : float; default=np.nan. Start time of `run_kilosort`. save_extra_vars : bool; default=False. If True, save tF and Wall to disk along with copies of st, clu and amplitudes with no postprocessing applied. save_preprocessed_copy : bool; default=False. If True, save a pre-processed copy of the data (including drift correction) to `temp_wh.dat` in the results directory and format Phy output to use that copy of the data. skip_dat_path : bool; default=False. If True, will save `dat_path = 'no_path.bin'` in `params.py` in place of a real filename. This is done to prevent an error in Phy when filename has an unexpected format, like when using a `file_object` loaded from an external data format through SpikeInterface. The full filename(s) will still be included in `params.py` for reference, but will be commented out. Returns ------- ops : dict Dictionary storing settings and results for all algorithmic steps. similar_templates : np.ndarray. Similarity score between each pair of clusters, computed as correlation between clusters. Shape (n_clusters, n_clusters). is_ref : np.ndarray. 1D boolean array with shape (n_clusters,) indicating whether each cluster is refractory. est_contam_rate : np.ndarray. Contamination rate for each cluster, computed as fraction of refractory period violations relative to expectation based on a Poisson process. Shape (n_clusters,). kept_spikes : np.ndarray. Boolean mask with shape (n_spikes,) that is False for spikes that were removed by `kilosort.postprocessing.remove_duplicate_spikes` and True otherwise. Notes ----- For documentation of saved files, see `kilosort.io.save_to_phy`. """ tic = time.time() logger.info(' ') logger.info('Saving to phy and computing refractory periods') logger.info('-'*40) results_dir, similar_templates, is_ref, est_contam_rate, kept_spikes = \ io.save_to_phy( st, clu, tF, Wall, ops['probe'], ops, imin, results_dir=results_dir, data_dtype=ops['data_dtype'], save_extra_vars=save_extra_vars, save_preprocessed_copy=save_preprocessed_copy, skip_dat_path=skip_dat_path ) logger.info(f'{int(is_ref.sum())} units found with good refractory periods') ops['n_units_total'] = np.unique(clu).size ops['n_units_good'] = int(is_ref.sum()) ops['n_spikes'] = st[kept_spikes].shape[0] if ops.get('dshift', None) is not None: ops['mean_drift'] = np.abs(ops['dshift']).mean(axis=0)[0] else: ops['mean_drift'] = np.nan elapsed = elapsed = time.time() - tic ops['runtime_postproc'] = elapsed ops['usage_postproc'] = get_performance() logger.info(f'Exporting to Phy took: {elapsed:.2f}s') runtime = time.time()-tic0 seconds = runtime % 60 mins = runtime // 60 hrs = mins // 60 mins = mins % 60 logger.info(f'Total runtime: {runtime:.2f}s = {int(hrs):02d}:' + f'{int(mins):02d}:{round(seconds)} h:m:s') ops['runtime'] = runtime io.save_ops(ops, results_dir) logger.info(f'Sorting output saved in: {results_dir}.') log_cuda_details(logger) log_performance(logger, 'info', 'Resource usage after saving', reset=True) return ops, similar_templates, is_ref, est_contam_rate, kept_spikes
[docs] def load_sorting(results_dir, device=None, load_extra_vars=False): '''Load saved sorting results into memory. Parameters ---------- results_dir : str or pathlib.Path Directory where results were saved. device : torch.device; optional. CPU or GPU device to use to load Pytorch tensors. By default, PyTorch will use the first detected GPU. If no GPUs are detected, CPU will be used. To set this manually, specify `device = torch.device(<device_name>)`. See PyTorch documentation for full description. load_extra_vars : default=False. If True, load tF, Wall, and full copies of st, clu, and spike amplitudes in addition to the other variables. Returns ------- ops : dict Dictionary storing settings and results for all algorithmic steps. st : np.ndarray 1D vector of spike times (in samples) for all clusters. This is *only* the first column of the 3-column array returned by `run_kilosort`. clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st`. similar_templates : np.ndarray. Similarity score between each pair of clusters, computed as correlation between clusters. Shape (n_clusters, n_clusters). is_ref : np.ndarray. 1D boolean array with shape (n_clusters,) indicating whether each cluster is refractory. est_contam_rate : np.ndarray. Contamination rate for each cluster, computed as fraction of refractory period violations relative to expectation based on a Poisson process. Shape (n_clusters,). kept_spikes : np.ndarray. Boolean mask with shape (n_spikes,) that is False for spikes that were removed by `kilosort.postprocessing.remove_duplicate_spikes` and True otherwise. tF : torch.Tensor. Only returned if `load_extra_vars` is True. PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) Wall : torch.Tensor. Only returned if `load_extra_vars` is True. PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). full_st : np.ndarray. Only returned if `load_extra_vars` is True. 3-column array of peak time (in samples), template, and threshold amplitude for each spike. Includes spikes removed by `kilosort.postprocessing.remove_duplicate_spikes`. full_clu : np.ndarray. Only returned if `load_extra_vars` is True. 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st[:,0]`. Includes spikes removed by `kilosort.postprocessing.remove_duplicate_spikes`. full_amp : np.ndarray. Only returned if `load_extra_vars` is True. Per-spike amplitudes, computed as the L2 norm of the PC features for each spike. Includes spikes removed by `kilosort.postprocessing.remove_duplicate_spikes`. ''' if device is None: if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') results_dir = Path(results_dir) ops = io.load_ops(results_dir / 'ops.npy', device=device) similar_templates = np.load(results_dir / 'similar_templates.npy') clu = np.load(results_dir / 'spike_clusters.npy') st = np.load(results_dir / 'spike_times.npy') kept_spikes = np.load(results_dir / 'kept_spikes.npy') acg_threshold = ops['settings']['acg_threshold'] ccg_threshold = ops['settings']['ccg_threshold'] is_ref, est_contam_rate = CCG.refract(clu, st / ops['fs'], acg_threshold=acg_threshold, ccg_threshold=ccg_threshold) results = [ops, st, clu, similar_templates, is_ref, est_contam_rate, kept_spikes] if load_extra_vars: # NOTE: tF and Wall always go on CPU, not CUDA tF = np.load(results_dir / 'tF.npy') tF = torch.from_numpy(tF) Wall = np.load(results_dir / 'Wall.npy') Wall = torch.from_numpy(Wall) full_st = np.load(results_dir / 'full_st.npy') full_clu = np.load(results_dir / 'full_clu.npy') full_amp = np.load(results_dir / 'full_amp.npy') results.extend([tF, Wall, full_st, full_clu, full_amp]) return results