Source code for marlax.agents.qvagent
"""
QValueAgent implementation using Q-value functions for MARLAX.
This agent maintains a mapping from global states to scalar Q-values
and selects actions by inferring the best next position.
"""
from marlax.abstracts import Agent
import random
import numpy as np
from collections import defaultdict
from functools import partial
[docs]
class QValueAgent(Agent):
"""
Agent that selects actions based on scalar Q-values per global state.
Attributes:
position (tuple): Agent's (x, y) position in the grid.
actions (list of str): Available actions.
q_table (defaultdict): Maps state_key to a scalar Q-value.
action_map (dict): Maps (dx, dy) offsets to action names.
"""
[docs]
def __init__(self, init_position = None, actions = ['stay', 'up', 'down', 'left', 'right']):
"""
Initialize agent with starting position and possible actions.
Args:
init_position (tuple, optional): The (x, y) starting coordinates. Defaults to None.
actions (list of str, optional): Available actions. Defaults to ['stay', 'up', 'down', 'left', 'right'].
"""
self.position = init_position
self.actions = actions
self.q_table = defaultdict(partial(int, 0))
self.action_map = {
(0, 0):'stay',
(0, -1):'up',
(0, 1):'down',
(-1, 0):'left',
(1, 0):'right',
}
[docs]
def choose(self, possible_states, epsilon=0.1, agent_id = 0):
"""
Select an action with epsilon-greedy exploration.
Args:
possible_states (list): List of global state keys to evaluate.
epsilon (float): Exploration probability. Defaults to 0.1.
agent_id (int): Identifier for this agent. Defaults to 0.
Returns:
str: The action corresponding to the move towards the best state.
"""
if random.random() < epsilon:
return random.choice(self.actions)
else:
# print(possible_states)
max_state = self.get_max_state(possible_states)
current_pos = self.position
next_pos = max_state[0][agent_id]
dx = next_pos[0] - current_pos[0]
dy = next_pos[1] - current_pos[1]
action = self.action_map[(dx, dy)]
return action
[docs]
def get_max_state(self, possible_states):
"""
Identify the state with the highest scalar Q-value.
Args:
possible_states (list): List of global state keys.
Returns:
any: The state_key with the maximal Q-value.
"""
max_state_id = np.argmax([self.q_table[state_key] for state_key in possible_states])
max_state = possible_states[max_state_id]
return max_state
[docs]
def update(self, state_key, action, reward, next_state_key, alpha=0.1, gamma=0.99):
"""
Update the scalar Q-value for a state using a simple update rule.
$$Q(s) <- (1-alpha)*Q(s) + alpha*(reward + gamma*Q(s'))$$
where:
- $Q(s)$ is the current Q-value for the state,
- $Q(s')$ is the Q-value for the next state,
- $alpha$ is the learning rate,
- $gamma$ is the discount factor.
Args:
state_key (any): Current global state key.
action (str): Action taken (unused since value is state-based).
reward (float): Reward received after action.
next_state_key (any): Next global state key.
alpha (float): Learning rate. Defaults to 0.1.
gamma (float): Discount factor. Defaults to 0.99.
"""
self.q_table[state_key] = (
(1 - alpha) * self.q_table[state_key]
+ alpha * (reward + gamma * self.q_table[next_state_key])
)