Source code for spiking.lambda_grid_search

import numpy as np
import scipy.io as sio
import os
from multiprocessing import Pool, cpu_count
import argparse
import warnings
import torch
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
warnings.filterwarnings("ignore")

def _init_worker():
    """Initialize each worker process with a fresh random seed and clear any GPU state"""

    np.random.seed(int.from_bytes(os.urandom(4), byteorder='little'))
    
    # Force CPU computation for worker processes
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    
    import sys
    import importlib
    custom_modules = ['spiking.LIF_network_fnc', 'spiking.utils']
    for module in custom_modules:
        if module in sys.modules:
            importlib.reload(sys.modules[module])
        
    np.random.seed()
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
    torch.set_num_threads(1)  # Prevent thread contention

[docs] def evaluate_single_trial(args): """ Evaluate a single trial for a given task. Args: args (tuple): Tuple containing the current full path, scaling factor, task name, use initial weights, and down sample. Returns: int: 1 if the trial is successful, 0 otherwise. """ curr_full, scaling_factor, task_name, use_initial_weights, down_sample = args from spiking.LIF_network_fnc import LIF_network_fnc model_data = sio.loadmat(curr_full) try: if task_name == 'go-nogo': u = np.zeros((1, 201)) trial_type = 0 if np.random.rand() >= 0.50: u[0, 50:75] = 1.0 trial_type = 1 stims = {'mode': 'none'} W, REC, spk, rs, all_fr, out, params = LIF_network_fnc(model_data, scaling_factor, u, stims, down_sample, use_initial_weights) max_output = np.max(out[10000:]) if trial_type == 1: # Go trial success = max_output > 0.7 else: # NoGo trial success = max_output < 0.3 return 1 if success else 0 elif task_name == 'mante': u = np.zeros((4, 501)) u_lab = np.zeros(2) if np.random.rand() >= 0.5: u[0, 50:250] = np.random.randn(200) + 0.5 u_lab[0] = 1 else: u[0, 50:250] = np.random.randn(200) - 0.5 u_lab[0] = -1 if np.random.rand() >= 0.5: u[1, 50:250] = np.random.randn(200) + 0.5 u_lab[1] = 1 else: u[1, 50:250] = np.random.randn(200) - 0.5 u_lab[1] = -1 if np.random.rand() >= 0.5: u[2, :] = 1 label = u_lab[0] else: u[3, :] = 1 label = u_lab[1] stims = {'mode': 'none'} _, _, _, _, _, out, _ = LIF_network_fnc(model_data, scaling_factor, u, stims, down_sample, use_initial_weights) if (label == 1 and np.max(out[26000:]) > 0.7) or (label == -1 and np.min(out[26000:]) < -0.7): return 1 return 0 elif task_name == 'xor': u = np.zeros((2, 301)) u_lab = np.zeros(2) if np.random.rand() >= 0.5: u[0, 50:100] = 1 u_lab[0] = 1 else: u[0, 50:100] = -1 u_lab[0] = -1 if np.random.rand() >= 0.5: u[1, 110:160] = 1 u_lab[1] = 1 else: u[1, 110:160] = -1 u_lab[1] = -1 label = np.prod(u_lab) stims = {'mode': 'none'} _, _, _, _, _, out, _ = LIF_network_fnc(model_data, scaling_factor, u, stims, down_sample, use_initial_weights) if (label == 1 and np.max(out[20000:]) > 0.7) or (label == -1 and np.min(out[20000:]) < -0.7): return 1 return 0 except: return 0
def parse_range(range_str): parts = list(map(int, range_str.split(":"))) if len(parts) == 3: return list(range(parts[0], parts[1] + 1, parts[2])) elif len(parts) == 2: return list(range(parts[0], parts[1] + 1)) else: raise ValueError("Range must be in 'start:stop:step' or 'start:stop' format") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model_dir", type=str, required=True) parser.add_argument("--task_name", type=str, choices=["go-nogo", "xor", "mante"], required=True) parser.add_argument("--n_trials", type=int, default=100) parser.add_argument("--scaling_factors", type=str, default="20:75:5") args = parser.parse_args() scaling_factors = parse_range(args.scaling_factors) lambda_grid_search(args.model_dir, args.task_name, args.n_trials, scaling_factors) # Run the script with the following command: """ python -m spiking.lambda_grid_search \ --model_dir "./eg/models/xor/P_rec_0.2_Taus_4.0_20.0" \ --task_name xor \ --n_trials 50 \ --scaling_factors 20:76:5 """