street-fighter-ai/000_image_stack_ram_based_reward/train.py

126 lines
4.1 KiB
Python
Raw Normal View History

2023-03-29 17:14:39 +00:00
import os
import random
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
2023-03-30 18:10:25 +00:00
from rmsprop_optim import RMSpropTF
2023-03-29 17:14:39 +00:00
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
class RandomOpponentChangeCallback(BaseCallback):
def __init__(self, stages, opponent_interval, verbose=0):
super(RandomOpponentChangeCallback, self).__init__(verbose)
self.stages = stages
self.opponent_interval = opponent_interval
def _on_step(self) -> bool:
if self.n_calls % self.opponent_interval == 0:
new_state = random.choice(self.stages)
print("\nCurrent state:", new_state)
self.training_env.env_method("load_state", new_state, indices=None)
return True
def make_env(game, state, seed=0):
def _init():
2023-03-30 18:10:25 +00:00
env = retro.make(
2023-03-29 17:14:39 +00:00
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env.seed(seed)
return env
return _init
def main():
# Set up the environment and model
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"ChampionX.Level1.ChunLiVsKen",
"ChampionX.Level2.ChunLiVsChunLi",
"ChampionX.Level3.ChunLiVsZangief",
"ChampionX.Level4.ChunLiVsDhalsim",
"ChampionX.Level5.ChunLiVsRyu",
"ChampionX.Level6.ChunLiVsEHonda",
"ChampionX.Level7.ChunLiVsBlanka",
"ChampionX.Level8.ChunLiVsGuile",
"ChampionX.Level9.ChunLiVsBalrog",
"ChampionX.Level10.ChunLiVsVega",
"ChampionX.Level11.ChunLiVsSagat",
"ChampionX.Level12.ChunLiVsBison"
# Add other stages as necessary
]
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
num_envs = 8
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
# Using CustomCNN as the feature extractor
policy_kwargs = {
'features_extractor_class': CustomCNN
}
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=5400,
batch_size=64,
2023-03-30 18:10:25 +00:00
learning_rate=0.0001,
2023-03-29 17:14:39 +00:00
ent_coef=0.01,
clip_range=0.2,
gamma=0.99,
gae_lambda=0.95,
2023-03-30 18:10:25 +00:00
tensorboard_log="logs/"
2023-03-29 17:14:39 +00:00
)
# Set the save directory
2023-03-30 18:10:25 +00:00
save_dir = "trained_models"
2023-03-29 17:14:39 +00:00
os.makedirs(save_dir, exist_ok=True)
2023-03-30 18:10:25 +00:00
# Load the model from file
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
2023-03-29 17:14:39 +00:00
# Load model and modify the learning rate and entropy coefficient
2023-03-30 18:10:25 +00:00
# custom_objects = {
# "learning_rate": 0.0002
# }
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
2023-03-29 17:14:39 +00:00
# Set up callbacks
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
2023-03-30 18:10:25 +00:00
# model_params = {
# 'n_steps': 5,
# 'gamma': 0.99,
# 'gae_lambda':1,
# 'learning_rate': 7e-4,
# 'vf_coef': 0.5,
# 'ent_coef': 0.0,
# 'max_grad_norm':0.5,
# 'rms_prop_eps':1e-05
# }
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
2023-03-29 17:14:39 +00:00
model.learn(
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback, stage_increase_callback]
)
2023-03-30 18:10:25 +00:00
env.close()
2023-03-29 17:14:39 +00:00
# Save the final model
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
if __name__ == "__main__":
main()