{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Reproducing Kilosort4 benchmarks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 1. Download data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, choose a simulated dataset to use for the comparison. You'll need to download and extract the data before running the rest of the notebook. For this demo, I chose the medium drift dataset and saved it to `D:/sim_med_drift`.\n", "\n", "All datasets: https://janelia.figshare.com/articles/dataset/Simulations_from_kilosort4_paper/25298815/1\n", "\n", "Medium drift: https://janelia.figshare.com/ndownloader/files/44729869" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Decompress data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will also need to decompress the binary file, which has been compressed with `mtscomp`. You can see details about the package here: https://github.com/int-brain-lab/mtscomp\n", "\n", "Install with:\n", "```\n", "pip install mtscomp\n", "```\n", "Then decompress the .cbin file in the unzipped directory downloaded in the previous step:\n", "```\n", "mtsdecomp sim.imec0.ap.cbin -o data.bin\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Get probe information" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The probe file for the recording can be parsed from `sim.imec0.ap.meta` using a Spike GLX script available here:\n", "\n", "https://github.com/jenniferColonell/SGLXMetaToCoords/blob/main/SGLXMetaToCoords.py\n", "\n", "Simply download the script and run it in your Kilosort environment with `python SGLXMetaToCoords.py`, then follow the prompt to select the `.meta` file. This will save the probe file in the same directory as `sim.imec0.ap_kilosortChanMap.mat`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Run Kilosort4" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "from kilosort import run_kilosort\n", "from kilosort.io import load_probe, load_ops\n", "from kilosort.bench import load_GT, load_phy, compare_recordings\n", "\n", "# Specify paths\n", "# NOTE: Make sure to update `download_path` to match your local directory.\n", "download_path = Path('D:/sim_med_drift')\n", "results_dir = download_path / 'kilosort4'\n", "filename = download_path / 'data.bin'\n", "probe_path = download_path / 'sim.imec0.ap_kilosortChanMap.mat'\n", "gt_path = download_path / 'sim.imec0.ap_params.npz'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Run Kilosort4 on downloaded dataset\n", "# You can also run it without drift correction by setting `nblocks = 0`,\n", "# or only rigid drift correction by setting `nblocks = 1`.\n", "probe = load_probe(probe_path)\n", "settings = {'n_chan_bin': 385, 'filename': filename, 'nblocks': 5}\n", "_ = run_kilosort(settings, probe=probe)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load `ops`\n", "# NOTE: If you've already sorted the data, you can skip the previous step.\n", "ops = load_ops(results_dir / 'ops.npy')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Compute ground truth comparison" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 600/600 [04:25<00:00, 2.26it/s]\n" ] } ], "source": [ "# NOTE: Using this with the current version of Kilosort4 requires updates\n", "# to `kilosort.bench.py`` added in v4.0.32\n", "\n", "# Load ground truth results\n", "st_gt, clu_gt, yclu_gt, mu_gt, Wsub, nsp = load_GT(filename, ops, gt_path)\n", "# Load sorting results\n", "st_new, clu_new, yclu_new, Wsub = load_phy(filename, results_dir, ops)\n", "# Compare\n", "fmax, fmiss, fpos, best_ind, matched_all, top_inds = \\\n", " compare_recordings(st_gt, clu_gt, yclu_gt, st_new, clu_new, yclu_new)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number correct: 553\n", "number missed: 47\n" ] } ], "source": [ "def num_correct(fpos, fmiss):\n", " score = 1 - fpos - fmiss\n", " num_correct = (score >= 0.8).sum()\n", " return num_correct\n", "\n", "n = num_correct(fpos, fmiss)\n", "print(f'number correct: {n}')\n", "print(f'number missed: {600 - n}') # 600 ground truth units\n", "print(f'number correct in paper: 555') # for medium drift dataset, from fig 4j" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "kilosort", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 2 }