2023-04-05 15:21:32 +00:00
import os
2023-04-05 13:39:08 +00:00
import time
import retro
from stable_baselines3 import PPO
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
2023-04-06 15:35:08 +00:00
RESET_ROUND = True # Whether to reset the round when fight is over.
RENDERING = True # Whether to render the game screen.
MODEL_NAME = r " ppo_ryu_2500000_steps_updated " # Speicify the model file to load. Model "ppo_ryu_2500000_steps_updated" is capable of beating the final stage (Bison) of the game.
# Model notes:
# ppo_ryu_2000000_steps_updated: Just beginning to overfit state, generalizable but not quite capable.
# ppo_ryu_2500000_steps_updated: Approaching the final overfitted state, cannot dominate first round but partially generalizable. High chance of beating the final stage.
# ppo_ryu_3000000_steps_updated: Near the final overfitted state, almost dominate first round but barely generalizable.
# ppo_ryu_7000000_steps_updated: Overfitted, dominates first round but not generalizable.
2023-04-05 13:39:08 +00:00
RANDOM_ACTION = False
2023-04-06 15:35:08 +00:00
NUM_EPISODES = 30 # Make sure NUM_EPISODES >= 3 if you set RESET_ROUND to False to see the whole final stage game.
2023-04-05 15:21:32 +00:00
MODEL_DIR = r " trained_models/ "
2023-04-05 13:39:08 +00:00
def make_env ( game , state ) :
def _init ( ) :
env = retro . make (
game = game ,
state = state ,
use_restricted_actions = retro . Actions . FILTERED ,
obs_type = retro . Observations . IMAGE
)
env = StreetFighterCustomWrapper ( env , reset_round = RESET_ROUND , rendering = RENDERING )
return env
return _init
game = " StreetFighterIISpecialChampionEdition-Genesis "
env = make_env ( game , state = " Champion.Level12.RyuVsBison " ) ( )
# model = PPO("CnnPolicy", env)
if not RANDOM_ACTION :
2023-04-05 15:21:32 +00:00
model = PPO . load ( os . path . join ( MODEL_DIR , MODEL_NAME ) , env = env )
2023-04-05 13:39:08 +00:00
2023-04-06 15:35:08 +00:00
obs = env . reset ( )
2023-04-05 13:39:08 +00:00
done = False
2023-04-06 15:35:08 +00:00
num_episodes = NUM_EPISODES
2023-04-05 13:39:08 +00:00
episode_reward_sum = 0
num_victory = 0
2023-04-06 15:35:08 +00:00
print ( " \n Fighting Begins! \n " )
2023-04-05 13:39:08 +00:00
for _ in range ( num_episodes ) :
done = False
2023-04-06 15:35:08 +00:00
if RESET_ROUND :
obs = env . reset ( )
2023-04-05 15:21:32 +00:00
2023-04-05 13:39:08 +00:00
total_reward = 0
2023-04-06 15:35:08 +00:00
2023-04-05 13:39:08 +00:00
while not done :
timestamp = time . time ( )
if RANDOM_ACTION :
obs , reward , done , info = env . step ( env . action_space . sample ( ) )
else :
action , _states = model . predict ( obs )
obs , reward , done , info = env . step ( action )
if reward != 0 :
total_reward + = reward
print ( " Reward: {:.3f} , playerHP: {} , enemyHP: {} " . format ( reward , info [ ' agent_hp ' ] , info [ ' enemy_hp ' ] ) )
2023-04-06 15:35:08 +00:00
if info [ ' enemy_hp ' ] < 0 or info [ ' agent_hp ' ] < 0 :
done = True
2023-04-05 15:21:32 +00:00
2023-04-05 13:39:08 +00:00
if info [ ' enemy_hp ' ] < 0 :
print ( " Victory! " )
num_victory + = 1
2023-04-06 15:35:08 +00:00
print ( " Total reward: {} \n " . format ( total_reward ) )
2023-04-05 13:39:08 +00:00
episode_reward_sum + = total_reward
2023-04-06 15:35:08 +00:00
if not RESET_ROUND :
while info [ ' enemy_hp ' ] < 0 or info [ ' agent_hp ' ] < 0 :
# Inter scene transition. Do nothing.
obs , reward , done , info = env . step ( [ 0 ] * 12 )
env . render ( )
2023-04-05 13:39:08 +00:00
env . close ( )
print ( " Winning rate: {} " . format ( 1.0 * num_victory / num_episodes ) )
if RANDOM_ACTION :
print ( " Average reward for random action: {} " . format ( episode_reward_sum / num_episodes ) )
else :
2023-04-05 15:21:32 +00:00
print ( " Average reward for {} : {} " . format ( MODEL_NAME , episode_reward_sum / num_episodes ) )