Rate RNN Model#

The main FR_RNN_dale class and task-specific functions for creating stimuli and targets.

Main Model Class#

class rate.model.FR_RNN_dale(N: int, P_inh: float, P_rec: float, w_in: ndarray, som_N: int, w_dist: str, gain: float, apply_dale: bool, w_out: ndarray, device: device)[source]#

Bases: Module

Firing-rate RNN model for excitatory and inhibitory neurons Initialization of the firing-rate model with recurrent connections

assign_exc_inh() Tuple[ndarray, ndarray, int, int, ndarray | int][source]#

Method to randomly assign units as excitatory or inhibitory (Dale’s principle).

Returns:

Tuple containing:
  • inh: Boolean array marking which units are inhibitory

  • exc: Boolean array marking which units are excitatory

  • NI: Number of inhibitory units

  • NE: Number of excitatory units

  • som_inh: Indices of “inh” for SOM neurons

Return type:

Tuple[np.ndarray, np.ndarray, int, int, Union[np.ndarray, int]]

display() None[source]#

Method to print the network setup.

forward(stim: Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) Tuple[Tensor, List[Tensor], List[Tensor], List[Tensor], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]#

Forward pass of the RNN.

Parameters:
  • stim (torch.Tensor) – Input stimulus tensor of shape (input_dim, T).

  • taus (List[float]) – Time constants (either single value or [min, max] range).

  • training_params (Dict[str, Any]) – Training parameters including activation function.

  • settings (Dict[str, Any]) – Task settings including T (trial duration) and DeltaT (sampling rate).

Returns:

  • stim: Input stimulus tensor

  • x: List of synaptic current tensors over time

  • r: List of firing rate tensors over time

  • o: List of output tensors over time

  • w: Recurrent weight matrix

  • w_in: Input weight matrix

  • mask: Dale’s principle mask matrix

  • som_mask: SOM connectivity mask matrix

  • w_out: Output weight matrix

  • b_out: Output bias

  • taus_gaus: Time constant parameters

Return type:

Tuple containing multiple tensors

initialize_W() Tuple[ndarray, ndarray, ndarray][source]#

Method to generate and initialize the connectivity weight matrix, W. The weights are drawn from either gaussian or gamma distribution.

Returns:

Tuple containing:
  • w: NxN weights (all positive)

  • mask: NxN matrix of 1’s (excitatory units) and -1’s (for inhibitory units)

  • som_mask: NxN mask for SOM connectivity constraints

Return type:

Tuple[np.ndarray, np.ndarray, np.ndarray]

Note

To compute the “full” weight matrix, simply multiply w and mask (i.e. w*mask)

lesion_w(lesion_percentage: float, lesion_scale: float = 0.0) None[source]#

Applies lesions to the recurrent weight matrix by setting a percentage of connections to zero.

Parameters:
  • lesion_percentage (float) – Fraction of existing connections to attenuate (0.0 to 1.0).

  • lesion_scale (float) – Multiplicative factor applied to selected weights (0.0 to 1.0).

lesion_w_by_type(lesion_percentage: float, lesion_scale: float = 0.0) None[source]#

Applies lesions by setting an equal percentage of existing connections to zero for each neuron type pairing (E-E, E-I, I-E, I-I).

Parameters:
  • lesion_percentage (float) – Fraction of existing connections (per type) to attenuate (0.0 to 1.0).

  • lesion_scale (float) – Multiplicative factor applied to selected weights (0.0 to 1.0).

load_net(model_dir: str) FR_RNN_dale[source]#

Method to load pre-configured network settings.

Parameters:

model_dir (str) – Path to the model directory containing saved parameters.

Returns:

The loaded network instance.

Return type:

FR_RNN_dale

Core Functions#

rate.model.loss_op(o: List[Tensor], z: ndarray | Tensor, training_params: Dict[str, Any]) Tensor[source]#

Define loss function for training.

Parameters:
  • o (List[torch.Tensor]) – List of output values from the network.

  • z (Union[np.ndarray, torch.Tensor]) – Target values.

  • training_params (Dict[str, Any]) – Dictionary containing training parameters including ‘loss_fn’ key.

Returns:

Loss function value.

Return type:

torch.Tensor

rate.model.eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: ndarray, device: device) Tuple[List[float], List[ndarray], List[ndarray]][source]#

Evaluate a trained PyTorch RNN.

Parameters:
  • net (FR_RNN_dale) – Trained FR_RNN_dale model.

  • settings (Dict[str, Any]) – Dictionary containing task settings.

  • u (np.ndarray) – Stimulus matrix.

  • device (torch.device) – PyTorch device.

Returns:

Tuple containing:
  • o: Output vector (list of floats)

  • r: Firing rates (list of numpy arrays)

  • x: Synaptic currents (list of numpy arrays)

Return type:

Tuple[List[float], List[np.ndarray], List[np.ndarray]]

Overview#

The model module provides the core rate-based RNN implementation with Dale’s principle (separate excitatory and inhibitory neurons).

Key Components:

  • FR_RNN_dale: Main rate RNN class with excitatory/inhibitory neuron types

  • loss_op: Loss function for training rate RNNs

  • eval_rnn: Evaluation function for running trained networks

Note: Task-specific stimulus and target generation functions have been moved to the Tasks module as part of the new task-based architecture. For creating stimuli and targets, use the task classes from rate.tasks instead.