Examples#

This section provides detailed examples of using the complete spikeRNN framework for training rate RNNs and converting them to spiking networks. The examples below can be run interactively in a Python environment or adapted into standalone scripts.

Complete Workflow Example#

This example shows the complete workflow from training a rate RNN to converting it to a spiking network:

import torch
import numpy as np
from rate import FR_RNN_dale, set_gpu, create_default_config
from rate.tasks import TaskFactory
from spiking import LIF_network_fnc, lambda_grid_search, evaluate_task

# Step 1: Set up GPU and configuration
device = set_gpu('0', 0.3)
config = create_default_config(N=200, P_inh=0.2, P_rec=0.2)

# Step 2: Train rate RNN
# ... training code ...

# Step 3: Save trained model
# Save trained model in .mat format for spiking conversion
import scipy.io as sio

model_dict = {
    'w': net.w.detach().cpu().numpy(),
    'w_in': net.w_in.detach().cpu().numpy(),
    'w_out': net.w_out.detach().cpu().numpy(),
    'w0': net.w0.detach().cpu().numpy(),
    'N': N,
    'm': net.m.cpu().numpy(),
    'som_m': net.som_m.cpu().numpy(),
    'inh': net.inh.cpu().numpy(),
    'exc': net.exc.cpu().numpy(),
    'taus': settings['taus'],
    'taus_gaus': net.taus_gaus.detach().cpu().numpy(),
    'taus_gaus0': net.taus_gaus0.detach().cpu().numpy(),
}
sio.savemat('trained_model.mat', model_dict)

# Step 4: Convert to spiking network
scaling_factor = 50.0
u = np.zeros((1, 201))
u[0, 30:50] = 1  # Go trial stimulus

W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(
    'trained_model.mat', scaling_factor, u, {'mode': 'none'},
    downsample=1, use_initial_weights=False
)

print(f"Rate-to-spike conversion completed!")
print(f"Generated {np.sum(spk)} spikes")

Rate RNN Training Examples#

Go-NoGo Task#

The Go-NoGo task trains the network to respond to “Go” stimuli and withhold responses to “NoGo” stimuli:

import torch
import torch.optim as optim
from rate import FR_RNN_dale, set_gpu
from rate.model import loss_op
from rate.tasks import TaskFactory

# Setup
device = set_gpu('0', 0.4)
N = 200
P_inh = 0.2
P_rec = 0.2

# Network initialization
w_in = torch.randn(N, 1, device=device)
w_out = torch.randn(1, N, device=device) / 100
net = FR_RNN_dale(N, P_inh, P_rec, w_in, som_N=0, w_dist='gaus',
                  gain=1.5, apply_dale=True, w_out=w_out, device=device)

# Task settings
settings = {
    'T': 200, 'stim_on': 50, 'stim_dur': 25,
    'DeltaT': 1, 'taus': [10], 'task': 'go-nogo'
}

training_params = {
    'learning_rate': 0.01, 'loss_threshold': 7,
    'eval_freq': 100, 'P_rec': 0.20, 'activation': 'sigmoid'
}

# Training loop
optimizer = optim.Adam(net.parameters(), lr=training_params['learning_rate'])
n_trials = 1000

for tr in range(n_trials):
    optimizer.zero_grad()

    # Generate task data
    task = TaskFactory.create_task('go_nogo', settings)
    u, target, label = task.simulate_trial()
    u_tensor = torch.tensor(u, dtype=torch.float32, device=device)

    # Forward pass
    outputs = net.forward(u_tensor, settings['taus'], training_params, settings)

    # Compute loss and update
    loss = loss_op(outputs, target, training_params)
    loss.backward()
    optimizer.step()

    if tr % 100 == 0:
        print(f"Trial {tr}, Loss: {loss.item():.4f}")

XOR Task#

The XOR task requires temporal working memory to compute XOR of two sequential inputs:

from rate.tasks import TaskFactory

# Task settings
settings = {
    'T': 300, 'stim_on': 50, 'stim_dur': 50, 'delay': 10,
    'DeltaT': 1, 'taus': [10], 'task': 'xor'
}

# Network with 2 inputs for XOR
w_in = torch.randn(N, 2, device=device)
net = FR_RNN_dale(N, P_inh, P_rec, w_in, som_N=0, w_dist='gaus',
                  gain=1.5, apply_dale=True, w_out=w_out, device=device)

# Training loop
for tr in range(n_trials):
    optimizer.zero_grad()
    task = TaskFactory.create_task('xor', settings)
    u, target, label = task.simulate_trial()
    u_tensor = torch.tensor(u, dtype=torch.float32, device=device)

    outputs = net.forward(u_tensor, settings['taus'], training_params, settings)
    loss = loss_op(outputs, target, training_params)
    loss.backward()
    optimizer.step()

Spiking Network Examples#

Basic Rate-to-Spike Conversion#

Convert a trained rate RNN to a spiking network:

from spiking import LIF_network_fnc
import numpy as np
import matplotlib.pyplot as plt

# Load trained model (.mat files only for spiking conversion)
model_path = 'trained_model.mat'
scaling_factor = 50.0

# Create test stimulus
u = np.zeros((1, 201))
u[0, 30:50] = 1  # Go trial stimulus

# Convert to spiking network
stims = {'mode': 'none'}
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(
    model_path, scaling_factor, u, stims,
    downsample=1, use_initial_weights=False
)

print(f"Conversion completed!")
print(f"Generated {np.sum(spk)} spikes")

Scaling Factor Optimization#

Finding the optimal scaling factor is crucial for good performance. You can run the grid search from the command line:

python -m spiking.lambda_grid_search \
    --model_dir "models/go-nogo/P_rec_0.2_Taus_4.0_20.0" \
    --task_name go-nogo \
    --n_trials 100 \
    --scaling_factors 20:76:5

Or call the function from within a Python script:

from spiking import lambda_grid_search

# Comprehensive grid search
lambda_grid_search(
    model_dir='models/go-nogo/P_rec_0.2_Taus_4.0_20.0',
    task_name='go-nogo',
    n_trials=100,
    scaling_factors=(20, 76, 5)
)

Task Performance Evaluation#

You can also evaluate converted spiking networks directly from the command line. For example, to evaluate the Go-NoGo task for a specific model, run the following command from the spikeRNN directory:

python -m spiking.eval_tasks --task go_nogo \
    --model_dir models/go-nogo/P_rec_0.2_Taus_4.0_20.0

If you have a specific scaling factor you want to use, you can specify it:

python -m spiking.eval_tasks --task go_nogo \
    --model_dir models/go-nogo/P_rec_0.2_Taus_4.0_20.0 \
    --scaling_factor 50.0

Alternatively, you can call the evaluation function from a Python script:

from spiking.eval_tasks import evaluate_task

# Evaluate Go-NoGo performance
performance = evaluate_task(
    task_name='go_nogo',
    model_dir='models/go-nogo/P_rec_0.2_Taus_4.0_20.0'
)

All registered tasks can be evaluated using the same interface:

python -m spiking.eval_tasks --task xor --model_dir models/xor/
python -m spiking.eval_tasks --task mante --model_dir models/mante/