"""
Base classes and interfaces for spiking neural network tasks.
This module defines abstract base classes and concrete implementations for various
cognitive tasks that can be evaluated using spiking neural networks. Each task
provides methods for stimulus generation, evaluation, and performance analysis.
Classes:
AbstractSpikingTask: Base abstract class for all spiking tasks
GoNogoSpikingTask: Go/NoGo task for spiking networks
XORSpikingTask: XOR task for spiking networks
ManteSpikingTask: Context-dependent sensory integration task for spiking networks
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Tuple, Optional, List
import numpy as np
import matplotlib.pyplot as plt
import os
from .abstract import AbstractSpikingRNN
[docs]
class AbstractSpikingTask(ABC):
"""
Abstract base class for spiking neural network tasks.
This class defines the interface for evaluating spiking networks on cognitive tasks.
Each task is responsible for generating stimuli, running evaluations, and analyzing
performance metrics specific to spiking implementations.
"""
def __init__(self, settings: Optional[Dict[str, Any]] = None):
"""
Initialize the spiking task with settings.
Args:
settings (Optional[Dict[str, Any]]): Task-specific settings dictionary.
"""
self.settings = settings or self.get_default_settings()
self.validate_settings()
[docs]
@abstractmethod
def get_default_settings(self) -> Dict[str, Any]:
"""
Get default settings for the task.
Returns:
Dict[str, Any]: Default task settings.
"""
pass
[docs]
@abstractmethod
def validate_settings(self) -> None:
"""
Validate that all required settings are present and valid.
Raises:
ValueError: If required settings are missing or invalid.
"""
pass
[docs]
@abstractmethod
def generate_stimulus(self, trial_type: Optional[str] = None) -> Tuple[np.ndarray, Any]:
"""
Generate input stimulus for the task.
Args:
trial_type (Optional[str]): Specific trial type to generate.
Returns:
Tuple[np.ndarray, Any]: Input stimulus array and label/condition.
"""
pass
[docs]
@abstractmethod
def evaluate_trial(self, spiking_rnn: AbstractSpikingRNN,
stimulus: np.ndarray, label: Any) -> Dict[str, Any]:
"""
Evaluate a single trial on the spiking network.
Args:
spiking_rnn (AbstractSpikingRNN): Spiking network to evaluate.
stimulus (np.ndarray): Input stimulus.
label (Any): Expected label/condition.
Returns:
Dict[str, Any]: Trial evaluation results.
"""
pass
[docs]
def create_plots_directory(self, base_dir: str) -> str:
"""
Create directory for saving plots.
Args:
base_dir (str): Base directory path.
Returns:
str: Path to plots directory.
"""
plot_dir = os.path.join(base_dir, 'plots')
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
return plot_dir
[docs]
def get_sample_trial_types(self) -> List[str]:
"""
Get sample trial types for visualization.
This method should be overridden by concrete task classes to specify
what trial types should be used for generating sample visualizations.
Returns:
List[str]: List of trial type identifiers for this task.
"""
return [] # Default: no specific trial types
[docs]
class GoNogoSpikingTask(AbstractSpikingTask):
"""
Go/NoGo impulse control task for spiking neural networks.
Evaluates the network's ability to respond to "Go" stimuli and withhold responses
to "NoGo" stimuli using spiking implementations.
"""
[docs]
def get_default_settings(self) -> Dict[str, Any]:
"""Get default Go/NoGo task settings."""
return {
'T': 201,
'stim_on': 30,
'stim_dur': 20,
'delay': 10
}
[docs]
def validate_settings(self) -> None:
"""Validate Go/NoGo task settings."""
required_keys = ['T', 'stim_on', 'stim_dur']
for key in required_keys:
if key not in self.settings:
raise ValueError(f"Missing required setting: {key}")
if self.settings['stim_on'] + self.settings['stim_dur'] >= self.settings['T']:
raise ValueError("Stimulus extends beyond trial duration")
[docs]
def get_sample_trial_types(self) -> List[str]:
"""Get sample trial types for Go/NoGo visualization."""
return ['go', 'nogo']
[docs]
def generate_stimulus(self, trial_type: Optional[str] = None) -> Tuple[np.ndarray, str]:
"""
Generate stimulus for Go/NoGo task.
Args:
trial_type (Optional[str]): 'go' or 'nogo' for specific trial types.
Returns:
Tuple[np.ndarray, str]: Stimulus and trial type ('go' or 'nogo').
"""
T = self.settings['T']
stim_on = self.settings['stim_on']
stim_dur = self.settings['stim_dur']
u = np.zeros((1, T))
if trial_type is None:
trial_type = 'go' if np.random.rand() <= 0.5 else 'nogo'
if trial_type == 'go':
u[0, stim_on:stim_on+stim_dur] = 1
return u, trial_type
[docs]
def evaluate_trial(self, spiking_rnn: AbstractSpikingRNN,
stimulus: np.ndarray, label: str) -> Dict[str, Any]:
"""
Evaluate a single Go/NoGo trial.
Args:
spiking_rnn (AbstractSpikingRNN): Spiking network to evaluate.
stimulus (np.ndarray): Input stimulus.
label (str): 'go' or 'nogo'.
Returns:
Dict[str, Any]: Trial results including spikes, output, and performance.
"""
# Simulate the network
stims = {'mode': 'none'}
spikes, voltages, output, params = spiking_rnn.simulate(stimulus, stims)
# Calculate performance metrics
output_mean = np.mean(output)
# Determine if response is correct
if label == 'go':
correct = output_mean > 0.5 # Should respond
else:
correct = output_mean <= 0.5 # Should not respond
return {
'stimulus': stimulus,
'label': label,
'spikes': spikes,
'voltages': voltages,
'output': output,
'output_mean': output_mean,
'correct': correct,
'params': params
}
[docs]
def create_visualization(self, results: List[Dict[str, Any]], save_dir: str) -> None:
"""
Create visualization plots for Go/NoGo task results.
Args:
results (List[Dict[str, Any]]): List of trial results.
save_dir (str): Directory to save plots.
"""
plot_dir = self.create_plots_directory(save_dir)
# Separate Go and NoGo trials
go_results = [r for r in results if r['label'] == 'go']
nogo_results = [r for r in results if r['label'] == 'nogo']
if go_results:
self._plot_spike_raster(go_results[0], 'Go',
os.path.join(plot_dir, 'go_spike_raster.png'))
if nogo_results:
self._plot_spike_raster(nogo_results[0], 'NoGo',
os.path.join(plot_dir, 'nogo_spike_raster.png'))
# Plot output comparison
if go_results and nogo_results:
self._plot_output_comparison(go_results[0], nogo_results[0],
os.path.join(plot_dir, 'network_output.png'))
def _plot_spike_raster(self, result: Dict[str, Any], trial_type: str, filename: str) -> None:
"""Plot spike raster for a trial."""
spikes = result['spikes']
params = result['params']
dt = params['dt']
nt = params['nt']
t = np.arange(dt, dt*(nt+1), dt)[:nt]
plt.figure(figsize=(8, 6))
N = spikes.shape[0]
for neuron_idx in range(N):
curr_spk = spikes[neuron_idx, 9:] # Skip first 9 timepoints
spike_indices = np.where(curr_spk > 0)[0]
if len(spike_indices) > 0:
spike_times = t[9:][spike_indices]
plt.plot(spike_times, np.ones(len(spike_times)) * neuron_idx,
'r.', markersize=4)
plt.xlim([0, 1])
plt.ylim([-5, N+5])
plt.xlabel('Time (s)')
plt.ylabel('Neuron Index')
plt.title(f'{trial_type} Spike Raster')
plt.tight_layout()
plt.savefig(filename)
plt.close()
def _plot_output_comparison(self, go_result: Dict[str, Any],
nogo_result: Dict[str, Any], filename: str) -> None:
"""Plot output comparison between Go and NoGo trials."""
params = go_result['params']
dt = params['dt']
nt = params['nt']
t = np.arange(dt, dt*(nt+1), dt)[:nt]
plt.figure(figsize=(10, 6))
plt.plot(t, nogo_result['output'].flatten(), 'm', linewidth=2, label='NoGo')
plt.plot(t, go_result['output'].flatten(), 'g', linewidth=2, label='Go')
plt.xlabel('Time (s)')
plt.ylabel('Network Output')
plt.legend()
plt.title('Network Output Comparison')
plt.tight_layout()
plt.savefig(filename)
plt.close()
[docs]
class XORSpikingTask(AbstractSpikingTask):
"""
XOR temporal logic task for spiking neural networks.
Evaluates the network's ability to perform XOR logic on temporal sequences
using spiking implementations.
"""
[docs]
def get_default_settings(self) -> Dict[str, Any]:
"""Get default XOR task settings."""
return {
'T': 400,
'stim_on': 50,
'stim_dur': 50,
'delay': 20
}
[docs]
def validate_settings(self) -> None:
"""Validate XOR task settings."""
required_keys = ['T', 'stim_on', 'stim_dur', 'delay']
for key in required_keys:
if key not in self.settings:
raise ValueError(f"Missing required setting: {key}")
total_stim_time = self.settings['stim_on'] + 2 * self.settings['stim_dur'] + self.settings['delay']
if total_stim_time >= self.settings['T']:
raise ValueError("Stimuli extend beyond trial duration")
[docs]
def get_sample_trial_types(self) -> List[str]:
"""Get sample trial types for XOR visualization."""
return ['++', '+-', '-+', '--']
[docs]
def generate_stimulus(self, trial_type: Optional[str] = None) -> Tuple[np.ndarray, str]:
"""
Generate stimulus for XOR task.
Args:
trial_type (Optional[str]): Specific pattern ('++', '+-', '-+', '--').
Returns:
Tuple[np.ndarray, str]: Stimulus and expected output ('same' or 'diff').
"""
T = self.settings['T']
stim_on = self.settings['stim_on']
stim_dur = self.settings['stim_dur']
delay = self.settings['delay']
u = np.zeros((2, T))
if trial_type is None:
# Generate random pattern
patterns = ['++', '+-', '-+', '--']
trial_type = np.random.choice(patterns)
# Parse pattern
first_stim = 1 if trial_type[0] == '+' else -1
second_stim = 1 if trial_type[1] == '+' else -1
u[0, stim_on:stim_on+stim_dur] = first_stim
u[1, stim_on+stim_dur+delay:stim_on+2*stim_dur+delay] = second_stim
# Determine expected output
expected = 'same' if first_stim == second_stim else 'diff'
return u, expected
[docs]
def evaluate_trial(self, spiking_rnn: AbstractSpikingRNN,
stimulus: np.ndarray, label: str) -> Dict[str, Any]:
"""
Evaluate a single XOR trial.
Args:
spiking_rnn (AbstractSpikingRNN): Spiking network to evaluate.
stimulus (np.ndarray): Input stimulus.
label (str): Expected output ('same' or 'diff').
Returns:
Dict[str, Any]: Trial results.
"""
# Simulate the network
stims = {'mode': 'none'}
spikes, voltages, output, params = spiking_rnn.simulate(stimulus, stims)
# Analyze output during decision period
stim_on = self.settings['stim_on']
stim_dur = self.settings['stim_dur']
delay = self.settings['delay']
task_end_T = stim_on + (2 * stim_dur) + delay
target_onset = 10 + task_end_T
target_offset = target_onset + 100
if target_offset <= len(output):
decision_output = np.mean(output[target_onset:target_offset])
else:
decision_output = np.mean(output[-50:]) # Use last 50 time points
# Determine predicted response
predicted = 'same' if decision_output > 0 else 'diff'
correct = predicted == label
return {
'stimulus': stimulus,
'label': label,
'predicted': predicted,
'spikes': spikes,
'voltages': voltages,
'output': output,
'decision_output': decision_output,
'correct': correct,
'params': params
}
def _get_pattern_from_stimulus(self, stimulus: np.ndarray) -> str:
"""Extract pattern from stimulus array."""
stim_on = self.settings['stim_on']
stim_dur = self.settings['stim_dur']
delay = self.settings['delay']
first_val = np.mean(stimulus[0, stim_on:stim_on+stim_dur])
second_val = np.mean(stimulus[1, stim_on+stim_dur+delay:stim_on+2*stim_dur+delay])
first_char = '+' if first_val > 0 else '-'
second_char = '+' if second_val > 0 else '-'
return first_char + second_char
[docs]
class ManteSpikingTask(AbstractSpikingTask):
"""
Context-dependent sensory integration task for spiking neural networks.
Evaluates the network's ability to perform context-dependent decision making
using spiking implementations.
"""
[docs]
def get_default_settings(self) -> Dict[str, Any]:
"""Get default Mante task settings."""
return {
'T': 300,
'stim_on': 50,
'stim_dur': 100
}
[docs]
def validate_settings(self) -> None:
"""Validate Mante task settings."""
required_keys = ['T', 'stim_on', 'stim_dur']
for key in required_keys:
if key not in self.settings:
raise ValueError(f"Missing required setting: {key}")
if self.settings['stim_on'] + self.settings['stim_dur'] >= self.settings['T']:
raise ValueError("Stimulus extends beyond trial duration")
[docs]
def get_sample_trial_types(self) -> List[str]:
"""Get sample trial types for Mante visualization."""
return ['color', 'motion']
[docs]
def generate_stimulus(self, trial_type: Optional[str] = None) -> Tuple[np.ndarray, int]:
"""
Generate stimulus for Mante task.
Args:
trial_type (Optional[str]): 'color' or 'motion' for specific contexts.
Returns:
Tuple[np.ndarray, int]: Stimulus and expected decision (+1 or -1).
"""
T = self.settings['T']
stim_on = self.settings['stim_on']
stim_dur = self.settings['stim_dur']
u = np.zeros((4, T))
# Generate sensory inputs
color_input = 2.5*(np.random.rand()-0.5) # [-1.25, 1.25]
motion_input = 2.5*(np.random.rand()-0.5) # [-1.25, 1.25]
if trial_type is None:
trial_type = 'color' if np.random.rand() < 0.5 else 'motion'
if trial_type == 'color':
u[0, stim_on:stim_on+stim_dur] = 1 # context cue
u[1, stim_on:stim_on+stim_dur] = color_input # color input
u[2, stim_on:stim_on+stim_dur] = motion_input # motion input (irrelevant)
label = 1 if color_input > 0 else -1
else:
u[0, stim_on:stim_on+stim_dur] = -1 # context cue
u[1, stim_on:stim_on+stim_dur] = color_input # color input (irrelevant)
u[2, stim_on:stim_on+stim_dur] = motion_input # motion input
label = 1 if motion_input > 0 else -1
return u, label
[docs]
def evaluate_trial(self, spiking_rnn: AbstractSpikingRNN,
stimulus: np.ndarray, label: int) -> Dict[str, Any]:
"""
Evaluate a single Mante task trial.
Args:
spiking_rnn (AbstractSpikingRNN): Spiking network to evaluate.
stimulus (np.ndarray): Input stimulus.
label (int): Expected decision (+1 or -1).
Returns:
Dict[str, Any]: Trial results.
"""
# Simulate the network
stims = {'mode': 'none'}
spikes, voltages, output, params = spiking_rnn.simulate(stimulus, stims)
# Analyze output during decision period
stim_on = self.settings['stim_on']
stim_dur = self.settings['stim_dur']
decision_start = stim_on + stim_dur
decision_output = np.mean(output[decision_start:])
# Determine predicted decision
predicted = 1 if decision_output > 0 else -1
correct = predicted == label
return {
'stimulus': stimulus,
'label': label,
'predicted': predicted,
'spikes': spikes,
'voltages': voltages,
'output': output,
'decision_output': decision_output,
'correct': correct,
'params': params
}
# Task factory for spiking tasks
[docs]
class SpikingTaskFactory:
"""Factory class for creating spiking task instances."""
_registry = {
'go_nogo': GoNogoSpikingTask,
'xor': XORSpikingTask,
'mante': ManteSpikingTask
}
[docs]
@classmethod
def create_task(cls, task_name: str, settings: Optional[Dict[str, Any]] = None) -> AbstractSpikingTask:
"""
Create a spiking task instance by type.
Args:
task_name (str): Name of task ('go_nogo', 'xor', 'mante').
settings (Optional[Dict[str, Any]]): Task settings.
Returns:
AbstractSpikingTask: Created task instance.
Raises:
ValueError: If task type is not recognized.
"""
if task_name not in cls._registry:
available = list(cls._registry.keys())
raise ValueError(f"Task type '{task_name}' not found. Available types: {available}")
task_class = cls._registry[task_name]
return task_class(settings)
[docs]
@classmethod
def register_task(cls, task_name: str, task_class: type) -> None:
"""
Register a custom task class with the factory.
Args:
task_name (str): Name to register the task under.
task_class (type): Task class that inherits from AbstractSpikingTask.
Raises:
ValueError: If task_class doesn't inherit from AbstractSpikingTask.
"""
if not issubclass(task_class, AbstractSpikingTask):
raise ValueError(f"Task class {task_class.__name__} must inherit from AbstractSpikingTask")
cls._registry[task_name] = task_class
[docs]
@classmethod
def list_available_tasks(cls) -> list:
"""List all available spiking task types."""
return list(cls._registry.keys())