import gymnasium as gym
from ns_gym.wrappers import (
NSClassicControlWrapper,
NSBridgeWrapper,
NSCliffWalkingWrapper,
NSFrozenLakeWrapper,
)
from ns_gym.schedulers import ContinuousScheduler
from ns_gym.update_functions import StepWiseUpdate, DistributionStepWiseUpdate
import numpy as np
from stable_baselines3 import PPO
import matplotlib.pyplot as plt
import argparse
import os
from datetime import datetime
[docs]
def make_env_with_context(
env_name, context_value, context_parameter_name="masscart", seed=None
):
"""Creates a ns_gym environment with a specified context parameter value.
Args:
env_name (str): Gymnasium environement name.
context_value (float): The value for the context parameter.
context_parameter_name (str): The name of the parameter to tune.
seed (Optional[int]): Seed for the environment.
Returns:
gym.Env: The configured environment.
Caution:
Not implemented for all environments. Currently supports:
- Classic Control: CartPole-v1, Acrobot-v1, MountainCarContinuous-v0, MountainCar-v0, Pendulum-v1
- Gridworlds: CliffWalking-v0, FrozenLake-v1, ns_gym/Bridge-v0
"""
if env_name in [
"CartPole-v1",
"Acrobot-v1",
"MountainCarContinuous-v0",
"MountainCar-v0",
"Pendulum-v1",
]:
env = gym.make(env_name)
scheduler = ContinuousScheduler(start=0, end=0)
update_fn = StepWiseUpdate(
scheduler, [context_value]
) # Provides the context value
tunable_params = {context_parameter_name: update_fn}
ns_env = NSClassicControlWrapper(
env=env,
tunable_params=tunable_params,
change_notification=True,
delta_change_notification=True,
)
elif env_name in ["CliffWalking-v0", "FrozenLake-v1", "ns_gym/Bridge-v0"]:
env = gym.make(env_name)
scheduler = ContinuousScheduler(start=0, end=0)
update_fn = DistributionStepWiseUpdate(
scheduler, [context_value]
) # Provides the context value
tunable_params = {context_parameter_name: update_fn}
if env_name == "FrozenLake-v1":
ns_env = NSFrozenLakeWrapper(
env,
tunable_params,
change_notification=True,
delta_change_notification=True,
initial_prob_dist=[
context_value,
(1 - context_value) / 2,
(1 - context_value) / 2,
],
)
elif env_name == "CliffWalking-v0":
ns_env = NSCliffWalkingWrapper(
env,
tunable_params,
change_notification=True,
delta_change_notification=True,
initial_prob_dist=[
context_value,
(1 - context_value) / 3,
(1 - context_value) / 3,
(1 - context_value) / 3,
],
)
elif env_name == "ns_gym/Bridge-v0":
ns_env = NSBridgeWrapper(
env,
tunable_params,
change_notification=True,
delta_change_notification=True,
initial_prob_dist=[
context_value,
(1 - context_value) / 2,
(1 - context_value) / 2,
],
)
else:
raise ValueError("Invalid environment")
ns_env = gym.wrappers.TransformObservation(ns_env, lambda obs: obs["state"], ns_env.unwrapped.observation_space)
ns_env = gym.wrappers.TransformReward(ns_env, lambda rew: rew.reward)
if seed is not None:
ns_env.reset(seed=seed)
return ns_env
[docs]
def run_context_episode(agent, ns_env_instance, num_episodes):
"""Runs an StableBaselines3 policy in a given ns_gym environment for a number of episodes.
Args:
agent (StableBaselines3 Policy): The trained agent/policy to evaluate.
ns_env_instance (gym.Env): The ns_gym environment instance.
num_episodes (int): Number of episodes to run.
"""
reward_list = []
for ep in range(num_episodes):
ep_reward = 0.0
done = False
truncated = False
obs, info = ns_env_instance.reset()
if not isinstance(obs, np.ndarray) and hasattr(obs, "state"):
obs = np.array(obs.state, dtype=np.float32)
elif not isinstance(obs, np.ndarray):
obs = np.array(obs, dtype=np.float32)
while not (done or truncated):
action, _states = agent.predict(obs, deterministic=True)
obs, reward, done, truncated, current_info = ns_env_instance.step(action)
if not isinstance(obs, np.ndarray) and hasattr(obs, "state"):
obs = np.array(obs.state, dtype=np.float32)
elif not isinstance(obs, np.ndarray):
obs = np.array(obs, dtype=np.float32)
if not isinstance(reward, (float, int)) and hasattr(reward, "reward"):
reward = float(reward.reward)
elif not isinstance(reward, (float, int)):
reward = float(reward)
ep_reward += reward
reward_list.append(ep_reward)
return np.mean(reward_list), np.std(reward_list)
[docs]
def eval_target_contexts(
policy, make_env_func_partial, num_episodes_per_context, target_context_range
):
"""Evaluates a given policy across a range of target contexts.
Args:
policy: The trained StableBaselines3 agent/policy to evaluate. Should be compatible with StableBaselines3.
make_env_func_partial: A partial function of make_env_with_context (with context_parameter_name fixed).
num_episodes_per_context: How many episodes to run for each target context.
target_context_range: Array of target context values to evaluate on.
Returns:
Array of mean rewards for each target context, Array of std deviations.
"""
mean_rewards = np.zeros(len(target_context_range))
std_rewards = np.zeros(len(target_context_range))
for i, target_ctx_val in enumerate(target_context_range):
eval_env = make_env_func_partial(
context_value=target_ctx_val
) # Pass only context_value
mean_rewards[i], std_rewards[i] = run_context_episode(
policy, eval_env, num_episodes_per_context
)
eval_env.close()
return mean_rewards, std_rewards
[docs]
def normalize_rewards_matrix(U_matrix_raw):
"""Normalizes a reward matrix using min-max scaling (0-1 range)."""
min_val = np.min(U_matrix_raw)
max_val = np.max(U_matrix_raw)
if max_val == min_val:
U_matrix_normalized = np.full_like(U_matrix_raw, 0.5 if min_val != 0 else 0.0)
return U_matrix_normalized, min_val, max_val
U_matrix_normalized = (U_matrix_raw - min_val) / (max_val - min_val)
return U_matrix_normalized, min_val, max_val
[docs]
def calculate_sem(data_array):
"""Calculates the standard error of the mean for a 1D array."""
if len(data_array) < 2:
return 0.0
return np.std(data_array, ddof=1) / np.sqrt(len(data_array))
[docs]
def save_metrics_to_file(
filename,
U_matrix_raw,
peak_performances_raw,
overall_system_performance_raw,
sem_overall_raw,
U_matrix_normalized,
peak_performances_normalized,
overall_system_performance_normalized,
sem_overall_norm,
normalization_params,
source_contexts,
context_range,
args_dict,
):
"""
Saves both raw and normalized metrics, including SEM, to a text file.
"""
try:
with open(filename, "w") as f:
f.write("--- Experiment Configuration ---\n")
for key, value in args_dict.items():
f.write(f"{key}: {value}\n")
f.write("\n--- Source Contexts for Training ---\n")
f.write(f"{source_contexts.tolist()}\n")
f.write("\n--- Target Contexts Evaluated ---\n")
f.write(f"{np.round(context_range, 3).tolist()}\n")
# --- Raw Metrics ---
f.write("\n\n--- RAW METRICS ---\n")
f.write("Overall System Generalized Performance (Raw):\n")
f.write(
f"Mean: {overall_system_performance_raw:.4f}, SEM: {sem_overall_raw:.4f}\n"
)
f.write("\nPeak Performance for each Policy (Raw):\n")
for i, peak_p in enumerate(peak_performances_raw):
f.write(
f" Agent trained on Ctx {source_contexts[i]:.2f}: Max Reward = {peak_p:.4f}\n"
)
f.write("\nPerformance Matrix U (Raw):\n")
for i in range(U_matrix_raw.shape[0]):
row_str = ", ".join([f"{val:.4f}" for val in U_matrix_raw[i, :]])
f.write(f"Agent {i} (Ctx {source_contexts[i]:.2f}): [{row_str}]\n")
# --- Normalized Metrics ---
if (
U_matrix_normalized is not None
and overall_system_performance_normalized is not None
and sem_overall_norm is not None
):
f.write("\n\n--- NORMALIZED METRICS ---\n")
f.write(
f"Normalization Parameters: Min Observed = {normalization_params.get('min_reward_observed', 'N/A'):.4f}, Max Observed = {normalization_params.get('max_reward_observed', 'N/A'):.4f}\n"
)
f.write("Overall System Generalized Performance (Normalized):\n")
f.write(
f"Mean: {overall_system_performance_normalized:.4f}, SEM: {sem_overall_norm:.4f}\n"
)
f.write("\nPeak Performance for each Policy (Normalized):\n")
for i, peak_p in enumerate(peak_performances_normalized): # type: ignore
f.write(
f" Agent trained on Ctx {source_contexts[i]:.2f}: Max Reward = {peak_p:.4f}\n"
)
f.write("\nPerformance Matrix U (Normalized):\n")
for i in range(U_matrix_normalized.shape[0]):
row_str = ", ".join(
[f"{val:.4f}" for val in U_matrix_normalized[i, :]]
)
f.write(f"Agent {i} (Ctx {source_contexts[i]:.2f}): [{row_str}]\n")
print(f"Metrics saved to {filename}")
except Exception as e:
print(f"Error saving metrics to {filename}: {e}")
if __name__ == "__main__":
# Example test context switching experiment code.
parser = argparse.ArgumentParser(
description="Run Model-Based Transfer Learning Evaluation for CartPole."
)
parser.add_argument(
"--timesteps_train",
type=int,
default=30000,
help="Total timesteps to train each agent.",
)
parser.add_argument(
"--episodes_eval",
type=int,
default=20,
help="Number of episodes for evaluation on each target context.",
)
parser.add_argument(
"--context_param",
type=str,
default="masscart",
help="Environment parameter to modify (e.g., 'masscart', 'length').",
)
parser.add_argument(
"--num_target_contexts",
type=int,
default=100,
help="Number of points in the target context range for evaluation.",
)
parser.add_argument(
"--target_context_min",
type=float,
default=0.1,
help="Minimum value for the target context range.",
)
parser.add_argument(
"--target_context_max",
type=float,
default=10.0,
help="Maximum value for the target context range.",
)
parser.add_argument("--env_name", type=str, help="Environment name")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
default_output_dir = f"experiment_results_{timestamp}"
parser.add_argument(
"--output_dir",
type=str,
default=default_output_dir,
help="Directory to save plot and metrics files.",
)
parser.add_argument(
"--plot_file",
type=str,
default="performance_plot.png",
help="Filename for the saved plot (relative to output_dir).",
)
parser.add_argument(
"--metrics_file",
type=str,
default="performance_metrics.txt",
help="Filename for the saved metrics (relative to output_dir).",
)
parser.add_argument(
"--normalize_rewards",
type=bool,
default=True,
help="normalize generalized performance",
)
args = parser.parse_args()
from stable_baselines3.common.env_util import make_vec_env
SOURCE_CONTEXTS = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
# SOURCE_CONTEXTS = np.linspace(0.0025 - 0.001, 0.05,9)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
print(f"Created output directory: {args.output_dir}")
plot_filepath = os.path.join(args.output_dir, args.plot_file)
metrics_filepath = os.path.join(args.output_dir, args.metrics_file)
# Generate context_range based on args
current_context_range = np.linspace(
args.target_context_min, args.target_context_max, args.num_target_contexts
)
context_size = len(current_context_range)
random_integers = np.random.choice(context_size, size=9, replace=False)
SOURCE_CONTEXTS = current_context_range[random_integers]
print("--- Starting Experiment: Model-Based Transfer Learning Evaluation ---")
print("Using arguments:")
for arg, value in vars(args).items():
print(f" {arg}: {value}")
trained_agents = []
# Use functools.partial to pass the fixed context_parameter_name to make_env_with_context
from functools import partial
make_env_partial_fn = partial(
make_env_with_context,
env_name=args.env_name,
context_parameter_name=args.context_param,
)
print(f"\n--- Training {len(SOURCE_CONTEXTS)} Agents ---")
for i, s_ctx in enumerate(SOURCE_CONTEXTS):
print(
f"Training Agent {i + 1}/{len(SOURCE_CONTEXTS)} for source context ({args.context_param}): {s_ctx:.2f}"
)
# Pass the partial function to make_vec_env
# The lambda now only needs to provide context_value and seed
train_env = make_vec_env(
lambda: make_env_partial_fn(context_value=s_ctx, seed=i), n_envs=1
)
model = PPO(
"MlpPolicy",
train_env,
verbose=0,
tensorboard_log=None,
device="auto",
seed=i,
)
# model = DQN(
# "MlpPolicy",
# train_env,
# verbose=0,
# tensorboard_log=None,
# device="auto",
# seed=i
# )
model.learn(total_timesteps=args.timesteps_train, progress_bar=True)
trained_agents.append(model)
train_env.close()
print("--- All Agents Trained ---")
(
U_raw,
peaks_raw,
overall_raw,
sem_raw,
U_norm,
peaks_norm,
overall_norm,
sem_norm,
norm_params,
) = calculate_generalized_performance(
trained_agents,
SOURCE_CONTEXTS,
make_env_partial_fn,
args.episodes_eval,
current_context_range,
normalize=args.normalize_rewards,
)
print("\n--- Results Summary ---")
print(f"Shape of Performance Matrix U (Raw): {U_raw.shape}")
if U_norm is not None:
print(f"Shape of Performance Matrix U (Normalized): {U_norm.shape}")
print("\nPeak Performance for each Policy (Raw):")
for i, peak_p in enumerate(peaks_raw):
print(
f" Agent trained on Ctx {SOURCE_CONTEXTS[i]:.2f}: Max Reward = {peak_p:.2f}"
)
if peaks_norm is not None:
print("\nPeak Performance for each Policy (Normalized):")
for i, peak_p in enumerate(peaks_norm):
print(
f" Agent trained on Ctx {SOURCE_CONTEXTS[i]:.2f}: Max Reward = {peak_p:.2f}"
)
print(
f"\nOverall System Generalized Performance (Paper's V-metric, Raw): Mean = {overall_raw:.2f}, SEM = {sem_raw:.4f}"
)
if overall_norm is not None and sem_norm is not None:
print(
f"Overall System Generalized Performance (Paper's V-metric, Normalized): Mean = {overall_norm:.2f}, SEM = {sem_norm:.4f}"
)
print("\nSaving metrics and plotting performance curves...")
save_metrics_to_file(
metrics_filepath,
U_raw,
peaks_raw,
overall_raw,
sem_raw,
U_norm,
peaks_norm,
overall_norm,
sem_norm,
norm_params,
SOURCE_CONTEXTS,
current_context_range,
vars(args),
)
# Decide which matrix to plot based on normalization flag
if args.normalize_rewards and U_norm is not None:
plot_performance_curves(
U_norm,
current_context_range,
SOURCE_CONTEXTS,
overall_norm,
args.context_param,
True,
plot_filepath,
)
else:
plot_performance_curves(
U_raw,
current_context_range,
SOURCE_CONTEXTS,
overall_raw,
args.context_param,
False,
plot_filepath,
)
print("\n--- Experiment Finished ---")