Source code for rate.model

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8

# PyTorch adaptation of the continuous rate-based RNN 
# The original model is from the following paper:
# Kim, R., Hasson, D. V. Z. T., & Pehlevan, C. (2019). A framework for 
# reconciling rate and spike-based neuronal models. arXiv preprint arXiv:1904.05831.
# Original TensorFlow repository: https://github.com/rkim35/spikeRNN

import os, sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io
from typing import Dict, Any, Tuple, List, Union

'''
CONTINUOUS FIRING-RATE RNN CLASS
'''

[docs] class FR_RNN_dale(nn.Module): """ Firing-rate RNN model for excitatory and inhibitory neurons Initialization of the firing-rate model with recurrent connections """ def __init__(self, N: int, P_inh: float, P_rec: float, w_in: np.ndarray, som_N: int, w_dist: str, gain: float, apply_dale: bool, w_out: np.ndarray, device: torch.device) -> None: """ Network initialization method. Args: N (int): Number of units (neurons). P_inh (float): Probability of a neuron being inhibitory. P_rec (float): Recurrent connection probability. w_in (np.ndarray): NxN weight matrix for the input stimuli. som_N (int): Number of SOM neurons (set to 0 for no SOM neurons). w_dist (str): Recurrent weight distribution ('gaus' or 'gamma'). apply_dale (bool): Apply Dale's principle (True or False). w_out (np.ndarray): Nx1 readout weights. device (torch.device): PyTorch device. Note: Based on the probability (P_inh) provided above, the units in the network are classified into either excitatory or inhibitory. Next, the weight matrix is initialized based on the connectivity probability (P_rec) provided above. """ super(FR_RNN_dale, self).__init__() self.N = N self.P_inh = P_inh self.P_rec = P_rec self.w_in = torch.tensor(w_in, dtype=torch.float32, device=device) self.som_N = som_N self.w_dist = w_dist self.gain = gain self.apply_dale = apply_dale self.device = device # Assign each unit as excitatory or inhibitory inh, exc, NI, NE, som_inh = self.assign_exc_inh() self.inh = inh self.som_inh = som_inh self.exc = exc self.NI = NI self.NE = NE # Initialize the weight matrix W, mask, som_mask = self.initialize_W() # Create learnable parameters self.w = nn.Parameter(torch.tensor(W, dtype=torch.float32, device=device)) self.mask = torch.tensor(mask, dtype=torch.float32, device=device) self.som_mask = torch.tensor(som_mask, dtype=torch.float32, device=device) self.w_out = nn.Parameter(torch.tensor(w_out, dtype=torch.float32, device=device)) self.b_out = nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
[docs] def assign_exc_inh(self) -> Tuple[np.ndarray, np.ndarray, int, int, Union[np.ndarray, int]]: """ Method to randomly assign units as excitatory or inhibitory (Dale's principle). Returns: Tuple[np.ndarray, np.ndarray, int, int, Union[np.ndarray, int]]: Tuple containing: - inh: Boolean array marking which units are inhibitory - exc: Boolean array marking which units are excitatory - NI: Number of inhibitory units - NE: Number of excitatory units - som_inh: Indices of "inh" for SOM neurons """ # Apply Dale's principle if self.apply_dale == True: inh = np.random.rand(self.N, 1) < self.P_inh exc = ~inh NI = len(np.where(inh == True)[0]) NE = self.N - NI # Do NOT apply Dale's principle else: inh = np.random.rand(self.N, 1) < 0 # no separate inhibitory units exc = ~inh NI = len(np.where(inh == True)[0]) NE = self.N - NI if self.som_N > 0: som_inh = np.where(inh==True)[0][:self.som_N] else: som_inh = 0 return inh, exc, NI, NE, som_inh
[docs] def initialize_W(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Method to generate and initialize the connectivity weight matrix, W. The weights are drawn from either gaussian or gamma distribution. Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing: - w: NxN weights (all positive) - mask: NxN matrix of 1's (excitatory units) and -1's (for inhibitory units) - som_mask: NxN mask for SOM connectivity constraints Note: To compute the "full" weight matrix, simply multiply w and mask (i.e. w*mask) """ # Weight matrix w = np.zeros((self.N, self.N), dtype = np.float32) idx = np.where(np.random.rand(self.N, self.N) < self.P_rec) if self.w_dist.lower() == 'gamma': w[idx[0], idx[1]] = np.random.gamma(2, 0.003, len(idx[0])) elif self.w_dist.lower() == 'gaus': w[idx[0], idx[1]] = np.random.normal(0, 1.0, len(idx[0])) w = w/np.sqrt(self.N*self.P_rec)*self.gain # scale by a gain to make it chaotic if self.apply_dale == True: w = np.abs(w) # Mask matrix mask = np.eye(self.N, dtype=np.float32) mask[np.where(self.inh==True)[0], np.where(self.inh==True)[0]] = -1 # SOM mask matrix som_mask = np.ones((self.N, self.N), dtype=np.float32) if self.som_N > 0: for i in self.som_inh: som_mask[i, np.where(self.inh==True)[0]] = 0 return w, mask, som_mask
[docs] def load_net(self, model_dir: str) -> 'FR_RNN_dale': """ Method to load pre-configured network settings. Args: model_dir (str): Path to the model directory containing saved parameters. Returns: FR_RNN_dale: The loaded network instance. """ settings = scipy.io.loadmat(model_dir) self.N = settings['N'][0][0] self.som_N = settings['som_N'][0][0] self.inh = settings['inh'] self.exc = settings['exc'] self.inh = self.inh == 1 self.exc = self.exc == 1 self.NI = len(np.where(settings['inh'] == True)[0]) self.NE = len(np.where(settings['exc'] == True)[0]) # Update parameters self.mask = torch.tensor(settings['m'], dtype=torch.float32, device=self.device) self.som_mask = torch.tensor(settings['som_m'], dtype=torch.float32, device=self.device) self.w.data = torch.tensor(settings['w'], dtype=torch.float32, device=self.device) self.w_in = torch.tensor(settings['w_in'], dtype=torch.float32, device=self.device) self.b_out.data = torch.tensor(settings['b_out'], dtype=torch.float32, device=self.device) self.w_out.data = torch.tensor(settings['w_out'], dtype=torch.float32, device=self.device) return self
[docs] def display(self) -> None: """ Method to print the network setup. """ print('Network Settings') print('====================================') print('Number of Units: ', self.N) print('\t Number of Excitatory Units: ', self.NE) print('\t Number of Inhibitory Units: ', self.NI) print('Weight Matrix, W') full_w = (self.w * self.mask).cpu().numpy() zero_w = len(np.where(full_w == 0)[0]) pos_w = len(np.where(full_w > 0)[0]) neg_w = len(np.where(full_w < 0)[0]) print('\t Zero Weights: %2.2f %%' % (zero_w/(self.N*self.N)*100)) print('\t Positive Weights: %2.2f %%' % (pos_w/(self.N*self.N)*100)) print('\t Negative Weights: %2.2f %%' % (neg_w/(self.N*self.N)*100))
[docs] def forward(self, stim: torch.Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the RNN. Args: stim (torch.Tensor): Input stimulus tensor of shape (input_dim, T). taus (List[float]): Time constants (either single value or [min, max] range). training_params (Dict[str, Any]): Training parameters including activation function. settings (Dict[str, Any]): Task settings including T (trial duration) and DeltaT (sampling rate). Returns: Tuple containing multiple tensors: - stim: Input stimulus tensor - x: List of synaptic current tensors over time - r: List of firing rate tensors over time - o: List of output tensors over time - w: Recurrent weight matrix - w_in: Input weight matrix - mask: Dale's principle mask matrix - som_mask: SOM connectivity mask matrix - w_out: Output weight matrix - b_out: Output bias - taus_gaus: Time constant parameters """ T = settings['T'] DeltaT = settings['DeltaT'] # Initialize taus_gaus for time constants if len(taus) > 1: taus_gaus = torch.randn(self.N, 1, device=self.device, requires_grad=True) else: taus_gaus = torch.randn(self.N, 1, device=self.device, requires_grad=False) # Synaptic currents and firing-rates x = [] r = [] x.append(torch.randn(self.N, 1, device=self.device) / 100) # Initial firing rate with activation function if training_params['activation'] == 'sigmoid': r.append(torch.sigmoid(x[0])) elif training_params['activation'] == 'clipped_relu': r.append(torch.clamp(F.relu(x[0]), 0, 20)) elif training_params['activation'] == 'softplus': r.append(torch.clamp(F.softplus(x[0]), 0, 20)) # Output list o = [] # Forward pass through time for t in range(1, T): if self.apply_dale == True: # Parametrize the weight matrix to enforce exc/inh synaptic currents w_pos = F.relu(self.w) else: w_pos = self.w # Compute effective weight matrix ww = torch.matmul(w_pos, self.mask) ww = ww * self.som_mask # Compute time constants if len(taus) > 1: taus_sig = torch.sigmoid(taus_gaus) * (taus[1] - taus[0]) + taus[0] else: taus_sig = taus[0] # Update synaptic currents next_x = (1 - DeltaT / taus_sig) * x[t-1] + \ (DeltaT / taus_sig) * (torch.matmul(ww, r[t-1]) + \ torch.matmul(self.w_in, stim[:, t-1:t])) + \ torch.randn(self.N, 1, device=self.device) / 10 x.append(next_x) # Apply activation function if training_params['activation'] == 'sigmoid': r.append(torch.sigmoid(next_x)) elif training_params['activation'] == 'clipped_relu': r.append(torch.clamp(F.relu(next_x), 0, 20)) elif training_params['activation'] == 'softplus': r.append(torch.clamp(F.softplus(next_x), 0, 20)) # Compute output next_o = torch.matmul(self.w_out, r[t]) + self.b_out o.append(next_o) return stim, x, r, o, self.w, self.w_in, self.mask, self.som_mask, self.w_out, self.b_out, taus_gaus
''' Task-specific input signals '''
[docs] def generate_input_stim_go_nogo(settings: Dict[str, Any]) -> Tuple[np.ndarray, int]: """ Generate the input stimulus matrix for the Go-NoGo task. Args: settings (Dict[str, Any]): Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - taus: Time-constants (in steps) - DeltaT: Sampling rate Returns: Tuple[np.ndarray, int]: Tuple containing: - u: 1xT stimulus matrix - label: Either 1 (Go trial) or 0 (NoGo trial) """ T = settings['T'] stim_on = settings['stim_on'] stim_dur = settings['stim_dur'] u = np.zeros((1, T)) #+ np.random.randn(1, T) u_lab = np.zeros((2, 1)) if np.random.rand() <= 0.50: u[0, stim_on:stim_on+stim_dur] = 1 label = 1 else: label = 0 return u, label
[docs] def generate_input_stim_xor(settings: Dict[str, Any]) -> Tuple[np.ndarray, str]: """ Generate the input stimulus matrix for the XOR task. Args: settings (Dict[str, Any]): Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - delay: Delay between two stimuli (in steps) - taus: Time-constants (in steps) - DeltaT: Sampling rate Returns: Tuple[np.ndarray, str]: Tuple containing: - u: 2xT stimulus matrix - label: Either 'same' or 'diff' """ T = settings['T'] stim_on = settings['stim_on'] stim_dur = settings['stim_dur'] delay = settings['delay'] # Initialize u u = np.zeros((2, T)) # XOR task labs = [] if np.random.rand() < 0.50: u[0, stim_on:stim_on+stim_dur] = 1 labs.append(1) else: u[0, stim_on:stim_on+stim_dur] = -1 labs.append(-1) if np.random.rand() < 0.50: u[1, stim_on+stim_dur+delay:stim_on+2*stim_dur+delay] = 1 labs.append(1) else: u[1, stim_on+stim_dur+delay:stim_on+2*stim_dur+delay] = -1 labs.append(-1) if np.prod(labs) == 1: label = 'same' else: label = 'diff' return u, label
[docs] def generate_input_stim_mante(settings: Dict[str, Any]) -> Tuple[np.ndarray, int]: """ Generate the input stimulus matrix for the sensory integration task from Mante et al (2013). Args: settings (Dict[str, Any]): Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - DeltaT: Sampling rate Returns: Tuple[np.ndarray, int]: Tuple containing: - u: 4xT stimulus matrix - label: Either +1 or -1 (the correct decision) """ T = settings['T'] stim_on = settings['stim_on'] stim_dur = settings['stim_dur'] # Initialize stimulus u = np.zeros((4, T)) # Color task color_input = 2.5*(np.random.rand()-0.5) # [-1.25, 1.25] motion_input = 2.5*(np.random.rand()-0.5) # [-1.25, 1.25] # Context signal if np.random.rand() < 0.50: # Context = color task u[0, stim_on:stim_on+stim_dur] = 1 # context cue u[1, stim_on:stim_on+stim_dur] = color_input # color input u[2, stim_on:stim_on+stim_dur] = motion_input # motion input (irrelevant in this context) if color_input > 0: label = 1 # choose option 1 else: label = -1 # choose option 2 else: # Context = motion task u[0, stim_on:stim_on+stim_dur] = -1 # context cue u[1, stim_on:stim_on+stim_dur] = color_input # color input (irrelevant in this context) u[2, stim_on:stim_on+stim_dur] = motion_input # motion input if motion_input > 0: label = 1 # choose option 1 else: label = -1 # choose option 2 return u, label
[docs] def generate_target_continuous_go_nogo(settings: Dict[str, Any], label: int) -> np.ndarray: """ Generate the target output signal for the Go-NoGo task. Args: settings (Dict[str, Any]): Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) label (int): Either 1 (Go trial) or 0 (NoGo trial). Returns: np.ndarray: 1xT target signal array. """ T = settings['T'] stim_on = settings['stim_on'] stim_dur = settings['stim_dur'] target = np.zeros((T-1,)) resp_onset = stim_on + stim_dur if label == 1: target[resp_onset:] = 1 else: target[resp_onset:] = 0 return target
[docs] def generate_target_continuous_xor(settings: Dict[str, Any], label: str) -> np.ndarray: """ Generate the target output signal for the XOR task. Args: settings (Dict[str, Any]): Dictionary containing the following keys: - T: Duration of a single trial (in steps) label (str): Either 'same' or 'diff'. Returns: np.ndarray: 1xT target signal array. """ T = settings['T'] target = np.zeros((T-1,)) if label == 'same': target[200:] = 1 elif label == 'diff': target[200:] = -1 return target
[docs] def generate_target_continuous_mante(settings: Dict[str, Any], label: int) -> np.ndarray: """ Generate the target output signal for the sensory integration task from Mante et al (2013). Args: settings (Dict[str, Any]): Dictionary containing the following keys: - T: Duration of a single trial (in steps) label (int): Either +1 or -1 (the correct decision). Returns: np.ndarray: 1xT target signal array. """ T = settings['T'] target = np.zeros((T-1,)) target[-200:] = label return target
[docs] def loss_op(o: List[torch.Tensor], z: np.ndarray, training_params: Dict[str, Any]) -> torch.Tensor: """ Define loss function for training. Args: o (List[torch.Tensor]): List of output values from the network. z (np.ndarray): Target values. training_params (Dict[str, Any]): Dictionary containing training parameters including 'loss_fn' key. Returns: torch.Tensor: Loss function value. """ # Loss function loss = torch.tensor(0.0, requires_grad=True) loss_fn = training_params['loss_fn'] z_tensor = torch.tensor(z, dtype=torch.float32, device=o[0].device) for i in range(len(o)): if loss_fn.lower() == 'l1': loss = loss + torch.norm(o[i].squeeze() - z_tensor[i], p=1) elif loss_fn.lower() == 'l2': loss = loss + (o[i].squeeze() - z_tensor[i])**2 if loss_fn.lower() == 'l2': loss = torch.sqrt(loss) return loss
def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray, device: torch.device) -> Tuple[List[float], List[np.ndarray], List[np.ndarray]]: """ Evaluate a trained PyTorch RNN. Args: net (FR_RNN_dale): Trained FR_RNN_dale model. settings (Dict[str, Any]): Dictionary containing task settings. u (np.ndarray): Stimulus matrix. device (torch.device): PyTorch device. Returns: Tuple[List[float], List[np.ndarray], List[np.ndarray]]: Tuple containing: - o: Output vector (list of floats) - r: Firing rates (list of numpy arrays) - x: Synaptic currents (list of numpy arrays) """ T = settings['T'] DeltaT = settings['DeltaT'] taus = settings['taus'] net.eval() with torch.no_grad(): # Convert input to tensor u_tensor = torch.tensor(u, dtype=torch.float32, device=device) # Initialize x = [] r = [] x.append(torch.randn(net.N, 1, device=device) / 100) r.append(torch.sigmoid(x[0])) # Default to sigmoid o = [] for t in range(1, T): if net.apply_dale: w_pos = F.relu(net.w) else: w_pos = net.w ww = torch.matmul(w_pos, net.mask) ww = ww * net.som_mask if len(taus) > 1: taus_sig = taus[0] # Use first tau for evaluation else: taus_sig = taus[0] next_x = (1 - DeltaT / taus_sig) * x[t-1] + \ (DeltaT / taus_sig) * (torch.matmul(ww, r[t-1]) + \ torch.matmul(net.w_in, u_tensor[:, t-1:t])) + \ torch.randn(net.N, 1, device=device) / 10 x.append(next_x) r.append(torch.sigmoid(next_x)) # Default to sigmoid next_o = torch.matmul(net.w_out, r[t]) + net.b_out o.append(next_o.item()) return o, [r_t.cpu().numpy() for r_t in r], [x_t.cpu().numpy() for x_t in x]