Examples#
This section provides detailed examples of using the complete SpikeRNN framework for training rate RNNs and converting them to spiking networks.
Complete Workflow Example#
This example shows the complete workflow from training a rate RNN to converting it to a spiking network:
import torch
import numpy as np
from rate import FR_RNN_dale, set_gpu, create_default_config
from rate.model import generate_input_stim_go_nogo, generate_target_continuous_go_nogo
from spiking import LIF_network_fnc, lambda_grid_search, eval_go_nogo
# Step 1: Set up GPU and configuration
device = set_gpu('0', 0.3)
config = create_default_config(N=200, P_inh=0.2, P_rec=0.2)
# Step 2: Train rate RNN (simplified example)
# ... training code ...
# Step 3: Save trained model
# Save trained model in .mat format for spiking conversion
import scipy.io as sio
model_dict = {
'w': net.w.detach().cpu().numpy(),
'w_in': net.w_in.detach().cpu().numpy(),
'w_out': net.w_out.detach().cpu().numpy(),
'w0': net.w0.detach().cpu().numpy(),
'N': N,
'm': net.m.cpu().numpy(),
'som_m': net.som_m.cpu().numpy(),
'inh': net.inh.cpu().numpy(),
'exc': net.exc.cpu().numpy(),
'taus': settings['taus'],
'taus_gaus': net.taus_gaus.detach().cpu().numpy(),
'taus_gaus0': net.taus_gaus0.detach().cpu().numpy(),
}
sio.savemat('trained_model.mat', model_dict)
# Step 4: Convert to spiking network
scaling_factor = 50.0
u = np.zeros((1, 201))
u[0, 30:50] = 1 # Go trial stimulus
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(
'trained_model.mat', scaling_factor, u, {'mode': 'none'},
downsample=1, use_initial_weights=False
)
print(f"Rate-to-spike conversion completed!")
print(f"Generated {np.sum(spk)} spikes")
Rate RNN Training Examples#
Go-NoGo Task#
The Go-NoGo task trains the network to respond to “Go” stimuli and withhold responses to “NoGo” stimuli:
import torch
import torch.optim as optim
from rate import FR_RNN_dale, set_gpu
from rate.model import (generate_input_stim_go_nogo,
generate_target_continuous_go_nogo, loss_op)
# Setup
device = set_gpu('0', 0.4)
N = 200
P_inh = 0.2
P_rec = 0.2
# Network initialization
w_in = torch.randn(N, 1, device=device)
w_out = torch.randn(1, N, device=device) / 100
net = FR_RNN_dale(N, P_inh, P_rec, w_in, som_N=0, w_dist='gaus',
gain=1.5, apply_dale=True, w_out=w_out, device=device)
# Task settings
settings = {
'T': 200, 'stim_on': 50, 'stim_dur': 25,
'DeltaT': 1, 'taus': [10], 'task': 'go-nogo'
}
training_params = {
'learning_rate': 0.01, 'loss_threshold': 7,
'eval_freq': 100, 'P_rec': 0.20, 'activation': 'sigmoid'
}
# Training loop
optimizer = optim.Adam(net.parameters(), lr=training_params['learning_rate'])
n_trials = 1000
for tr in range(n_trials):
optimizer.zero_grad()
# Generate task data
u, label = generate_input_stim_go_nogo(settings)
target = generate_target_continuous_go_nogo(settings, label)
u_tensor = torch.tensor(u, dtype=torch.float32, device=device)
# Forward pass
outputs = net.forward(u_tensor, settings['taus'], training_params, settings)
# Compute loss and update
loss = loss_op(outputs, target, training_params)
loss.backward()
optimizer.step()
if tr % 100 == 0:
print(f"Trial {tr}, Loss: {loss.item():.4f}")
XOR Task#
The XOR task requires temporal working memory to compute XOR of two sequential inputs:
from rate.model import generate_input_stim_xor, generate_target_continuous_xor
# Task settings
settings = {
'T': 300, 'stim_on': 50, 'stim_dur': 50, 'delay': 10,
'DeltaT': 1, 'taus': [10], 'task': 'xor'
}
# Network with 2 inputs for XOR
w_in = torch.randn(N, 2, device=device)
net = FR_RNN_dale(N, P_inh, P_rec, w_in, som_N=0, w_dist='gaus',
gain=1.5, apply_dale=True, w_out=w_out, device=device)
# Training loop
for tr in range(n_trials):
optimizer.zero_grad()
u, label = generate_input_stim_xor(settings)
target = generate_target_continuous_xor(settings, label)
u_tensor = torch.tensor(u, dtype=torch.float32, device=device)
outputs = net.forward(u_tensor, settings['taus'], training_params, settings)
loss = loss_op(outputs, target, training_params)
loss.backward()
optimizer.step()
Spiking Network Examples#
Basic Rate-to-Spike Conversion#
Convert a trained rate RNN to a spiking network:
from spiking import LIF_network_fnc
import numpy as np
import matplotlib.pyplot as plt
# Load trained model (.mat files only for spiking conversion)
model_path = 'trained_model.mat'
scaling_factor = 50.0
# Create test stimulus
u = np.zeros((1, 201))
u[0, 30:50] = 1 # Go trial stimulus
# Convert to spiking network
stims = {'mode': 'none'}
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(
model_path, scaling_factor, u, stims,
downsample=1, use_initial_weights=False
)
print(f"Conversion completed!")
print(f"Generated {np.sum(spk)} spikes")
Scaling Factor Optimization#
Finding the optimal scaling factor is crucial for good performance:
from spiking import lambda_grid_search
# Comprehensive grid search
lambda_grid_search(
model_path='models/go-nogo/trained_model.mat',
scaling_range=(20, 100),
n_trials_per_factor=100,
task_type='go-nogo',
parallel=True
)
Task Performance Evaluation#
Evaluate converted spiking networks:
from spiking import eval_go_nogo
# Evaluate Go-NoGo performance
eval_go_nogo(
model_path='models/go-nogo/trained_model.mat',
scaling_factor=50.0,
n_trials=200,
plot_results=True
)
Batch Processing#
Process multiple models:
import os
from spiking import LIF_network_fnc, lambda_grid_search
# Process all .mat models in directory
model_dir = 'models/'
model_paths = [
'models/go-nogo/trained_model.mat',
'models/xor/trained_model.mat',
'models/mante/trained_model.mat'
]
results = {}
for model_path in model_paths:
if os.path.exists(model_path):
# Find optimal scaling
lambda_grid_search(model_path=model_path, task_type='go-nogo')
# Convert and analyze
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(
model_path, 50.0, stimulus, {'mode': 'none'}, 1, False
)
results[model_path] = {
'spikes': np.sum(spk),
'output': out[-1]
}
Advanced Analysis#
from spiking import format_spike_data
import matplotlib.pyplot as plt
# Load and convert model
model_path = 'models/go-nogo/trained_model.mat'
# Format spike data for analysis
spike_data = format_spike_data(spk, params['dt'])
# Print statistics
print(f"Total spikes: {spike_data['total_spikes']}")
print(f"Number of active neurons: {len(spike_data['active_neurons'])}")
print(f"Mean firing rate: {np.mean(spike_data['firing_rates']):.2f} Hz")
print(f"Spike rate: {spike_data['total_spikes'] / params['total_time']:.2f} spikes/s")
# Plot firing rate distribution
plt.figure(figsize=(8, 5))
plt.hist(spike_data['firing_rates'], bins=30, alpha=0.7)
plt.xlabel('Firing Rate (Hz)')
plt.ylabel('Number of Neurons')
plt.title('Firing Rate Distribution')
plt.show()
Advanced Examples#
Multi-Task Comparison#
Compare spiking network performance across different tasks:
tasks = ['go-nogo', 'xor', 'mante']
model_paths = [
'models/go-nogo/trained_model.mat',
'models/xor/trained_model.mat',
'models/mante/trained_model.mat'
]
for task, model_path in zip(tasks, model_paths):
print(f"\nEvaluating {task} task...")
# Optimize scaling factor
lambda_grid_search(
model_path=model_path,
task_type=task,
parallel=True
)
# Evaluate performance
if task == 'go-nogo':
eval_go_nogo(model_path=model_path, plot_results=True)
Parameter Sensitivity Analysis#
Test how different LIF parameters affect conversion:
from spiking.utils import generate_lif_params
model_path = 'models/go-nogo/trained_model.mat'
scaling_factor = 50.0
u = np.zeros((1, 201))
u[0, 30:50] = 1
# Test different time constants
time_constants = [0.01, 0.02, 0.05, 0.1]
for tm in time_constants:
print(f"\nTesting membrane time constant: {tm}s")
# This would require modifying LIF_network_fnc to accept custom parameters
# or creating a custom implementation
# Results would show how membrane dynamics affect spike timing