diff --git a/.gitignore b/.gitignore index 777e661..fc998d5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,7 @@ archives/ images/ data/ -main/logs/monitoring/ \ No newline at end of file +main/logs/monitoring/ +recordings/ + +007* \ No newline at end of file diff --git a/__pycache__/custom_cnn.cpython-38.pyc b/__pycache__/custom_cnn.cpython-38.pyc deleted file mode 100644 index cf6425e..0000000 Binary files a/__pycache__/custom_cnn.cpython-38.pyc and /dev/null differ diff --git a/__pycache__/custom_sf2_cv_env.cpython-38.pyc b/__pycache__/custom_sf2_cv_env.cpython-38.pyc deleted file mode 100644 index 8085902..0000000 Binary files a/__pycache__/custom_sf2_cv_env.cpython-38.pyc and /dev/null differ diff --git a/__pycache__/mobilenet_extractor.cpython-38.pyc b/__pycache__/mobilenet_extractor.cpython-38.pyc deleted file mode 100644 index da42d5b..0000000 Binary files a/__pycache__/mobilenet_extractor.cpython-38.pyc and /dev/null differ diff --git a/main/test.py b/main/test.py index 09e5c5f..53a79fc 100644 --- a/main/test.py +++ b/main/test.py @@ -1,3 +1,4 @@ +import os import time import retro @@ -7,9 +8,13 @@ from street_fighter_custom_wrapper import StreetFighterCustomWrapper RESET_ROUND = False # Reset the round when fight is over. RENDERING = True +RECORDING = True RANDOM_ACTION = False -MODEL_PATH = r"trained_models/ppo_ryu_7000000_steps" +MODEL_DIR = r"trained_models/" +MOVIE_DIR = r"recordings" + +MODEL_NAME = r"ppo_ryu_7000000_steps" def make_env(game, state): def _init(): @@ -28,8 +33,7 @@ env = make_env(game, state="Champion.Level12.RyuVsBison")() # model = PPO("CnnPolicy", env) if not RANDOM_ACTION: - # model.load(MODEL_PATH) - model = PPO.load(MODEL_PATH, env=env) + model = PPO.load(os.path.join(MODEL_DIR, MODEL_NAME), env=env) # obs = env.reset() done = False @@ -40,6 +44,13 @@ num_victory = 0 for _ in range(num_episodes): done = False obs = env.reset() + + if RECORDING: + # Start recording + movie_path = os.path.join(MOVIE_DIR, "{}.bk2".format(MODEL_NAME)) + env.unwrapped.movie = retro.Movie(movie_path, retro.MovieMode.RECORD) + env.unwrapped.movie.step() + total_reward = 0 while not done: @@ -50,11 +61,20 @@ for _ in range(num_episodes): else: action, _states = model.predict(obs) obs, reward, done, info = env.step(action) + + if RECORDING: + # Record the step + env.unwrapped.movie.step() if reward != 0: total_reward += reward print("Reward: {:.3f}, playerHP: {}, enemyHP:{}".format(reward, info['agent_hp'], info['enemy_hp'])) - + + if RECORDING: + # Stop recording + env.unwrapped.movie.close() + del env.unwrapped.movie + if info['enemy_hp'] < 0: print("Victory!") num_victory += 1 @@ -66,4 +86,4 @@ print("Winning rate: {}".format(1.0 * num_victory / num_episodes)) if RANDOM_ACTION: print("Average reward for random action: {}".format(episode_reward_sum/num_episodes)) else: - print("Average reward for {}: {}".format(MODEL_PATH, episode_reward_sum/num_episodes)) \ No newline at end of file + print("Average reward for {}: {}".format(MODEL_NAME, episode_reward_sum/num_episodes)) \ No newline at end of file