Source code for spiking.lambda_grid_search

import numpy as np
import scipy.io as sio
import os
from multiprocessing import Pool, cpu_count
import argparse
import warnings
import torch
import multiprocessing
from typing import Dict, Any, Optional
multiprocessing.set_start_method('spawn', force=True)
warnings.filterwarnings("ignore")

def _init_worker():
    """Initialize each worker process with a fresh random seed and clear any GPU state"""

    np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
    
    # Force CPU computation for worker processes
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    
    import sys
    import importlib
    custom_modules = ['spiking.LIF_network_fnc', 'spiking.utils']
    for module in custom_modules:
        if module in sys.modules:
            importlib.reload(sys.modules[module])
        
    np.random.seed()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    torch.set_num_threads(1)  # Prevent thread contention

[docs] def evaluate_single_trial(task_name: str, model_path: str, scaling_factor: float, settings: Optional[Dict[str, Any]] = None) -> int: """ Evaluate a single trial for a given task using the appropriate evaluator class. Args: task_name: Name of the task ('go_nogo', 'xor', 'mante') model_path: Path to the model .mat file scaling_factor: Scaling factor for the model settings: Optional custom settings. If None, uses default settings. Returns: int: 1 if trial is correct, 0 if incorrect """ from spiking.eval_tasks import SpikingEvaluatorFactory, _get_default_task_settings try: task_name = task_name.replace('-', '_') # Use provided settings or get default settings for the task if settings is None: settings = _get_default_task_settings(task_name) # Create the appropriate evaluator using the factory evaluator = SpikingEvaluatorFactory.create_evaluator(task_name, settings) # Use the evaluator's evaluate_single_trial method return evaluator.evaluate_single_trial(model_path, scaling_factor) except Exception as e: print(f"Error in evaluate_single_trial: {e}") return 0
def parse_range(range_str): parts = list(map(int, range_str.split(":"))) if len(parts) == 3: return list(range(parts[0], parts[1] + 1, parts[2])) elif len(parts) == 2: return list(range(parts[0], parts[1] + 1)) else: raise ValueError("Range must be in 'start:stop:step' or 'start:stop' format") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--task_name", type=str, choices=["go-nogo", "xor", "mante"], required=True) parser.add_argument("--n_trials", type=int, default=100) parser.add_argument("--scaling_factors", type=str, default="20:75:5") # Task-specific settings (advanced usage) - defaults are None so task-specific defaults apply parser.add_argument("--T", type=int, help="Trial duration (timesteps)") parser.add_argument("--stim_on", type=int, help="Stimulus onset time") parser.add_argument("--stim_dur", type=int, help="Stimulus duration") parser.add_argument("--delay", type=int, help="Delay between stimuli (XOR task)") parser.add_argument("--eval_amp_thresh", type=float, help="Evaluation amplitude threshold") args = parser.parse_args() # Build task settings from arguments task_settings = {} for param in ['T', 'stim_on', 'stim_dur', 'delay', 'eval_amp_thresh']: value = getattr(args, param) if value is not None: task_settings[param] = value task_settings = task_settings if task_settings else None scaling_factors = parse_range(args.scaling_factors) lambda_grid_search(args.model_path, args.task_name, args.n_trials, scaling_factors, task_settings) # Run the script with the following command: """ # Basic usage with default settings: python -m spiking.lambda_grid_search \ --model_path "./models/xor/rate_model_xor.mat" \ --task_name xor \ --n_trials 50 \ --scaling_factors 20:76:5 # Advanced usage with custom settings: python -m spiking.lambda_grid_search \ --model_path "./models/go-nogo/rate_model_go_nogo.mat" \ --task_name go-nogo \ --n_trials 100 \ --scaling_factors 20:75:5 \ --T 200 \ --stim_on 50 \ --stim_dur 50 \ --eval_amp_thresh 0.7 """