mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 23:20:43 +00:00
Merge e00678ebee
into 78cb661e6c
This commit is contained in:
commit
ca41251a68
159
main/general/street_fighter_custom_wrapper_general.py
Normal file
159
main/general/street_fighter_custom_wrapper_general.py
Normal file
@ -0,0 +1,159 @@
|
||||
# Copyright 2024 WANG Jing. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import math
|
||||
import time
|
||||
import collections
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
# Custom environment wrapper
|
||||
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
|
||||
def __init__(self, env, reset_round=False, rendering=False):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
self.env = env
|
||||
|
||||
self.win = 0
|
||||
self.loss = 0
|
||||
self.round = 0
|
||||
self.round_end = False
|
||||
self.jump = False
|
||||
|
||||
# Use a deque to store the last 9 frames
|
||||
self.num_frames = 9
|
||||
self.frame_stack = collections.deque(maxlen=self.num_frames)
|
||||
|
||||
self.num_step_frames = 6
|
||||
|
||||
self.reward_coeff = 3.0
|
||||
|
||||
self.total_timesteps = 0
|
||||
self.match1_reward = 0
|
||||
|
||||
self.full_hp = 176
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(100, 128, 3), dtype=np.uint8)
|
||||
|
||||
self.reset_round = reset_round
|
||||
self.rendering = rendering
|
||||
|
||||
def _stack_observation(self):
|
||||
return np.stack([self.frame_stack[i * 3 + 2][:, :, i] for i in range(3)], axis=-1)
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
self.win = 0
|
||||
self.loss = 0
|
||||
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
self.total_timesteps = 0
|
||||
|
||||
# Clear the frame stack and add the first observation [num_frames] times
|
||||
self.frame_stack.clear()
|
||||
for _ in range(self.num_frames):
|
||||
self.frame_stack.append(observation[::2, ::2, :])
|
||||
|
||||
return np.stack([self.frame_stack[i * 3 + 2][:, :, i] for i in range(3)], axis=-1)
|
||||
|
||||
|
||||
def step(self, action):
|
||||
custom_done = False
|
||||
|
||||
obs, _reward, _done, info = self.env.step(action)
|
||||
if not self.jump:
|
||||
self.frame_stack.append(obs[::2, ::2, :])
|
||||
# Render the game if rendering flag is set to True.
|
||||
if self.rendering:
|
||||
self.env.render()
|
||||
#time.sleep(0.003)
|
||||
for _ in range(self.num_step_frames - 1):
|
||||
# Keep the button pressed for (num_step_frames - 1) frames.
|
||||
obs, _reward, _done, info= self.env.step(action)
|
||||
if not self.jump:
|
||||
self.frame_stack.append(obs[::2, ::2, :])
|
||||
if self.rendering:
|
||||
self.env.render()
|
||||
#time.sleep(0.003)
|
||||
|
||||
curr_player_health = info['agent_hp']
|
||||
curr_oppont_health = info['enemy_hp']
|
||||
|
||||
if not custom_done:
|
||||
if self.round_end:
|
||||
while not(curr_player_health == self.full_hp and curr_oppont_health == self.full_hp):
|
||||
self.jump = True
|
||||
obs, _reward, _done, info = self.env.step([0] * 12)
|
||||
curr_player_health = info['agent_hp']
|
||||
curr_oppont_health = info['enemy_hp']
|
||||
if self.rendering:
|
||||
self.env.render()
|
||||
#time.sleep(0.01)
|
||||
self.round_end = False
|
||||
self.jump = False
|
||||
if self.jump :
|
||||
custom_reward = np.nan
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
custom_done = False
|
||||
else :
|
||||
self.total_timesteps += self.num_step_frames
|
||||
reduce_enermy_health = self.prev_oppont_health - curr_oppont_health
|
||||
reduce_player_health = self.prev_player_health - curr_player_health
|
||||
|
||||
# Determine game status and calculate rewards.
|
||||
if reduce_player_health > 0 and curr_player_health < 0:
|
||||
custom_reward = -math.pow(self.full_hp, (curr_oppont_health + 1) / (self.full_hp + 1))
|
||||
self.loss += 1
|
||||
self.round +=1
|
||||
self.round_end = True
|
||||
#print('loss')
|
||||
if self.loss == 2:
|
||||
self.loss = 0
|
||||
self.round = 0
|
||||
custom_done = True
|
||||
|
||||
elif reduce_enermy_health > 0 and curr_oppont_health <0 :
|
||||
custom_reward = math.pow(self.full_hp, (curr_player_health + 1) / (self.full_hp + 1)) * self.reward_coeff
|
||||
self.win += 1
|
||||
self.round +=1
|
||||
self.round_end = True
|
||||
#print('win')
|
||||
if self.win == 2:
|
||||
self.win = 0
|
||||
self.round = 0
|
||||
custom_done = True
|
||||
else:
|
||||
if reduce_enermy_health >= 0 and reduce_player_health >= 0:
|
||||
if reduce_enermy_health > 0 or reduce_player_health > 0:
|
||||
custom_reward = self.reward_coeff * (reduce_enermy_health - reduce_player_health)
|
||||
else:
|
||||
custom_reward = 0
|
||||
else:
|
||||
custom_reward = 0
|
||||
|
||||
# Update health states.
|
||||
self.prev_player_health = curr_player_health
|
||||
self.prev_oppont_health = curr_oppont_health
|
||||
|
||||
|
||||
# Max reward is 2 * 6 * full_hp = 2108 2 * (damage * 3 + winning_reward * 3) norm_coefficient = 0.001 for win first two
|
||||
# Or in three round 2 * 6 * full_hp - full_hp - 1 + 3 * full_hp = 2464 (damage * 3 + winning_reward * 3)
|
||||
#print("reward{}, reduce_player_health:{}, reduce_enermy_health:{}, agentHP:{}, enemyHP:{}".format(custom_reward,reduce_player_health,reduce_enermy_health,curr_player_health,curr_oppont_health))
|
||||
return self._stack_observation(), 0.001 * custom_reward, custom_done, info # reward normalization
|
||||
|
||||
|
86
main/general/test_general.py
Normal file
86
main/general/test_general.py
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2024 WANG Jing. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from street_fighter_custom_wrapper_general import StreetFighterCustomWrapper
|
||||
|
||||
RESET_ROUND = False # Already in rule best of three keep in False
|
||||
RENDERING = True # Whether to render the game screen.
|
||||
|
||||
lin = r"ppo_ryu_2500000_steps_updated"
|
||||
king = r"ppo_ryu_king_10000000_steps"
|
||||
|
||||
#the trained result performance in best of three games
|
||||
# 2500000_steps_updated which with 1 round trained got win rate in 0.4
|
||||
# king_10000000_steps which with entire best of three rules got win rate in 0.99
|
||||
|
||||
MODEL_NAME = king
|
||||
RANDOM_ACTION = False
|
||||
NUM_EPISODES = 50 #
|
||||
MODEL_DIR = r"./trained_models"
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env, reset_round=RESET_ROUND, rendering=RENDERING)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
env = make_env(game, state="Champion.Level12.RyuVsBison")()
|
||||
|
||||
if not RANDOM_ACTION:
|
||||
model = PPO.load(os.path.join(MODEL_DIR, MODEL_NAME), env=env)
|
||||
|
||||
obs = env.reset()
|
||||
num_episodes = NUM_EPISODES
|
||||
episode_reward_sum = 0
|
||||
num_victory = 0
|
||||
|
||||
print("\nFighting Begins!\n")
|
||||
|
||||
for _ in range(num_episodes):
|
||||
done = False
|
||||
total_reward = 0
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
if RANDOM_ACTION:
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
else:
|
||||
action, _states = model.predict(obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
if reward != 0:
|
||||
total_reward += reward
|
||||
print("Reward: {:.3f}, playerHP: {}, enemyHP:{}".format(reward, info['agent_hp'], info['enemy_hp']))
|
||||
|
||||
if info['enemy_hp'] < 0 and info['agent_hp']>=0:
|
||||
print("Victory!")
|
||||
num_victory += 1
|
||||
print("Total reward: {}\n".format(total_reward))
|
||||
episode_reward_sum += total_reward
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
||||
print("Winning rate: {}".format(1.0 * num_victory / num_episodes))
|
||||
if RANDOM_ACTION:
|
||||
print("Average reward for random action: {}".format(episode_reward_sum/num_episodes))
|
||||
else:
|
||||
print("Average reward for {}: {}".format(MODEL_NAME, episode_reward_sum/num_episodes))
|
120
main/general/train_general.py
Normal file
120
main/general/train_general.py
Normal file
@ -0,0 +1,120 @@
|
||||
# Copyright 2024 WANG Jing. All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from street_fighter_custom_wrapper_general import StreetFighterCustomWrapper
|
||||
|
||||
NUM_ENV = 16
|
||||
LOG_DIR = 'logs_gerneral'
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# Linear scheduler
|
||||
def linear_schedule(initial_value, final_value=0.0):
|
||||
|
||||
if isinstance(initial_value, str):
|
||||
initial_value = float(initial_value)
|
||||
final_value = float(final_value)
|
||||
assert (initial_value > 0.0)
|
||||
|
||||
def scheduler(progress):
|
||||
return final_value + progress * (initial_value - final_value)
|
||||
|
||||
return scheduler
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env = Monitor(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
env = SubprocVecEnv([make_env(game, state="Champion.Level12.RyuVsBison", seed=i) for i in range(NUM_ENV)])
|
||||
|
||||
# Set linear schedule for learning rate
|
||||
# Start
|
||||
lr_schedule = linear_schedule(2.5e-4, 2.5e-6)
|
||||
|
||||
clip_range_schedule = linear_schedule(0.15, 0.025)
|
||||
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
verbose=1,
|
||||
n_steps=512,
|
||||
batch_size=512,
|
||||
n_epochs=4,
|
||||
gamma=0.94,
|
||||
learning_rate=lr_schedule,
|
||||
clip_range=clip_range_schedule,
|
||||
tensorboard_log="logs_final"
|
||||
)
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models_gerneral_10"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Load the model from file
|
||||
# model_path = "trained_models/ppo_ryu_7000000_steps.zip"
|
||||
|
||||
# Load model and modify the learning rate and entropy coefficient
|
||||
# custom_objects = {
|
||||
# "learning_rate": lr_schedule,
|
||||
# "clip_range": clip_range_schedule,
|
||||
# "n_steps": 512
|
||||
# }
|
||||
# model = PPO.load(model_path, env=env, device="cuda", custom_objects=custom_objects)
|
||||
|
||||
# Set up callbacks
|
||||
# Note that 1 timesetp = 6 frame
|
||||
checkpoint_interval = 31250 # checkpoint_interval * num_envs = total_steps_per_checkpoint
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu_gerneral")
|
||||
|
||||
# Writing the training logs from stdout to a file
|
||||
original_stdout = sys.stdout
|
||||
log_file_path = os.path.join(save_dir, "training_log.txt")
|
||||
with open(log_file_path, 'w') as log_file:
|
||||
sys.stdout = log_file
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(10000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback]#, stage_increase_callback]
|
||||
)
|
||||
env.close()
|
||||
|
||||
# Restore stdout
|
||||
sys.stdout = original_stdout
|
||||
|
||||
# Save the final model
|
||||
model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
BIN
main/trained_models/ppo_ryu_king_10000000_steps.zip
Normal file
BIN
main/trained_models/ppo_ryu_king_10000000_steps.zip
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user