mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-03 22:50:43 +00:00
128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
# Copyright 2023 LIN Yi. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import os
|
|
import sys
|
|
|
|
import retro
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.callbacks import CheckpointCallback
|
|
from stable_baselines3.common.vec_env import SubprocVecEnv
|
|
|
|
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
|
|
|
NUM_ENV = 16
|
|
LOG_DIR = 'logs'
|
|
os.makedirs(LOG_DIR, exist_ok=True)
|
|
|
|
# Linear scheduler
|
|
def linear_schedule(initial_value, final_value=0.0):
|
|
|
|
if isinstance(initial_value, str):
|
|
initial_value = float(initial_value)
|
|
final_value = float(final_value)
|
|
assert (initial_value > 0.0)
|
|
|
|
def scheduler(progress):
|
|
return final_value + progress * (initial_value - final_value)
|
|
|
|
return scheduler
|
|
|
|
def make_env(game, state, seed=0):
|
|
def _init():
|
|
env = retro.make(
|
|
game=game,
|
|
state=state,
|
|
use_restricted_actions=retro.Actions.FILTERED,
|
|
obs_type=retro.Observations.IMAGE
|
|
)
|
|
env = StreetFighterCustomWrapper(env)
|
|
env = Monitor(env)
|
|
env.seed(seed)
|
|
return env
|
|
return _init
|
|
|
|
def main():
|
|
# Set up the environment and model
|
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
|
env = SubprocVecEnv([make_env(game, state="Champion.Level12.RyuVsBison", seed=i) for i in range(NUM_ENV)])
|
|
|
|
# Set linear schedule for learning rate
|
|
# Start
|
|
lr_schedule = linear_schedule(2.5e-4, 2.5e-6)
|
|
|
|
# fine-tune
|
|
# lr_schedule = linear_schedule(5.0e-5, 2.5e-6)
|
|
|
|
# Set linear scheduler for clip range
|
|
# Start
|
|
clip_range_schedule = linear_schedule(0.15, 0.025)
|
|
|
|
# fine-tune
|
|
# clip_range_schedule = linear_schedule(0.075, 0.025)
|
|
|
|
model = PPO(
|
|
"CnnPolicy",
|
|
env,
|
|
device="cuda",
|
|
verbose=1,
|
|
n_steps=512,
|
|
batch_size=512,
|
|
n_epochs=4,
|
|
gamma=0.94,
|
|
learning_rate=lr_schedule,
|
|
clip_range=clip_range_schedule,
|
|
tensorboard_log="logs"
|
|
)
|
|
|
|
# Set the save directory
|
|
save_dir = "trained_models"
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Load the model from file
|
|
# model_path = "trained_models/ppo_ryu_7000000_steps.zip"
|
|
|
|
# Load model and modify the learning rate and entropy coefficient
|
|
# custom_objects = {
|
|
# "learning_rate": lr_schedule,
|
|
# "clip_range": clip_range_schedule,
|
|
# "n_steps": 512
|
|
# }
|
|
# model = PPO.load(model_path, env=env, device="cuda", custom_objects=custom_objects)
|
|
|
|
# Set up callbacks
|
|
# Note that 1 timesetp = 6 frame
|
|
checkpoint_interval = 31250 # checkpoint_interval * num_envs = total_steps_per_checkpoint
|
|
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu")
|
|
|
|
# Writing the training logs from stdout to a file
|
|
original_stdout = sys.stdout
|
|
log_file_path = os.path.join(save_dir, "training_log.txt")
|
|
with open(log_file_path, 'w') as log_file:
|
|
sys.stdout = log_file
|
|
|
|
model.learn(
|
|
total_timesteps=int(100000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
|
callback=[checkpoint_callback]#, stage_increase_callback]
|
|
)
|
|
env.close()
|
|
|
|
# Restore stdout
|
|
sys.stdout = original_stdout
|
|
|
|
# Save the final model
|
|
model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip"))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|