"""
Functions for evaluating a trained LIF RNN model constructed to perform the Go-NoGo task.
This task requires the network to respond to “Go” stimuli and withhold responses to “NoGo” stimuli, testing impulse control and decision-making capabilities.
The evaluation includes:
* Performance comparison between rate and spiking networks
* Spike raster plot visualization
* Response time analysis
* Accuracy metrics for Go and NoGo trials
"""
# PyTorch adaptation of the script to evaluate a trained LIF RNN model
# constructed to perform the Go-NoGo task
# The original model is from the following paper:
# Kim, R., Hasson, D. V. Z. T., & Pehlevan, C. (2019). A framework for
# reconciling rate and spike-based neuronal models. arXiv preprint arXiv:1904.05831.
# Original repository: https://github.com/rkim35/spikeRNN
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
import os
import time
from .LIF_network_fnc import LIF_network_fnc
[docs]
def eval_go_nogo(model_path= '../models/go-nogo/P_rec_0.2_Taus_4.0_20.0'):
# First, load one trained rate RNN
# Make sure lambda_grid_search.py was performed on the model.
# Update model_path to point where the trained model is
mat_files = [f for f in os.listdir(model_path) if f.endswith('.mat')]
model_name = mat_files[0]
model_path = os.path.join(model_path, model_name)
# Load model data to get opt_scaling_factor
model_data = sio.loadmat(model_path)
opt_scaling_factor = model_data['opt_scaling_factor'].item()
use_initial_weights = False
scaling_factor = opt_scaling_factor
down_sample = 1
# --------------------------------------------------------------
# NoGo trial example
# --------------------------------------------------------------
u = np.zeros((1, 201)) # input stim
# Run the LIF simulation
stims = {'mode': 'none'}
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(model_path, scaling_factor,
u, stims, down_sample, use_initial_weights)
dt = params['dt']
T = params['T']
t = np.arange(dt, T + dt, dt)
nogo_out = out # LIF network output
nogo_rs = rs # firing rates
nogo_spk = spk # spikes
# --------------------------------------------------------------
# Go trial example
# --------------------------------------------------------------
u = np.zeros((1, 201)) # input stim
u[0, 30:50] = 1 # Note: Python uses 0-based indexing, MATLAB uses 1-based
# Run the LIF simulation
stims = {'mode': 'none'}
W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(model_path, scaling_factor,
u, stims, down_sample, use_initial_weights)
dt = params['dt']
T = params['T']
t = np.arange(dt, T + dt, dt)
go_out = out # LIF network output
go_rs = rs # firing rates
go_spk = spk # spikes
# Load additional model data for plotting
model_data = sio.loadmat(model_path)
inh = model_data['inh'].flatten()
exc = model_data['exc'].flatten()
N = int(model_data['N'].item())
# --------------------------------------------------------------
# Plot the network output
# --------------------------------------------------------------
plt.figure()
plt.plot(t, nogo_out.flatten(), 'm', linewidth=2, label='NoGo')
plt.plot(t, go_out.flatten(), 'g', linewidth=2, label='Go')
plt.xlabel('Time (s)')
plt.ylabel('Network Output')
plt.legend()
plt.title('Network Output Comparison')
plt.tight_layout()
plt.show()
# --------------------------------------------------------------
# Plot the spike raster
# --------------------------------------------------------------
# NoGo spike raster
plt.figure(figsize=(8, 6))
inh_ind = np.where(inh == 1)[0]
exc_ind = np.where(exc == 1)[0]
all_ind = np.arange(N)
for i in range(len(all_ind)):
curr_spk = nogo_spk[all_ind[i], 10:] # Skip first 10 time steps
spike_times = t[10:][curr_spk > 0] # Get corresponding time points
if exc[all_ind[i]] == 1:
plt.plot(spike_times, np.ones(len(spike_times)) * i, 'r.', markersize=8)
else:
plt.plot(spike_times, np.ones(len(spike_times)) * i, 'b.', markersize=8)
plt.xlim([0, 1])
plt.ylim([-5, 205])
plt.xlabel('Time (s)')
plt.ylabel('Neuron Index')
plt.title('NoGo Spike Raster (Red: Excitatory, Blue: Inhibitory)')
plt.tight_layout()
plt.show()
# Go spike raster
plt.figure(figsize=(8, 6))
for i in range(len(all_ind)):
curr_spk = go_spk[all_ind[i], 10:] # Skip first 10 time steps
spike_times = t[10:][curr_spk > 0] # Get corresponding time points
if exc[all_ind[i]] == 1:
plt.plot(spike_times, np.ones(len(spike_times)) * i, 'r.', markersize=8)
else:
plt.plot(spike_times, np.ones(len(spike_times)) * i, 'b.', markersize=8)
plt.xlim([0, 1])
plt.ylim([-5, 205])
plt.xlabel('Time (s)')
plt.ylabel('Neuron Index')
plt.title('Go Spike Raster (Red: Excitatory, Blue: Inhibitory)')
plt.tight_layout()
plt.show()
if __name__ == "__main__":
eval_go_nogo()