2023-03-28 18:27:06 +00:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
|
2023-03-27 17:31:23 +00:00
|
|
|
import gym
|
|
|
|
import cv2
|
|
|
|
import retro
|
|
|
|
import numpy as np
|
|
|
|
from stable_baselines3 import PPO
|
2023-03-29 17:14:39 +00:00
|
|
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
2023-03-27 17:31:23 +00:00
|
|
|
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
2023-03-28 18:27:06 +00:00
|
|
|
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
2023-03-27 17:31:23 +00:00
|
|
|
|
2023-03-28 07:39:30 +00:00
|
|
|
from custom_cnn import CustomCNN
|
2023-03-29 17:14:39 +00:00
|
|
|
from mobilenet_extractor import MobileNetV3Extractor
|
2023-03-28 07:39:30 +00:00
|
|
|
from custom_sf2_cv_env import StreetFighterCustomWrapper
|
2023-03-27 17:31:23 +00:00
|
|
|
|
2023-03-28 18:27:06 +00:00
|
|
|
class RandomOpponentChangeCallback(BaseCallback):
|
|
|
|
def __init__(self, stages, opponent_interval, save_dir, verbose=0):
|
|
|
|
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
|
|
|
self.stages = stages
|
|
|
|
self.opponent_interval = opponent_interval
|
|
|
|
|
|
|
|
def _on_step(self) -> bool:
|
|
|
|
if self.n_calls % self.opponent_interval == 0:
|
|
|
|
new_state = random.choice(self.stages)
|
|
|
|
print("\nCurrent state:", new_state)
|
|
|
|
self.training_env.env_method("load_state", new_state, indices=None)
|
|
|
|
return True
|
|
|
|
|
2023-03-27 17:31:23 +00:00
|
|
|
def make_env(game, state, seed=0):
|
|
|
|
def _init():
|
2023-03-28 18:27:06 +00:00
|
|
|
win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE)
|
2023-03-28 07:39:30 +00:00
|
|
|
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
2023-03-29 17:14:39 +00:00
|
|
|
env = retro.RetroEnv(
|
|
|
|
game=game,
|
|
|
|
state=state,
|
|
|
|
use_restricted_actions=retro.Actions.FILTERED,
|
|
|
|
obs_type=retro.Observations.IMAGE
|
|
|
|
)
|
2023-03-28 07:39:30 +00:00
|
|
|
env = StreetFighterCustomWrapper(env, win_template, lose_template)
|
2023-03-28 18:27:06 +00:00
|
|
|
# env.seed(seed)
|
2023-03-27 17:31:23 +00:00
|
|
|
return env
|
|
|
|
return _init
|
|
|
|
|
|
|
|
def main():
|
|
|
|
# Set up the environment and model
|
|
|
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
|
|
|
state_stages = [
|
2023-03-29 17:14:39 +00:00
|
|
|
"ChampionX.Level1.ChunLiVsKen",
|
|
|
|
"ChampionX.Level2.ChunLiVsChunLi",
|
|
|
|
"ChampionX.Level3.ChunLiVsZangief",
|
|
|
|
"ChampionX.Level4.ChunLiVsDhalsim",
|
|
|
|
"ChampionX.Level5.ChunLiVsRyu",
|
|
|
|
"ChampionX.Level6.ChunLiVsEHonda",
|
|
|
|
"ChampionX.Level7.ChunLiVsBlanka",
|
|
|
|
"ChampionX.Level8.ChunLiVsGuile",
|
|
|
|
"ChampionX.Level9.ChunLiVsBalrog",
|
|
|
|
"ChampionX.Level10.ChunLiVsVega",
|
|
|
|
"ChampionX.Level11.ChunLiVsSagat",
|
|
|
|
"ChampionX.Level12.ChunLiVsBison"
|
2023-03-27 17:31:23 +00:00
|
|
|
# Add other stages as necessary
|
|
|
|
]
|
2023-03-29 17:14:39 +00:00
|
|
|
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
|
2023-03-27 17:31:23 +00:00
|
|
|
|
|
|
|
num_envs = 8
|
|
|
|
|
2023-03-28 18:27:06 +00:00
|
|
|
# env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
2023-03-27 17:31:23 +00:00
|
|
|
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
|
|
|
|
2023-03-29 17:14:39 +00:00
|
|
|
# Using CustomCNN as the feature extractor
|
|
|
|
# policy_kwargs = {
|
|
|
|
# 'features_extractor_class': CustomCNN
|
|
|
|
# }
|
|
|
|
|
|
|
|
# Using MobileNetV3 as the feature extractor
|
2023-03-27 17:31:23 +00:00
|
|
|
policy_kwargs = {
|
2023-03-29 17:14:39 +00:00
|
|
|
'features_extractor_class': MobileNetV3Extractor
|
2023-03-27 17:31:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
model = PPO(
|
|
|
|
"CnnPolicy",
|
|
|
|
env,
|
|
|
|
device="cuda",
|
|
|
|
policy_kwargs=policy_kwargs,
|
2023-03-28 09:39:01 +00:00
|
|
|
verbose=1,
|
2023-03-28 18:27:06 +00:00
|
|
|
n_steps=5400,
|
2023-03-28 09:39:01 +00:00
|
|
|
batch_size=64,
|
|
|
|
n_epochs=10,
|
|
|
|
learning_rate=0.0003,
|
|
|
|
ent_coef=0.01,
|
|
|
|
clip_range=0.2,
|
|
|
|
clip_range_vf=None,
|
|
|
|
gamma=0.99,
|
|
|
|
gae_lambda=0.95,
|
|
|
|
max_grad_norm=0.5,
|
|
|
|
use_sde=False,
|
|
|
|
sde_sample_freq=-1
|
2023-03-27 17:31:23 +00:00
|
|
|
)
|
|
|
|
|
2023-03-28 18:27:06 +00:00
|
|
|
# Set the save directory
|
2023-03-29 17:14:39 +00:00
|
|
|
save_dir = "trained_models_cv_customcnn_time_penalty"
|
2023-03-28 18:27:06 +00:00
|
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
|
|
|
|
# Set up callbacks
|
|
|
|
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
|
|
|
|
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
|
|
|
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
|
|
|
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
|
|
|
|
|
|
|
|
|
|
|
model.learn(
|
|
|
|
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
|
|
|
callback=[checkpoint_callback, stage_increase_callback]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Save the final model
|
|
|
|
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
|
2023-03-27 17:31:23 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|