Source code for spiking.eval_tasks

#!/usr/bin/env python3
"""
Evaluation script for trained spiking RNN models on cognitive tasks.
"""

import argparse
import os
import sys
import scipy.io as sio
import numpy as np
from typing import Dict, Any, Optional

from .LIF_network_fnc import LIF_network_fnc
from .tasks import SpikingTaskFactory
from .abstract import AbstractSpikingRNN


[docs] class LIFNetworkAdapter(AbstractSpikingRNN): """ Adapter to use LIF_network_fnc with the spiking task interface. """ def __init__(self, model_path: str, scaling_factor: float): # Create a minimal config for the abstract class from .abstract import SpikingConfig config = SpikingConfig(N=200) # Default N, will be overridden by actual model super().__init__(config) self.model_path = model_path self.scaling_factor = scaling_factor self.use_initial_weights = False self.downsample = 1
[docs] def load_rate_weights(self, model_path: str) -> None: """Load weights from a trained rate RNN model.""" # This is handled by LIF_network_fnc internally pass
[docs] def initialize_lif_params(self) -> None: """Initialize LIF neuron parameters.""" # This is handled by LIF_network_fnc internally pass
[docs] def compute_firing_rates(self, spikes: np.ndarray) -> np.ndarray: """Compute firing rates from spike trains.""" # Simple firing rate computation if spikes.size == 0: return np.array([]) return np.mean(spikes, axis=0) if spikes.ndim > 1 else np.mean(spikes)
[docs] def simulate(self, stimulus: np.ndarray, stims: Dict[str, Any]): """Simulate the LIF network on given stimulus.""" W, REC, spikes, rs, all_fr, output, params = LIF_network_fnc( self.model_path, self.scaling_factor, stimulus, stims, self.downsample, self.use_initial_weights ) return spikes, None, output, params
[docs] def load_model_and_scaling_factor(model_dir: str, optimal_scaling_factor: Optional[float] = None) -> tuple: """ Load model file and determine scaling factor. Args: model_dir: Directory containing the .mat model file optimal_scaling_factor: Override scaling factor if provided Returns: Tuple of (model_path, scaling_factor) """ # Find .mat file mat_files = [f for f in os.listdir(model_dir) if f.endswith('.mat')] if not mat_files: raise FileNotFoundError(f"No .mat files found in {model_dir}") model_path = os.path.join(model_dir, mat_files[0]) print(f"Using model file: {model_path}") # Load scaling factor if optimal_scaling_factor is not None: scaling_factor = optimal_scaling_factor print(f"Using provided scaling factor: {scaling_factor}") else: model_data = sio.loadmat(model_path) if 'opt_scaling_factor' not in model_data: raise ValueError("opt_scaling_factor not found in .mat file. Please run lambda_grid_search first or provide --scaling_factor") scaling_factor = float(model_data['opt_scaling_factor'].item()) print(f"Using scaling factor from model: {scaling_factor}") return model_path, scaling_factor
[docs] def evaluate_task(task_name: str, model_dir: str, optimal_scaling_factor: Optional[float] = None, task_settings: Optional[Dict[str, Any]] = None, save_plots: bool = True) -> Dict[str, float]: """ Evaluate a spiking task on a trained model. Args: task_name: Name of the task ('go_nogo', 'xor', 'mante') model_dir: Directory containing the trained model optimal_scaling_factor: Override scaling factor task_settings: Override task settings save_plots: Whether to save visualization plots Returns: Performance metrics dictionary """ # Load model and scaling factor model_path, scaling_factor = load_model_and_scaling_factor(model_dir, optimal_scaling_factor) # Create spiking network adapter spiking_rnn = LIFNetworkAdapter(model_path, scaling_factor) # Create task task = SpikingTaskFactory.create_task(task_name, task_settings) print(f"Created {task.__class__.__name__} with settings: {task.settings}") # Evaluate performance stimulus, label = task.generate_stimulus() performance = task.evaluate_trial(spiking_rnn, stimulus, label) # Create visualizations if requested if save_plots: print(f"\nGenerating sample trials and visualizations...") results = [] # Generate sample trials using task's sample trial types sample_trial_types = task.get_sample_trial_types() if sample_trial_types: for trial_type in sample_trial_types: try: stimulus, label = task.generate_stimulus(trial_type) result = task.evaluate_trial(spiking_rnn, stimulus, label) results.append(result) except Exception as e: print(f"Warning: Failed to generate trial type '{trial_type}': {e}") else: # Fallback: generate a few random trials for _ in range(4): try: stimulus, label = task.generate_stimulus() result = task.evaluate_trial(spiking_rnn, stimulus, label) results.append(result) except Exception as e: print(f"Warning: Failed to generate random trial: {e}") # Save visualizations using task's built-in methods if hasattr(task, 'create_visualization') and results: try: task.create_visualization(results, model_dir) plot_dir = os.path.join(model_dir, 'plots') print(f"Plots saved to: {plot_dir}") except Exception as e: print(f"Warning: Failed to create visualizations: {e}") elif results: print(f"Generated {len(results)} sample trials (no visualization method available)") else: print("No sample trials were generated for visualization") return performance
[docs] def main(): parser = argparse.ArgumentParser( description='Evaluate trained spiking RNN models on cognitive tasks.', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/ python -m spiking.eval_tasks --task xor --model_dir models/xor/ --n_trials 50 python -m spiking.eval_tasks --task mante --model_dir models/mante/ --scaling_factor 45.0 """ ) # Get available tasks from factory from .tasks import SpikingTaskFactory available_tasks = SpikingTaskFactory.list_available_tasks() parser.add_argument('--task', type=str, required=True, help=f'Task to evaluate. Available: {", ".join(available_tasks)}') parser.add_argument('--model_dir', type=str, required=True, help='Directory containing the trained model .mat file') parser.add_argument('--scaling_factor', type=float, default=None, help='Override scaling factor (uses value from .mat file if not provided)') parser.add_argument('--no_plots', action='store_true', help='Skip generating visualization plots') # Task-specific settings (advanced usage) parser.add_argument('--T', type=int, help='Trial duration (timesteps)') parser.add_argument('--stim_on', type=int, help='Stimulus onset time') parser.add_argument('--stim_dur', type=int, help='Stimulus duration') args = parser.parse_args() # Build task settings from arguments task_settings = {} for param in ['T', 'stim_on', 'stim_dur']: value = getattr(args, param) if value is not None: task_settings[param] = value task_settings = task_settings if task_settings else None try: performance = evaluate_task( task_name=args.task, model_dir=args.model_dir, optimal_scaling_factor=args.scaling_factor, task_settings=task_settings, save_plots=not args.no_plots ) print(f"\n✓ Evaluation completed successfully!") # print(f"Performance: {performance}") return 0 except Exception as e: print(f"\n✗ Evaluation failed: {e}") return 1
if __name__ == "__main__": sys.exit(main()) # Usage: """ python -m spiking.eval_tasks --task go_nogo --model_dir models/go-nogo/ python -m spiking.eval_tasks --task xor --model_dir models/xor/ --scaling_factor 45.0 """