From 45f8e4649a6dc2671110f4d0ee09d82955d3702b Mon Sep 17 00:00:00 2001 From: linyiLYi <48440925+linyiLYi@users.noreply.github.com> Date: Thu, 6 Apr 2023 09:43:00 +0800 Subject: [PATCH] requirements.txt --- main/requirements.txt | 5 ++++ main/tune_ppo.py | 68 ------------------------------------------- 2 files changed, 5 insertions(+), 68 deletions(-) create mode 100644 main/requirements.txt delete mode 100644 main/tune_ppo.py diff --git a/main/requirements.txt b/main/requirements.txt new file mode 100644 index 0000000..d2d0ea6 --- /dev/null +++ b/main/requirements.txt @@ -0,0 +1,5 @@ +gym==0.18.3 +gym-retro==0.8.0 +opencv-python==4.7.0.72 +stable-baselines3==1.1.0 +tensorboard==2.12.1 diff --git a/main/tune_ppo.py b/main/tune_ppo.py deleted file mode 100644 index f2d3184..0000000 --- a/main/tune_ppo.py +++ /dev/null @@ -1,68 +0,0 @@ -import os - -import retro -import optuna -from stable_baselines3 import PPO -from stable_baselines3.common.monitor import Monitor -from stable_baselines3.common.evaluation import evaluate_policy - -from street_fighter_custom_wrapper import StreetFighterCustomWrapper - -LOG_DIR = 'logs/' -OPT_DIR = 'optuna/' -os.makedirs(LOG_DIR, exist_ok=True) -os.makedirs(OPT_DIR, exist_ok=True) - -def optimize_ppo(trial): - return { - 'n_steps':trial.suggest_int('n_steps', 512, 2048, log=True), - 'gamma':trial.suggest_float('gamma', 0.9, 0.9999), - 'learning_rate':trial.suggest_float('learning_rate', 5e-5, 5e-4, log=True), - 'gae_lambda':trial.suggest_float('gae_lambda', 0.8, 0.9999) - } - -def make_env(game, state): - def _init(): - env = retro.make( - game=game, - state=state, - use_restricted_actions=retro.Actions.FILTERED, - obs_type=retro.Observations.IMAGE - ) - env = StreetFighterCustomWrapper(env) - return env - return _init - -def optimize_agent(trial): - game = "StreetFighterIISpecialChampionEdition-Genesis" - state = "Champion.Level1.ChunLiVsGuile"#"ChampionX.Level1.ChunLiVsKen" - - try: - model_params = optimize_ppo(trial) - - # Create environment - env = make_env(game, state)() - env = Monitor(env, LOG_DIR) - - # Create algo - model = PPO('CnnPolicy', env, verbose=1, **model_params) - model.learn(total_timesteps=500000) - - # Evaluate model - mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=30, deterministic=False) - env.close() - - SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number)) - model.save(SAVE_PATH) - - return mean_reward - - except Exception as e: - return -1 - -# Creating the experiment -study = optuna.create_study(direction='maximize') -study.optimize(optimize_agent, n_trials=10, n_jobs=1) - -print(study.best_params) -print(study.best_trial)