Rate RNN Model#
The main FR_RNN_dale class and task-specific functions for creating stimuli and targets.
Main Model Class#
- class rate.model.FR_RNN_dale(N: int, P_inh: float, P_rec: float, w_in: ndarray, som_N: int, w_dist: str, gain: float, apply_dale: bool, w_out: ndarray, device: device)[source]
Bases:
Module
Firing-rate RNN model for excitatory and inhibitory neurons Initialization of the firing-rate model with recurrent connections
- assign_exc_inh() Tuple[ndarray, ndarray, int, int, ndarray | int] [source]
Method to randomly assign units as excitatory or inhibitory (Dale’s principle).
- Returns:
- Tuple containing:
inh: Boolean array marking which units are inhibitory
exc: Boolean array marking which units are excitatory
NI: Number of inhibitory units
NE: Number of excitatory units
som_inh: Indices of “inh” for SOM neurons
- Return type:
Tuple[np.ndarray, np.ndarray, int, int, Union[np.ndarray, int]]
- display() None [source]
Method to print the network setup.
- forward(stim: Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) Tuple[Tensor, List[Tensor], List[Tensor], List[Tensor], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] [source]
Forward pass of the RNN.
- Parameters:
stim (torch.Tensor) – Input stimulus tensor of shape (input_dim, T).
taus (List[float]) – Time constants (either single value or [min, max] range).
training_params (Dict[str, Any]) – Training parameters including activation function.
settings (Dict[str, Any]) – Task settings including T (trial duration) and DeltaT (sampling rate).
- Returns:
stim: Input stimulus tensor
x: List of synaptic current tensors over time
r: List of firing rate tensors over time
o: List of output tensors over time
w: Recurrent weight matrix
w_in: Input weight matrix
mask: Dale’s principle mask matrix
som_mask: SOM connectivity mask matrix
w_out: Output weight matrix
b_out: Output bias
taus_gaus: Time constant parameters
- Return type:
Tuple containing multiple tensors
- initialize_W() Tuple[ndarray, ndarray, ndarray] [source]
Method to generate and initialize the connectivity weight matrix, W. The weights are drawn from either gaussian or gamma distribution.
- Returns:
- Tuple containing:
w: NxN weights (all positive)
mask: NxN matrix of 1’s (excitatory units) and -1’s (for inhibitory units)
som_mask: NxN mask for SOM connectivity constraints
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
Note
To compute the “full” weight matrix, simply multiply w and mask (i.e. w*mask)
- load_net(model_dir: str) FR_RNN_dale [source]
Method to load pre-configured network settings.
- Parameters:
model_dir (str) – Path to the model directory containing saved parameters.
- Returns:
The loaded network instance.
- Return type:
FR_RNN_dale
Task Functions#
XOR Task#
- rate.model.generate_input_stim_xor(settings: Dict[str, Any]) Tuple[ndarray, str] [source]
Generate the input stimulus matrix for the XOR task.
- Parameters:
settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - delay: Delay between two stimuli (in steps) - taus: Time-constants (in steps) - DeltaT: Sampling rate
- Returns:
- Tuple containing:
u: 2xT stimulus matrix
label: Either ‘same’ or ‘diff’
- Return type:
Tuple[np.ndarray, str]
- rate.model.generate_target_continuous_xor(settings: Dict[str, Any], label: str) ndarray [source]
Generate the target output signal for the XOR task.
- Parameters:
settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps)
label (str) – Either ‘same’ or ‘diff’.
- Returns:
1xT target signal array.
- Return type:
np.ndarray
Mante Task#
- rate.model.generate_input_stim_mante(settings: Dict[str, Any]) Tuple[ndarray, int] [source]
Generate the input stimulus matrix for the sensory integration task from Mante et al (2013).
- Parameters:
settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - DeltaT: Sampling rate
- Returns:
- Tuple containing:
u: 4xT stimulus matrix
label: Either +1 or -1 (the correct decision)
- Return type:
Tuple[np.ndarray, int]
- rate.model.generate_target_continuous_mante(settings: Dict[str, Any], label: int) ndarray [source]
Generate the target output signal for the sensory integration task from Mante et al (2013).
- Parameters:
settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps)
label (int) – Either +1 or -1 (the correct decision).
- Returns:
1xT target signal array.
- Return type:
np.ndarray
Go-NoGo Task#
- rate.model.generate_input_stim_go_nogo(settings: Dict[str, Any]) Tuple[ndarray, int] [source]
Generate the input stimulus matrix for the Go-NoGo task.
- Parameters:
settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - taus: Time-constants (in steps) - DeltaT: Sampling rate
- Returns:
- Tuple containing:
u: 1xT stimulus matrix
label: Either 1 (Go trial) or 0 (NoGo trial)
- Return type:
Tuple[np.ndarray, int]
- rate.model.generate_target_continuous_go_nogo(settings: Dict[str, Any], label: int) ndarray [source]
Generate the target output signal for the Go-NoGo task.
- Parameters:
settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps)
label (int) – Either 1 (Go trial) or 0 (NoGo trial).
- Returns:
1xT target signal array.
- Return type:
np.ndarray
Loss Function#
- rate.model.loss_op(o: List[Tensor], z: ndarray, training_params: Dict[str, Any]) Tensor [source]
Define loss function for training.
- Parameters:
o (List[torch.Tensor]) – List of output values from the network.
z (np.ndarray) – Target values.
training_params (Dict[str, Any]) – Dictionary containing training parameters including ‘loss_fn’ key.
- Returns:
Loss function value.
- Return type:
torch.Tensor