Source code for marlax.tracers
"""
Tracer for logging environment and agent data in MARLAX.
Handles buffering of frame-wise metrics and exports to Parquet files,
plus serialization of agent Q-tables.
"""
import shutil
import os
import pyarrow as pa
import pyarrow.parquet as pq
import pandas as pd
import pickle
[docs]
class Tracer:
"""
Records and manages simulation logs and agent exports.
Attributes:
log_path (str): Base directory for log storage.
logger_active (bool): Flag indicating logging activity.
log_buffer (list): Buffered rows awaiting flush.
flush_every (int): Number of frames before auto-flush.
regime_idx (int): Identifier for current regime.
log_filename (str): Path to the active Parquet file.
parquet_writer (pq.ParquetWriter): Writer for Parquet output.
"""
[docs]
def __init__(self, log_path):
"""
Initialize the tracer.
Args:
log_path (str): Path to the log directory.
"""
# Logger attributes
self.logger_active = False
self.log_buffer = []
self.flush_every = None
self.regime_idx = None
self.log_filename = None
self.parquet_writer = None
self.log_path = log_path
# Remove folder if it exists
if os.path.exists(self.log_path):
shutil.rmtree(self.log_path)
def _init_logger(self, flush_every, regime_idx, who="training"):
"""
Initialize the logger that appends rows to a single Parquet file.
Args:
flush_every (int): Number of frames to buffer before writing to disk.
regime_idx (int): Regime identifier that is recorded with each row.
who (str): Label to differentiate logs (e.g. 'training' or 'test').
"""
# Ensure the log directory exists.
os.makedirs(self.log_path+"/logs", exist_ok=True)
# Create the log filename.
self.log_filename = os.path.join(self.log_path+"/logs", f"{who}_{regime_idx}.parquet")
self.flush_every = flush_every
self.regime_idx = regime_idx
self.log_buffer = []
self.parquet_writer = None # We'll initialize this on the first flush
self.logger_active = True
def _flush_buffer(self):
"""
Write buffered log rows to the Parquet file.
Uses PyArrow to append table chunks.
"""
if not self.log_buffer:
return
# Convert the buffer to a pandas DataFrame.
df = pd.DataFrame(self.log_buffer)
# Convert the DataFrame to a PyArrow Table.
table = pa.Table.from_pandas(df)
if self.parquet_writer is None:
# Create a ParquetWriter with the schema inferred from the first flush.
self.parquet_writer = pq.ParquetWriter(self.log_filename, table.schema)
self.parquet_writer.write_table(table)
self.log_buffer = []
def _log_frame(self, step, next_state, rewards, info):
"""
Buffer a new frame record and auto-flush if needed.
Args:
step (int): Frame index.
next_state (tuple): (agent_positions, reward_loc).
rewards (list[float]): Reward values per agent.
info (dict): Diagnostic info, expects keys 'activated', 'collected', 'terminated', 'steps_without_reward'.
"""
# Expect next_state to be a tuple: (agent_states, reward_loc)
agent_states, reward_loc = next_state
row = {
"regime_idx": self.regime_idx,
"frame_idx": step,
"reward_loc": reward_loc,
"activated": info["activated"],
"collected": info["collected"],
"terminated": info["terminated"],
"steps_without_reward": info["steps_without_reward"]
}
# Add agent coordinates.
for i, (x, y) in enumerate(agent_states):
row[f"a{i+1}x"] = x
row[f"a{i+1}y"] = y
# Add rewards.
for i, reward in enumerate(rewards):
row[f"r{i+1}"] = reward
self.log_buffer.append(row)
if len(self.log_buffer) >= self.flush_every:
self._flush_buffer()
def _flush_logger(self):
"""
Flush any remaining buffered rows and close the Parquet writer.
"""
if self.logger_active:
if self.log_buffer:
self._flush_buffer()
if self.parquet_writer is not None:
self.parquet_writer.close()
self.logger_active = False
[docs]
def export_agents(self, env):
"""
Serialize each agent's Q-table to individual pickle files.
Args:
env (Environment): Environment whose agents will be dumped.
"""
os.makedirs(self.log_path+"/qvals", exist_ok=True)
for idx, agent in enumerate(env.agents):
filename = f"{self.log_path}/qvals/agent_{idx}.pkl"
with open(filename, "wb") as file:
pickle.dump(agent.q_table, file)
[docs]
def import_agents(self, agent):
"""
Load agent Q-tables from pickle files and reinstantiate agents.
Args:
agent_cls (callable): Constructor for the agent class to create instances.
Returns:
list: Agent instances with loaded Q-tables.
"""
all_agents = []
# load all the pkl files in the folder
for filename in os.listdir(self.log_path+"/qvals"):
if filename.endswith(".pkl"):
with open(f"{self.log_path}/qvals/{filename}", "rb") as file:
a = agent()
a.q_table = pickle.load(file)
all_agents.append(a)
return all_agents