diff --git a/main/general/street_fighter_custom_wrapper_general.py b/main/general/street_fighter_custom_wrapper_general.py new file mode 100644 index 0000000..a8dc63b --- /dev/null +++ b/main/general/street_fighter_custom_wrapper_general.py @@ -0,0 +1,159 @@ +# Copyright 2024 WANG Jing. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import time +import collections + +import gym +import numpy as np + +# Custom environment wrapper +class StreetFighterCustomWrapper(gym.Wrapper): + + def __init__(self, env, reset_round=False, rendering=False): + super(StreetFighterCustomWrapper, self).__init__(env) + self.env = env + + self.win = 0 + self.loss = 0 + self.round = 0 + self.round_end = False + self.jump = False + + # Use a deque to store the last 9 frames + self.num_frames = 9 + self.frame_stack = collections.deque(maxlen=self.num_frames) + + self.num_step_frames = 6 + + self.reward_coeff = 3.0 + + self.total_timesteps = 0 + self.match1_reward = 0 + + self.full_hp = 176 + self.prev_player_health = self.full_hp + self.prev_oppont_health = self.full_hp + + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(100, 128, 3), dtype=np.uint8) + + self.reset_round = reset_round + self.rendering = rendering + + def _stack_observation(self): + return np.stack([self.frame_stack[i * 3 + 2][:, :, i] for i in range(3)], axis=-1) + + def reset(self): + observation = self.env.reset() + self.win = 0 + self.loss = 0 + + self.prev_player_health = self.full_hp + self.prev_oppont_health = self.full_hp + + self.total_timesteps = 0 + + # Clear the frame stack and add the first observation [num_frames] times + self.frame_stack.clear() + for _ in range(self.num_frames): + self.frame_stack.append(observation[::2, ::2, :]) + + return np.stack([self.frame_stack[i * 3 + 2][:, :, i] for i in range(3)], axis=-1) + + + def step(self, action): + custom_done = False + + obs, _reward, _done, info = self.env.step(action) + if not self.jump: + self.frame_stack.append(obs[::2, ::2, :]) + # Render the game if rendering flag is set to True. + if self.rendering: + self.env.render() + #time.sleep(0.003) + for _ in range(self.num_step_frames - 1): + # Keep the button pressed for (num_step_frames - 1) frames. + obs, _reward, _done, info= self.env.step(action) + if not self.jump: + self.frame_stack.append(obs[::2, ::2, :]) + if self.rendering: + self.env.render() + #time.sleep(0.003) + + curr_player_health = info['agent_hp'] + curr_oppont_health = info['enemy_hp'] + + if not custom_done: + if self.round_end: + while not(curr_player_health == self.full_hp and curr_oppont_health == self.full_hp): + self.jump = True + obs, _reward, _done, info = self.env.step([0] * 12) + curr_player_health = info['agent_hp'] + curr_oppont_health = info['enemy_hp'] + if self.rendering: + self.env.render() + #time.sleep(0.01) + self.round_end = False + self.jump = False + if self.jump : + custom_reward = np.nan + self.prev_player_health = self.full_hp + self.prev_oppont_health = self.full_hp + custom_done = False + else : + self.total_timesteps += self.num_step_frames + reduce_enermy_health = self.prev_oppont_health - curr_oppont_health + reduce_player_health = self.prev_player_health - curr_player_health + + # Determine game status and calculate rewards. + if reduce_player_health > 0 and curr_player_health < 0: + custom_reward = -math.pow(self.full_hp, (curr_oppont_health + 1) / (self.full_hp + 1)) + self.loss += 1 + self.round +=1 + self.round_end = True + #print('loss') + if self.loss == 2: + self.loss = 0 + self.round = 0 + custom_done = True + + elif reduce_enermy_health > 0 and curr_oppont_health <0 : + custom_reward = math.pow(self.full_hp, (curr_player_health + 1) / (self.full_hp + 1)) * self.reward_coeff + self.win += 1 + self.round +=1 + self.round_end = True + #print('win') + if self.win == 2: + self.win = 0 + self.round = 0 + custom_done = True + else: + if reduce_enermy_health >= 0 and reduce_player_health >= 0: + if reduce_enermy_health > 0 or reduce_player_health > 0: + custom_reward = self.reward_coeff * (reduce_enermy_health - reduce_player_health) + else: + custom_reward = 0 + else: + custom_reward = 0 + + # Update health states. + self.prev_player_health = curr_player_health + self.prev_oppont_health = curr_oppont_health + + + # Max reward is 2 * 6 * full_hp = 2108 2 * (damage * 3 + winning_reward * 3) norm_coefficient = 0.001 for win first two + # Or in three round 2 * 6 * full_hp - full_hp - 1 + 3 * full_hp = 2464 (damage * 3 + winning_reward * 3) + #print("reward{}, reduce_player_health:{}, reduce_enermy_health:{}, agentHP:{}, enemyHP:{}".format(custom_reward,reduce_player_health,reduce_enermy_health,curr_player_health,curr_oppont_health)) + return self._stack_observation(), 0.001 * custom_reward, custom_done, info # reward normalization + + diff --git a/main/general/test_general.py b/main/general/test_general.py new file mode 100644 index 0000000..787a460 --- /dev/null +++ b/main/general/test_general.py @@ -0,0 +1,86 @@ +# Copyright 2024 WANG Jing. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import time + +import retro +from stable_baselines3 import PPO +from street_fighter_custom_wrapper_general import StreetFighterCustomWrapper + +RESET_ROUND = False # Already in rule best of three keep in False +RENDERING = True # Whether to render the game screen. + +lin = r"ppo_ryu_2500000_steps_updated" +king = r"ppo_ryu_king_10000000_steps" + +#the trained result performance in best of three games +# 2500000_steps_updated which with 1 round trained got win rate in 0.4 +# king_10000000_steps which with entire best of three rules got win rate in 0.99 + +MODEL_NAME = king +RANDOM_ACTION = False +NUM_EPISODES = 50 # +MODEL_DIR = r"./trained_models" + +def make_env(game, state): + def _init(): + env = retro.make( + game=game, + state=state, + use_restricted_actions=retro.Actions.FILTERED, + obs_type=retro.Observations.IMAGE + ) + env = StreetFighterCustomWrapper(env, reset_round=RESET_ROUND, rendering=RENDERING) + return env + return _init + +game = "StreetFighterIISpecialChampionEdition-Genesis" +env = make_env(game, state="Champion.Level12.RyuVsBison")() + +if not RANDOM_ACTION: + model = PPO.load(os.path.join(MODEL_DIR, MODEL_NAME), env=env) + +obs = env.reset() +num_episodes = NUM_EPISODES +episode_reward_sum = 0 +num_victory = 0 + +print("\nFighting Begins!\n") + +for _ in range(num_episodes): + done = False + total_reward = 0 + while not done: + timestamp = time.time() + if RANDOM_ACTION: + obs, reward, done, info = env.step(env.action_space.sample()) + else: + action, _states = model.predict(obs) + obs, reward, done, info = env.step(action) + if reward != 0: + total_reward += reward + print("Reward: {:.3f}, playerHP: {}, enemyHP:{}".format(reward, info['agent_hp'], info['enemy_hp'])) + + if info['enemy_hp'] < 0 and info['agent_hp']>=0: + print("Victory!") + num_victory += 1 + print("Total reward: {}\n".format(total_reward)) + episode_reward_sum += total_reward + obs = env.reset() + +env.close() +print("Winning rate: {}".format(1.0 * num_victory / num_episodes)) +if RANDOM_ACTION: + print("Average reward for random action: {}".format(episode_reward_sum/num_episodes)) +else: + print("Average reward for {}: {}".format(MODEL_NAME, episode_reward_sum/num_episodes)) \ No newline at end of file diff --git a/main/general/train_general.py b/main/general/train_general.py new file mode 100644 index 0000000..66fb767 --- /dev/null +++ b/main/general/train_general.py @@ -0,0 +1,120 @@ +# Copyright 2024 WANG Jing. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import os +import sys + +import retro +from stable_baselines3 import PPO +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.vec_env import SubprocVecEnv +from street_fighter_custom_wrapper_general import StreetFighterCustomWrapper + +NUM_ENV = 16 +LOG_DIR = 'logs_gerneral' +os.makedirs(LOG_DIR, exist_ok=True) + +# Linear scheduler +def linear_schedule(initial_value, final_value=0.0): + + if isinstance(initial_value, str): + initial_value = float(initial_value) + final_value = float(final_value) + assert (initial_value > 0.0) + + def scheduler(progress): + return final_value + progress * (initial_value - final_value) + + return scheduler + +def make_env(game, state, seed=0): + def _init(): + env = retro.make( + game=game, + state=state, + use_restricted_actions=retro.Actions.FILTERED, + obs_type=retro.Observations.IMAGE + ) + env = StreetFighterCustomWrapper(env) + env = Monitor(env) + env.seed(seed) + return env + return _init + +def main(): + # Set up the environment and model + game = "StreetFighterIISpecialChampionEdition-Genesis" + env = SubprocVecEnv([make_env(game, state="Champion.Level12.RyuVsBison", seed=i) for i in range(NUM_ENV)]) + + # Set linear schedule for learning rate + # Start + lr_schedule = linear_schedule(2.5e-4, 2.5e-6) + + clip_range_schedule = linear_schedule(0.15, 0.025) + + + model = PPO( + "CnnPolicy", + env, + device="cuda", + verbose=1, + n_steps=512, + batch_size=512, + n_epochs=4, + gamma=0.94, + learning_rate=lr_schedule, + clip_range=clip_range_schedule, + tensorboard_log="logs_final" + ) + + # Set the save directory + save_dir = "trained_models_gerneral_10" + os.makedirs(save_dir, exist_ok=True) + + # Load the model from file + # model_path = "trained_models/ppo_ryu_7000000_steps.zip" + + # Load model and modify the learning rate and entropy coefficient + # custom_objects = { + # "learning_rate": lr_schedule, + # "clip_range": clip_range_schedule, + # "n_steps": 512 + # } + # model = PPO.load(model_path, env=env, device="cuda", custom_objects=custom_objects) + + # Set up callbacks + # Note that 1 timesetp = 6 frame + checkpoint_interval = 31250 # checkpoint_interval * num_envs = total_steps_per_checkpoint + checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu_gerneral") + + # Writing the training logs from stdout to a file + original_stdout = sys.stdout + log_file_path = os.path.join(save_dir, "training_log.txt") + with open(log_file_path, 'w') as log_file: + sys.stdout = log_file + + model.learn( + total_timesteps=int(10000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds) + callback=[checkpoint_callback]#, stage_increase_callback] + ) + env.close() + + # Restore stdout + sys.stdout = original_stdout + + # Save the final model + model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip")) + +if __name__ == "__main__": + main() diff --git a/main/logs/king_final_result/events.out.tfevents.1714670323.Wanjeans-PC.19056.0 b/main/logs/king_final_result/events.out.tfevents.1714670323.Wanjeans-PC.19056.0 new file mode 100644 index 0000000..a8aad85 Binary files /dev/null and b/main/logs/king_final_result/events.out.tfevents.1714670323.Wanjeans-PC.19056.0 differ diff --git a/main/trained_models/ppo_ryu_king_10000000_steps.zip b/main/trained_models/ppo_ryu_king_10000000_steps.zip new file mode 100644 index 0000000..b02fbfa Binary files /dev/null and b/main/trained_models/ppo_ryu_king_10000000_steps.zip differ