Rate RNN Model#
The main FR_RNN_dale
class and task-specific functions for creating stimuli and targets.
Main Model Class#
- class rate.model.FR_RNN_dale(N: int, P_inh: float, P_rec: float, w_in: ndarray, som_N: int, w_dist: str, gain: float, apply_dale: bool, w_out: ndarray, device: device)[source]#
Bases:
Module
Firing-rate RNN model for excitatory and inhibitory neurons Initialization of the firing-rate model with recurrent connections
- assign_exc_inh() Tuple[ndarray, ndarray, int, int, ndarray | int] [source]#
Method to randomly assign units as excitatory or inhibitory (Dale’s principle).
- Returns:
- Tuple containing:
inh: Boolean array marking which units are inhibitory
exc: Boolean array marking which units are excitatory
NI: Number of inhibitory units
NE: Number of excitatory units
som_inh: Indices of “inh” for SOM neurons
- Return type:
Tuple[np.ndarray, np.ndarray, int, int, Union[np.ndarray, int]]
- forward(stim: Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) Tuple[Tensor, List[Tensor], List[Tensor], List[Tensor], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor] [source]#
Forward pass of the RNN.
- Parameters:
stim (torch.Tensor) – Input stimulus tensor of shape (input_dim, T).
taus (List[float]) – Time constants (either single value or [min, max] range).
training_params (Dict[str, Any]) – Training parameters including activation function.
settings (Dict[str, Any]) – Task settings including T (trial duration) and DeltaT (sampling rate).
- Returns:
stim: Input stimulus tensor
x: List of synaptic current tensors over time
r: List of firing rate tensors over time
o: List of output tensors over time
w: Recurrent weight matrix
w_in: Input weight matrix
mask: Dale’s principle mask matrix
som_mask: SOM connectivity mask matrix
w_out: Output weight matrix
b_out: Output bias
taus_gaus: Time constant parameters
- Return type:
Tuple containing multiple tensors
- initialize_W() Tuple[ndarray, ndarray, ndarray] [source]#
Method to generate and initialize the connectivity weight matrix, W. The weights are drawn from either gaussian or gamma distribution.
- Returns:
- Tuple containing:
w: NxN weights (all positive)
mask: NxN matrix of 1’s (excitatory units) and -1’s (for inhibitory units)
som_mask: NxN mask for SOM connectivity constraints
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
Note
To compute the “full” weight matrix, simply multiply w and mask (i.e. w*mask)
- lesion_w(lesion_percentage: float, lesion_scale: float = 0.0) None [source]#
Applies lesions to the recurrent weight matrix by setting a percentage of connections to zero.
- Parameters:
lesion_percentage (float) – Fraction of existing connections to attenuate (0.0 to 1.0).
lesion_scale (float) – Multiplicative factor applied to selected weights (0.0 to 1.0).
- lesion_w_by_type(lesion_percentage: float, lesion_scale: float = 0.0) None [source]#
Applies lesions by setting an equal percentage of existing connections to zero for each neuron type pairing (E-E, E-I, I-E, I-I).
- Parameters:
lesion_percentage (float) – Fraction of existing connections (per type) to attenuate (0.0 to 1.0).
lesion_scale (float) – Multiplicative factor applied to selected weights (0.0 to 1.0).
- load_net(model_dir: str) FR_RNN_dale [source]#
Method to load pre-configured network settings.
- Parameters:
model_dir (str) – Path to the model directory containing saved parameters.
- Returns:
The loaded network instance.
- Return type:
Core Functions#
- rate.model.loss_op(o: List[Tensor], z: ndarray | Tensor, training_params: Dict[str, Any]) Tensor [source]#
Define loss function for training.
- Parameters:
o (List[torch.Tensor]) – List of output values from the network.
z (Union[np.ndarray, torch.Tensor]) – Target values.
training_params (Dict[str, Any]) – Dictionary containing training parameters including ‘loss_fn’ key.
- Returns:
Loss function value.
- Return type:
torch.Tensor
- rate.model.eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: ndarray, device: device) Tuple[List[float], List[ndarray], List[ndarray]] [source]#
Evaluate a trained PyTorch RNN.
- Parameters:
net (FR_RNN_dale) – Trained FR_RNN_dale model.
settings (Dict[str, Any]) – Dictionary containing task settings.
u (np.ndarray) – Stimulus matrix.
device (torch.device) – PyTorch device.
- Returns:
- Tuple containing:
o: Output vector (list of floats)
r: Firing rates (list of numpy arrays)
x: Synaptic currents (list of numpy arrays)
- Return type:
Tuple[List[float], List[np.ndarray], List[np.ndarray]]
Overview#
The model module provides the core rate-based RNN implementation with Dale’s principle (separate excitatory and inhibitory neurons).
Key Components:
FR_RNN_dale: Main rate RNN class with excitatory/inhibitory neuron types
loss_op: Loss function for training rate RNNs
eval_rnn: Evaluation function for running trained networks
Note: Task-specific stimulus and target generation functions have been moved to the Tasks module as part of the new task-based architecture. For creating stimuli and targets, use the task classes from rate.tasks
instead.