Rate RNN Model#

The main FR_RNN_dale class and task-specific functions for creating stimuli and targets.

Main Model Class#

class rate.model.FR_RNN_dale(N: int, P_inh: float, P_rec: float, w_in: ndarray, som_N: int, w_dist: str, gain: float, apply_dale: bool, w_out: ndarray, device: device)[source]

Bases: Module

Firing-rate RNN model for excitatory and inhibitory neurons Initialization of the firing-rate model with recurrent connections

assign_exc_inh() Tuple[ndarray, ndarray, int, int, ndarray | int][source]

Method to randomly assign units as excitatory or inhibitory (Dale’s principle).

Returns:

Tuple containing:
  • inh: Boolean array marking which units are inhibitory

  • exc: Boolean array marking which units are excitatory

  • NI: Number of inhibitory units

  • NE: Number of excitatory units

  • som_inh: Indices of “inh” for SOM neurons

Return type:

Tuple[np.ndarray, np.ndarray, int, int, Union[np.ndarray, int]]

display() None[source]

Method to print the network setup.

forward(stim: Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) Tuple[Tensor, List[Tensor], List[Tensor], List[Tensor], Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor][source]

Forward pass of the RNN.

Parameters:
  • stim (torch.Tensor) – Input stimulus tensor of shape (input_dim, T).

  • taus (List[float]) – Time constants (either single value or [min, max] range).

  • training_params (Dict[str, Any]) – Training parameters including activation function.

  • settings (Dict[str, Any]) – Task settings including T (trial duration) and DeltaT (sampling rate).

Returns:

  • stim: Input stimulus tensor

  • x: List of synaptic current tensors over time

  • r: List of firing rate tensors over time

  • o: List of output tensors over time

  • w: Recurrent weight matrix

  • w_in: Input weight matrix

  • mask: Dale’s principle mask matrix

  • som_mask: SOM connectivity mask matrix

  • w_out: Output weight matrix

  • b_out: Output bias

  • taus_gaus: Time constant parameters

Return type:

Tuple containing multiple tensors

initialize_W() Tuple[ndarray, ndarray, ndarray][source]

Method to generate and initialize the connectivity weight matrix, W. The weights are drawn from either gaussian or gamma distribution.

Returns:

Tuple containing:
  • w: NxN weights (all positive)

  • mask: NxN matrix of 1’s (excitatory units) and -1’s (for inhibitory units)

  • som_mask: NxN mask for SOM connectivity constraints

Return type:

Tuple[np.ndarray, np.ndarray, np.ndarray]

Note

To compute the “full” weight matrix, simply multiply w and mask (i.e. w*mask)

load_net(model_dir: str) FR_RNN_dale[source]

Method to load pre-configured network settings.

Parameters:

model_dir (str) – Path to the model directory containing saved parameters.

Returns:

The loaded network instance.

Return type:

FR_RNN_dale

Task Functions#

XOR Task#

rate.model.generate_input_stim_xor(settings: Dict[str, Any]) Tuple[ndarray, str][source]

Generate the input stimulus matrix for the XOR task.

Parameters:

settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - delay: Delay between two stimuli (in steps) - taus: Time-constants (in steps) - DeltaT: Sampling rate

Returns:

Tuple containing:
  • u: 2xT stimulus matrix

  • label: Either ‘same’ or ‘diff’

Return type:

Tuple[np.ndarray, str]

rate.model.generate_target_continuous_xor(settings: Dict[str, Any], label: str) ndarray[source]

Generate the target output signal for the XOR task.

Parameters:
  • settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps)

  • label (str) – Either ‘same’ or ‘diff’.

Returns:

1xT target signal array.

Return type:

np.ndarray

Mante Task#

rate.model.generate_input_stim_mante(settings: Dict[str, Any]) Tuple[ndarray, int][source]

Generate the input stimulus matrix for the sensory integration task from Mante et al (2013).

Parameters:

settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - DeltaT: Sampling rate

Returns:

Tuple containing:
  • u: 4xT stimulus matrix

  • label: Either +1 or -1 (the correct decision)

Return type:

Tuple[np.ndarray, int]

rate.model.generate_target_continuous_mante(settings: Dict[str, Any], label: int) ndarray[source]

Generate the target output signal for the sensory integration task from Mante et al (2013).

Parameters:
  • settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps)

  • label (int) – Either +1 or -1 (the correct decision).

Returns:

1xT target signal array.

Return type:

np.ndarray

Go-NoGo Task#

rate.model.generate_input_stim_go_nogo(settings: Dict[str, Any]) Tuple[ndarray, int][source]

Generate the input stimulus matrix for the Go-NoGo task.

Parameters:

settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps) - taus: Time-constants (in steps) - DeltaT: Sampling rate

Returns:

Tuple containing:
  • u: 1xT stimulus matrix

  • label: Either 1 (Go trial) or 0 (NoGo trial)

Return type:

Tuple[np.ndarray, int]

rate.model.generate_target_continuous_go_nogo(settings: Dict[str, Any], label: int) ndarray[source]

Generate the target output signal for the Go-NoGo task.

Parameters:
  • settings (Dict[str, Any]) – Dictionary containing the following keys: - T: Duration of a single trial (in steps) - stim_on: Stimulus starting time (in steps) - stim_dur: Stimulus duration (in steps)

  • label (int) – Either 1 (Go trial) or 0 (NoGo trial).

Returns:

1xT target signal array.

Return type:

np.ndarray

Loss Function#

rate.model.loss_op(o: List[Tensor], z: ndarray, training_params: Dict[str, Any]) Tensor[source]

Define loss function for training.

Parameters:
  • o (List[torch.Tensor]) – List of output values from the network.

  • z (np.ndarray) – Target values.

  • training_params (Dict[str, Any]) – Dictionary containing training parameters including ‘loss_fn’ key.

Returns:

Loss function value.

Return type:

torch.Tensor