street-fighter-ai/000_image_stack_ram_based_reward/tune.py

82 lines
2.5 KiB
Python
Raw Normal View History

2023-03-30 18:10:25 +00:00
import gym
import retro
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state, seed=0):
def _init():
env = retro.RetroEnv(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env = Monitor(env)
env.seed(seed)
return env
return _init
def objective(trial):
game = "StreetFighterIISpecialChampionEdition-Genesis"
env = make_env(game, state="ChampionX.Level1.ChunLiVsKen")()
# Suggest hyperparameters
learning_rate = trial.suggest_float("learning_rate", 5e-5, 1e-3, log=True)
n_steps = trial.suggest_int("n_steps", 256, 8192, log=True)
batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
gamma = trial.suggest_float("gamma", 0.9, 0.9999)
gae_lambda = trial.suggest_float("gae_lambda", 0.9, 1.0)
clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
ent_coef = trial.suggest_float("ent_coef", 1e-4, 1e-2, log=True)
vf_coef = trial.suggest_float("vf_coef", 0.1, 1.0)
# Using CustomCNN as the feature extractor
policy_kwargs = {
'features_extractor_class': CustomCNN
}
# Train the model
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=n_steps,
batch_size=batch_size,
learning_rate=learning_rate,
ent_coef=ent_coef,
clip_range=clip_range,
vf_coef=vf_coef,
gamma=gamma,
gae_lambda=gae_lambda
)
for iteration in range(10):
model.learn(total_timesteps=100000)
mean_reward, _std_reward = evaluate_policy(model, env, n_eval_episodes=10)
trial.report(mean_reward, iteration)
if trial.should_prune():
raise optuna.TrialPruned()
return mean_reward
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, timeout=7200) # Run optimization for 100 trials or 2 hours, whichever comes first
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(f"{key}: {value}")