Quick Start Guide#
This guide will get you up and running with the SpikeRNN framework quickly.
Installation#
First, install both packages:
git clone https://github.com/NuttidaLab/spikeRNN.git
cd spikeRNN
pip install -e .
Basic Workflow#
Step 1: Train a Rate RNN#
import torch
import scipy.io as sio
from rate import FR_RNN_dale, set_gpu, create_default_config
from rate.model import generate_input_stim_go_nogo, generate_target_continuous_go_nogo
# Load trained model
model_path = 'models/go-nogo/P_rec_0.2_Taus_4.0_20.0/model.mat'
device = set_gpu('0', 0.3)
# Set up network configuration
config = create_default_config(N=200, P_inh=0.2, P_rec=0.2)
# Create and train network (simplified example)
net = FR_RNN_dale(200, 0.2, 0.2, w_in, som_N=0, w_dist='gaus',
gain=1.5, apply_dale=True, w_out=w_out, device=device)
# Training loop
settings = {'T': 200, 'stim_on': 50, 'stim_dur': 25, 'DeltaT': 1,
'taus': [10], 'task': 'go-nogo'}
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
for trial in range(1000):
# Generate stimulus and targets
u, label = generate_input_stim_go_nogo(settings)
target = generate_target_continuous_go_nogo(settings, label)
# Forward pass and training
u_tensor = torch.tensor(u, dtype=torch.float32, device=device)
outputs = net.forward(u_tensor, settings['taus'],
{'activation': 'sigmoid', 'P_rec': 0.2}, settings)
loss = loss_op(outputs, target, {'activation': 'sigmoid'})
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save model in .mat format for spiking conversion
model_dict = {
'w': net.w.detach().cpu().numpy(),
'w_in': net.w_in.detach().cpu().numpy(),
'w_out': net.w_out.detach().cpu().numpy(),
'w0': net.w0.detach().cpu().numpy(),
'N': 200,
'm': net.m.cpu().numpy(),
'som_m': net.som_m.cpu().numpy(),
'inh': net.inh.cpu().numpy(),
'exc': net.exc.cpu().numpy(),
'taus': settings['taus'],
'taus_gaus': net.taus_gaus.detach().cpu().numpy(),
'taus_gaus0': net.taus_gaus0.detach().cpu().numpy(),
}
sio.savemat('trained_model.mat', model_dict)
Step 2: Convert to Spiking Network#
import numpy as np
from spiking import LIF_network_fnc, lambda_grid_search
# First, optimize the scaling factor
lambda_grid_search(
model_path='trained_model.mat',
scaling_range=(20, 80),
n_trials_per_factor=50,
task_type='go-nogo',
parallel=True
)
# Convert to spiking network with optimal scaling
scaling_factor = 50.0 # Use value from grid search
# Create stimulus
u = np.zeros((1, 201))
u[0, 30:50] = 1 # Go trial stimulus
# Convert and simulate
stims = {'mode': 'none'}
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(
'trained_model.mat', scaling_factor, u, stims,
downsample=1, use_initial_weights=False
)
print(f"Spike conversion completed!")
print(f"Generated {np.sum(spk)} spikes")
print(f"Output: {out[-1]:.4f}")
Step 3: Analyze Results#
from spiking import eval_go_nogo, format_spike_data
import matplotlib.pyplot as plt
# Evaluate performance
eval_go_nogo(
model_path='trained_model.mat',
scaling_factor=50.0,
n_trials=100,
plot_results=True
)
# Analyze spike patterns
spike_data = format_spike_data(spk, params['dt'])
print(f"Active neurons: {len(spike_data['active_neurons'])}")
print(f"Mean firing rate: {np.mean(spike_data['firing_rates']):.2f} Hz")
# Plot spike raster
plt.figure(figsize=(12, 8))
spike_times, spike_neurons = np.where(spk)
plt.scatter(spike_times * params['dt'], spike_neurons, s=1, c='black', alpha=0.6)
plt.xlabel('Time (s)')
plt.ylabel('Neuron Index')
plt.title('Spike Raster Plot')
plt.show()
Working with Different Tasks#
Go-NoGo Task#
from rate.model import generate_input_stim_go_nogo, generate_target_continuous_go_nogo
settings = {'T': 200, 'stim_on': 50, 'stim_dur': 25, 'DeltaT': 1,
'taus': [10], 'task': 'go-nogo'}
# Go trial
u_go = np.zeros((1, 201))
u_go[0, 30:50] = 1
# NoGo trial
u_nogo = np.zeros((1, 201))
u_nogo[0, 30:50] = -1
XOR Task#
from rate.model import generate_input_stim_xor, generate_target_continuous_xor
settings = {'T': 300, 'stim_on': [50, 110], 'stim_dur': 50, 'DeltaT': 1,
'taus': [10], 'task': 'xor'}
# XOR stimulus with two sequential inputs
u = np.zeros((2, 301))
u[0, 50:100] = 1 # First input
u[1, 110:160] = -1 # Second input (XOR = 1 × -1 = -1)
Mante Task#
from rate.model import generate_input_stim_mante, generate_target_continuous_mante
settings = {'T': 500, 'stim_on': 50, 'stim_dur': 200, 'DeltaT': 1,
'taus': [10], 'task': 'mante'}
# Context-dependent integration
u = np.zeros((4, 501))
u[0, 50:250] = np.random.randn(200) + 0.5 # Motion coherence
u[1, 50:250] = np.random.randn(200) - 0.5 # Color coherence
u[2, :] = 1 # Motion context
Model File Requirements#
Important: The spiking package only supports MATLAB .mat files because they contain complete parameter sets required for accurate spiking conversion:
Required Parameters in .mat Files#
# Complete parameter set for spiking conversion
model_data = {
'w': recurrent_weights, # NxN trained weights
'w_in': input_weights, # Nx1 input weights
'w_out': output_weights, # 1xN output weights
'w0': initial_weights, # NxN initial random weights
'N': network_size, # Number of neurons
'm': connectivity_mask, # NxN Dale's principle mask
'som_m': som_mask, # NxN SOM connectivity mask
'inh': inhibitory_indices, # Boolean array for inhibitory neurons
'exc': excitatory_indices, # Boolean array for excitatory neurons
'taus': time_constants, # Synaptic time constants
'taus_gaus': gaussian_taus, # Gaussian time constants
'taus_gaus0': initial_taus, # Initial time constants
}
Saving Models for Spiking Conversion#
When training rate models, save them in .mat format:
import scipy.io as sio
# After training rate RNN...
model_dict = {
'w': net.w.detach().cpu().numpy(),
'w_in': net.w_in.detach().cpu().numpy(),
'w_out': net.w_out.detach().cpu().numpy(),
'w0': net.w0.detach().cpu().numpy(),
'N': N,
'm': net.m.cpu().numpy(),
'som_m': net.som_m.cpu().numpy(),
'inh': net.inh.cpu().numpy(),
'exc': net.exc.cpu().numpy(),
'taus': settings['taus'],
'taus_gaus': net.taus_gaus.detach().cpu().numpy(),
'taus_gaus0': net.taus_gaus0.detach().cpu().numpy(),
}
sio.savemat('trained_model.mat', model_dict)
Advanced Usage#
Loading and Validating Models#
from spiking import load_rate_model
# Load and validate .mat model
model_data = load_rate_model('trained_model.mat')
# Check for required parameters
required_keys = ['w', 'w_in', 'w_out', 'N', 'inh', 'exc', 'taus']
missing = [k for k in required_keys if k not in model_data]
if missing:
print(f"Warning: Missing critical parameters: {missing}")
Scaling Factor Optimization#
from spiking import lambda_grid_search
# Comprehensive grid search
lambda_grid_search(
model_path='models/go-nogo/model.mat',
scaling_range=(20, 100), # Wide range
n_trials_per_factor=100, # More trials for accuracy
task_type='go-nogo',
parallel=True # Use multiprocessing
)
Next Steps#
Explore the Examples for detailed use cases
Review the API Reference for all available functions
Check out advanced features in the individual package documentation: