mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-03 22:50:43 +00:00
learn from level 1
This commit is contained in:
parent
02e39f0a52
commit
16c80d5fba
@ -1,7 +1,6 @@
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ def make_env(game, state):
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
|
||||
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3 | -20.4
|
||||
"ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
@ -42,8 +42,12 @@ model = PPO(
|
||||
env,
|
||||
verbose=1
|
||||
)
|
||||
model_path = r"optuna/trial_1_best_model" # Average reward for optuna/trial_1_best_model: -82.3
|
||||
model_path = r"trained_models_level_1/ppo_chunli_1075200_steps"
|
||||
model.load(model_path)
|
||||
# Average reward for optuna/trial_1_best_model: -82.3
|
||||
# Average reward for optuna/trial_9_best_model: 36.7 | -86.23
|
||||
# Average reward for trained_models/ppo_chunli_5376000_steps: -77.8
|
||||
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
@ -3,13 +3,15 @@ import random
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
|
||||
from rmsprop_optim import RMSpropTF
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
LOG_DIR = 'logs/'
|
||||
|
||||
class RandomOpponentChangeCallback(BaseCallback):
|
||||
def __init__(self, stages, opponent_interval, verbose=0):
|
||||
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
||||
@ -23,7 +25,23 @@ class RandomOpponentChangeCallback(BaseCallback):
|
||||
self.training_env.env_method("load_state", new_state, indices=None)
|
||||
return True
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
# class StageIncreaseCallback(BaseCallback):
|
||||
# def __init__(self, stages, stage_interval, save_dir, verbose=0):
|
||||
# super(StageIncreaseCallback, self).__init__(verbose)
|
||||
# self.stages = stages
|
||||
# self.stage_interval = stage_interval
|
||||
# self.save_dir = save_dir
|
||||
# self.current_stage = 0
|
||||
|
||||
# def _on_step(self) -> bool:
|
||||
# if self.n_calls % self.stage_interval == 0 and self.current_stage < len(self.stages) - 1:
|
||||
# self.current_stage += 1
|
||||
# new_state = self.stages[self.current_stage]
|
||||
# self.training_env.env_method("load_state", new_state, indices=None)
|
||||
# self.model.save(os.path.join(self.save_dir, f"ppo_chunli_stage_{self.current_stage}.zip"))
|
||||
# return True
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
@ -32,57 +50,66 @@ def make_env(game, state, seed=0):
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
|
||||
state_stages = [
|
||||
"ChampionX.Level1.ChunLiVsKen",
|
||||
"ChampionX.Level2.ChunLiVsChunLi",
|
||||
"ChampionX.Level3.ChunLiVsZangief",
|
||||
"ChampionX.Level4.ChunLiVsDhalsim",
|
||||
"ChampionX.Level5.ChunLiVsRyu",
|
||||
"ChampionX.Level6.ChunLiVsEHonda",
|
||||
"ChampionX.Level7.ChunLiVsBlanka",
|
||||
"ChampionX.Level8.ChunLiVsGuile",
|
||||
"ChampionX.Level9.ChunLiVsBalrog",
|
||||
"ChampionX.Level10.ChunLiVsVega",
|
||||
"ChampionX.Level11.ChunLiVsSagat",
|
||||
"ChampionX.Level12.ChunLiVsBison"
|
||||
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
# state_stages = [
|
||||
# "ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||
# "ChampionX.Level2.ChunLiVsChunLi",
|
||||
# "ChampionX.Level3.ChunLiVsZangief",
|
||||
# "ChampionX.Level4.ChunLiVsDhalsim",
|
||||
# "ChampionX.Level5.ChunLiVsRyu",
|
||||
# "ChampionX.Level6.ChunLiVsEHonda",
|
||||
# "ChampionX.Level7.ChunLiVsBlanka",
|
||||
# "ChampionX.Level8.ChunLiVsGuile",
|
||||
# "ChampionX.Level9.ChunLiVsBalrog",
|
||||
# "ChampionX.Level10.ChunLiVsVega",
|
||||
# "ChampionX.Level11.ChunLiVsSagat",
|
||||
# "ChampionX.Level12.ChunLiVsBison"
|
||||
# # Add other stages as necessary
|
||||
# ]
|
||||
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
|
||||
|
||||
num_envs = 8
|
||||
|
||||
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
|
||||
# Using CustomCNN as the feature extractor
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
env = make_env(game, state_stages[0])()
|
||||
env = Monitor(env, LOG_DIR)
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
n_steps=5400,
|
||||
n_steps=35840, # 64 * 56
|
||||
batch_size=64,
|
||||
learning_rate=0.0001,
|
||||
learning_rate=6e-5,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
clip_range=0.15487,
|
||||
gamma=0.9483,
|
||||
gae_lambda=0.81322,
|
||||
tensorboard_log="logs/"
|
||||
)
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models"
|
||||
save_dir = "trained_models_level_1"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Load the model from file
|
||||
@ -95,10 +122,10 @@ def main():
|
||||
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
|
||||
|
||||
# Set up callbacks
|
||||
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
|
||||
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||
# opponent_interval = 35840 # stage_interval * num_envs = total_steps_per_stage
|
||||
checkpoint_interval = 358400 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
||||
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
# stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
|
||||
# model_params = {
|
||||
# 'n_steps': 5,
|
||||
@ -113,8 +140,8 @@ def main():
|
||||
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback, stage_increase_callback]
|
||||
total_timesteps=int(5376000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback]#, stage_increase_callback]
|
||||
)
|
||||
env.close()
|
||||
|
||||
|
8951
000_image_stack_ram_based_reward/trained_models/training_logs.txt
Normal file
8951
000_image_stack_ram_based_reward/trained_models/training_logs.txt
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
35
004_custom_policy/custom_cnn.py
Normal file
35
004_custom_policy/custom_cnn.py
Normal file
@ -0,0 +1,35 @@
|
||||
import torch.nn as nn
|
||||
|
||||
def conv2d_custom_init(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
nn.init.xavier_uniform_(conv.weight)
|
||||
return conv
|
||||
|
||||
def custom_conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
|
||||
return nn.Sequential(
|
||||
conv2d_custom_init(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
nn.Relu(),
|
||||
nn.MaxPool2d((2, 2))
|
||||
)
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(nn.Module):
|
||||
def __init__(self, num_frames, num_moves, num_attacks):
|
||||
super(CustomCNN, self).__init__()
|
||||
self.num_moves = num_moves
|
||||
self.num_attacks = num_attacks
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(16384, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
|
Loading…
Reference in New Issue
Block a user