mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-03 22:50:43 +00:00
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
import os
|
|
|
|
import gym
|
|
import torch
|
|
import retro
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
|
|
from stable_baselines3.common.env_util import make_vec_env
|
|
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
|
|
|
def import_rom(rom_directory):
|
|
retro.data.Integrations.add_custom_path(rom_directory)
|
|
os.system(f'python -m retro.import "{rom_directory}"')
|
|
|
|
class StageIncreaseCallback(BaseCallback):
|
|
def __init__(self, stages, stage_interval, save_dir, verbose=0):
|
|
super(StageIncreaseCallback, self).__init__(verbose)
|
|
self.stages = stages
|
|
self.stage_interval = stage_interval
|
|
self.save_dir = save_dir
|
|
self.current_stage = 0
|
|
|
|
def _on_step(self) -> bool:
|
|
if self.n_calls % self.stage_interval == 0 and self.current_stage < len(self.stages) - 1:
|
|
self.current_stage += 1
|
|
new_state = self.stages[self.current_stage]
|
|
self.training_env.env_method("load_state", new_state, indices=None)
|
|
self.model.save(os.path.join(self.save_dir, f"ppo_chunli_stage_{self.current_stage}.zip"))
|
|
return True
|
|
|
|
def make_env(game, state, seed=0):
|
|
def _init():
|
|
env = retro.RetroEnv(game=game, state=state)
|
|
env.seed(seed)
|
|
return env
|
|
return _init
|
|
|
|
def main():
|
|
n_envs = 4
|
|
|
|
# Set up the environment and model
|
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
|
state_stages = [
|
|
"Champion.Level1.ChunLiVsGuile",
|
|
"Champion.Level2.ChunLiVsKen",
|
|
"Champion.Level3.ChunLiVsChunLi",
|
|
"Champion.Level4.ChunLiVsZangief",
|
|
# Add other stages as necessary
|
|
]
|
|
|
|
rom_directory = "C:/Users/unitec/Documents/AIProjects/street-fighter-ai"
|
|
import_rom(rom_directory)
|
|
|
|
# Create the environment with the correct game ID and scenario
|
|
# env = retro.RetroEnv(game='StreetFighterIISpecialChampionEdition-Genesis', state='Champion.Level1.ChunLiVsGuile')
|
|
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(n_envs)])
|
|
|
|
# env = DummyVecEnv([lambda: env])
|
|
|
|
# Define PPO model
|
|
model = PPO(
|
|
"CnnPolicy",
|
|
env,
|
|
verbose=1,
|
|
device="cuda",
|
|
learning_rate=2.5e-4,
|
|
n_steps=5400,
|
|
batch_size=96,
|
|
n_epochs=10,
|
|
gamma=0.99,
|
|
gae_lambda=0.95,
|
|
clip_range=0.2,
|
|
ent_coef=0.01,
|
|
vf_coef=0.5,
|
|
max_grad_norm=0.5,
|
|
seed=None,
|
|
)
|
|
|
|
|
|
checkpoint_path = None
|
|
if checkpoint_path is not None:
|
|
model = model.load(checkpoint_path, env)
|
|
|
|
|
|
# Set the save directory
|
|
save_dir = "trained_models"
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Set up callbacks
|
|
stage_interval = 540000 # Number of steps between increasing stages
|
|
checkpoint_callback = CheckpointCallback(save_freq=stage_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
|
stage_increase_callback = StageIncreaseCallback(state_stages, stage_interval, save_dir)
|
|
|
|
# Train the model and save intermediate models
|
|
model.learn(total_timesteps=1620000, callback=[checkpoint_callback, stage_increase_callback])
|
|
|
|
# Save the final model
|
|
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
|
|
|
|
env.close()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|