From dc4aad6e8a2cf63f151e469957bcef81dc0aea21 Mon Sep 17 00:00:00 2001 From: linyiLYi <48440925+linyiLYi@users.noreply.github.com> Date: Wed, 29 Mar 2023 02:27:06 +0800 Subject: [PATCH] random training --- __pycache__/custom_cnn.cpython-38.pyc | Bin 1189 -> 1250 bytes __pycache__/custom_sf2_cv_env.cpython-38.pyc | Bin 2622 -> 2577 bytes custom_cnn.py | 4 +- custom_sf2_cv_env.py | 49 +++++++------- ...ern_wins_gray.png => pattern_win_gray.png} | Bin test_cv_sf2_ai.py | 18 ++++-- train_cv_sf2_ai.py | 60 +++++++++++++++--- 7 files changed, 93 insertions(+), 38 deletions(-) rename images/{pattern_wins_gray.png => pattern_win_gray.png} (100%) diff --git a/__pycache__/custom_cnn.cpython-38.pyc b/__pycache__/custom_cnn.cpython-38.pyc index 87ffae45ccb22b96221dade378e7e01e05f60e2d..db6b9b05b3c00728ac10cf48f7cf0af095315471 100644 GIT binary patch delta 419 zcmYk2zfQw25XR&D4lMg|yQjPe33 zJODd8@8FRq!IuHmmOr2GKK<v z*i^QpY$lv~V%|G#oe%17^Zjg1N5TVvm&B&n(3|oesmSLg)#pv9ZG&w6OTY2V3h|T( z(yiaDt7jl8rz}MJpxhhEWg)~aj*rJ$tnO8=a_ttsneMc z$SPt35u89mlc`7?NZn#DNG-}OElCCPB|vhrAVL~W$bi+PIu< zYe{}la>gyTkkq{5{GuWSkQi@#d~!}=adC2LPELG0k|B~n*&>ihMa)2g9f*rrfCLAl z2p?EPlf6gB(>be>H}sb`~s@Z8KYflC8av z;FZij&>r;U#e-MR#$denVvHolc=BSr;hTj7Byl$H&A0Ep%$x7qx9?Lssl>fRBBtT` zaxu$CcM>n|^lWQOz;#Zy?1Ih>Zf$}P@tH@MAXxSPC-$8ooKf_LN2wl%61>^aRz zb}}+&XGZehKkcfumn)g-_`{{JG94|rFJmwahP(`8W=KQL?U9cl4@udAH(ed8?gjO- zzwDTDmZYIyE)pLGch#ic%X>fIYga4xj_38XH=J{JtRTq z=u8tZLeOSZMx{e-MH#G8k>oI^ae}5H5qS}qi4j_O8(VIe zj(lZaK)N5z=G2MqT~T-8Z4?C`Q7u#GG)O~QPeF>rAV%WE1e0!#M!LPW=5G#QMaK|c z;MI@ViAxwox)DVO@d=({5@VC{I?FW~B8nDCxq?7R*@6(-@R&A8LvPTA5o#^c)*2>< zhJ|^ZlT`?=A)*S2hoOlTMq9gqH5fI6yIgNrew?duz!{8W&{|eV+JLoIXe?2b!RL>9kW=fm#e-|JPVe4AYZ|7 zdDZviZvCuZ>XH;NLLmilKxLJ>MBcW~6OxyU_IOh@5>emiE=srPfKI}7 zOO;{}l)OsCb&u&gs7loI3N@i>Y=byD2Hm~PV2q`h8D$((oK{s?bRENAU#mSAX@#u^ hqFAeWLRF0&|8htF1-rFY_(b<&1f%O^V;slr0fAkX^o*dJS-B8B zcru-MhKpXjcsE|X=pUGyiHY&#PtdBtC~9}=vDMYzd{x!^V&MINf7|y<4A=9wW9rEZ z|H-AkXKWS}S7P0(3ni7c&gxRx%0X)>S9xe{RZ=}@9o4IRwC*z|S>3zEf^t4=`n}=t zub)2B`f{OR5&Ex1zKtSKk{hN6MnRehIA-oa70Tu{yy1am?(#wNm@kYD(=Digy|p=7 zVFPEh<&Vw*{LmgVDn3sPVP4+TzMsT;B}(VpOLbqz$$TeP;e0Eai}gNC*is|rsNCB> zh`$peg7DUE=$t5^9?J1GVn!^dP?8#|lx;x97C4GX*@D~@UDy`0e#LldZ#WrjQm(sb z=}85*SP|pDV>c+KICe=)*Q$Z1#}GgFTT$A0P>Wm3%_v=qbp?akr;|$UdN?&TbG2Fx zr)O?nzj{UY(HW7?MJsxw%jg9{Q~7mgb$%(ts_SmEoy6pAubQq@!p4KRu^7%ptvGD2 z#CktQ6M|$gfV%<`Y*L*g)Q{zX^1|Hv5RKIxv=J- zCsB&2pFszwNOQ+|-jQA26}YjM4Zg{`7Fw~%mJcfE+UU!SCTw8``E)g~biqg_ zX}qinPah>!5ONAaxg9m=r540Q4xxGH9vJ_-C3+GsRZ$Wb6%EOSUNB{E48WZ9jyDKz z3%U3P)l=Nb{V;4aqa&1M*Gt9uWnjF94}=pT^uqH>|H?0UF&yLNFB*m@^rjhPSL c$iyF7g-7>uce2%151a9MAP_>rxqvHv0@cgkm;e9( diff --git a/custom_cnn.py b/custom_cnn.py index 8daa92e..5de99a7 100644 --- a/custom_cnn.py +++ b/custom_cnn.py @@ -2,6 +2,7 @@ import gym import torch import torch.nn as nn from stable_baselines3.common.torch_layers import BaseFeaturesExtractor +from torchvision.models import mobilenet_v3_small # Custom feature extractor (CNN) class CustomCNN(BaseFeaturesExtractor): @@ -20,4 +21,5 @@ class CustomCNN(BaseFeaturesExtractor): ) def forward(self, observations: torch.Tensor) -> torch.Tensor: - return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension \ No newline at end of file + return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension + \ No newline at end of file diff --git a/custom_sf2_cv_env.py b/custom_sf2_cv_env.py index 55b3d0f..8538440 100644 --- a/custom_sf2_cv_env.py +++ b/custom_sf2_cv_env.py @@ -4,7 +4,7 @@ import numpy as np # Custom environment wrapper class StreetFighterCustomWrapper(gym.Wrapper): - def __init__(self, env, win_template, lose_template, threshold=0.65): + def __init__(self, env, win_template, lose_template, testing=False, threshold=0.65): super(StreetFighterCustomWrapper, self).__init__(env) self.win_template = win_template self.lose_template = lose_template @@ -18,24 +18,21 @@ class StreetFighterCustomWrapper(gym.Wrapper): self.observation_space = gym.spaces.Box( low=0.0, high=1.0, shape=(84, 84, 1), dtype=np.float32 ) + + self.testing = testing def _preprocess_observation(self, observation): self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY) - # Print the size of self.game_screen_gray - # print("self.game_screen_gray size: ", self.game_screen_gray.shape) - # Print the size of the observation - # print("Observation size: ", observation.shape) resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0 return np.expand_dims(resized_image, axis=-1) - - def _check_game_over(self): - win_res = cv2.matchTemplate(self.game_screen_gray, self.win_template, cv2.TM_CCOEFF_NORMED) - lose_res = cv2.matchTemplate(self.game_screen_gray, self.lose_template, cv2.TM_CCOEFF_NORMED) - if np.max(win_res) >= self.threshold: - return True - if np.max(lose_res) >= self.threshold: - return True - return False + + def _get_win_or_lose_bonus(self): + if self.prev_player_health > self.prev_opponent_health: + # print('You win!') + return 200 + else: + # print('You lose!') + return -200 def _get_reward(self): player_health_area = self.game_screen_gray[15:20, 32:120] @@ -48,17 +45,13 @@ class StreetFighterCustomWrapper(gym.Wrapper): player_health_diff = self.prev_player_health - player_health opponent_health_diff = self.prev_opponent_health - opponent_health - reward = (opponent_health_diff - player_health_diff) * 100 - - # Add bonus for successful attacks or penalize for taking damage - if opponent_health_diff > player_health_diff: - reward += 10 # Bonus for successful attacks - elif opponent_health_diff < player_health_diff: - reward -= 10 # Penalty for taking damage + reward = (opponent_health_diff - player_health_diff) * 100 # max would be 100 self.prev_player_health = player_health self.prev_opponent_health = opponent_health + # Print the health values of the player and the opponent + # print("Player health: %f Opponent health:%f" % (player_health, opponent_health)) return reward def reset(self): @@ -68,7 +61,17 @@ class StreetFighterCustomWrapper(gym.Wrapper): return self._preprocess_observation(observation) def step(self, action): - observation, _, _, info = self.env.step(action) + # observation, _, _, info = self.env.step(action) + observation, _reward, _done, info = self.env.step(action) custom_reward = self._get_reward() - custom_done = self._check_game_over() or False + + custom_done = False + if self.prev_player_health <= 0.00001 or self.prev_opponent_health <= 0.00001: + custom_reward += self._get_win_or_lose_bonus() + if not self.testing: + custom_done = True + else: + self.prev_player_health = 1.0 + self.prev_opponent_health = 1.0 + return self._preprocess_observation(observation), custom_reward, custom_done, info \ No newline at end of file diff --git a/images/pattern_wins_gray.png b/images/pattern_win_gray.png similarity index 100% rename from images/pattern_wins_gray.png rename to images/pattern_win_gray.png diff --git a/test_cv_sf2_ai.py b/test_cv_sf2_ai.py index 5410f39..75cefb7 100644 --- a/test_cv_sf2_ai.py +++ b/test_cv_sf2_ai.py @@ -13,11 +13,11 @@ from custom_sf2_cv_env import StreetFighterCustomWrapper def make_env(game, state, seed=0): def _init(): - win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE) + win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE) lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE) env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE) - env = StreetFighterCustomWrapper(env, win_template, lose_template) - env.seed(seed) + env = StreetFighterCustomWrapper(env, win_template, lose_template, testing=True) + # env.seed(seed) return env return _init @@ -27,6 +27,14 @@ state_stages = [ "Champion.Level2.ChunLiVsKen", "Champion.Level3.ChunLiVsChunLi", "Champion.Level4.ChunLiVsZangief", + "Champion.Level5.ChunLiVsDhalsim", + "Champion.Level6.ChunLiVsRyu", + "Champion.Level7.ChunLiVsEHonda", + "Champion.Level8.ChunLiVsBlanka", + "Champion.Level9.ChunLiVsBalrog", + "Champion.Level10.ChunLiVsVega", + "Champion.Level11.ChunLiVsSagat", + "Champion.Level12.ChunLiVsBison" # Add other stages as necessary ] @@ -45,7 +53,7 @@ model = PPO( policy_kwargs=policy_kwargs, verbose=1 ) -model.load("ppo_sf2_cnn_new") +model.load(r"trained_models_cv_test/ppo_sf2_chunli_final") obs = env.reset() done = False @@ -59,4 +67,4 @@ while True: if render_time < 0.0111: time.sleep(0.0111 - render_time) # Add a delay for 90 FPS -# env.close() \ No newline at end of file +# env.close() diff --git a/train_cv_sf2_ai.py b/train_cv_sf2_ai.py index 80c5253..d6bdd8f 100644 --- a/train_cv_sf2_ai.py +++ b/train_cv_sf2_ai.py @@ -1,3 +1,6 @@ +import os +import random + import gym import cv2 import retro @@ -5,19 +8,33 @@ import numpy as np from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first +from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback import torch import torch.nn as nn from custom_cnn import CustomCNN from custom_sf2_cv_env import StreetFighterCustomWrapper +class RandomOpponentChangeCallback(BaseCallback): + def __init__(self, stages, opponent_interval, save_dir, verbose=0): + super(RandomOpponentChangeCallback, self).__init__(verbose) + self.stages = stages + self.opponent_interval = opponent_interval + + def _on_step(self) -> bool: + if self.n_calls % self.opponent_interval == 0: + new_state = random.choice(self.stages) + print("\nCurrent state:", new_state) + self.training_env.env_method("load_state", new_state, indices=None) + return True + def make_env(game, state, seed=0): def _init(): - win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE) + win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE) lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE) env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE) env = StreetFighterCustomWrapper(env, win_template, lose_template) - env.seed(seed) + # env.seed(seed) return env return _init @@ -25,15 +42,24 @@ def main(): # Set up the environment and model game = "StreetFighterIISpecialChampionEdition-Genesis" state_stages = [ - "Champion.Level1.ChunLiVsGuile", - "Champion.Level2.ChunLiVsKen", - "Champion.Level3.ChunLiVsChunLi", - "Champion.Level4.ChunLiVsZangief", + # "Champion.Level1.ChunLiVsGuile", + # "Champion.Level2.ChunLiVsKen", + # "Champion.Level3.ChunLiVsChunLi", + # "Champion.Level4.ChunLiVsZangief", + # "Champion.Level5.ChunLiVsDhalsim", + "Champion.Level6.ChunLiVsRyu", + "Champion.Level7.ChunLiVsEHonda", + "Champion.Level8.ChunLiVsBlanka", + "Champion.Level9.ChunLiVsBalrog", + "Champion.Level10.ChunLiVsVega", + "Champion.Level11.ChunLiVsSagat", + "Champion.Level12.ChunLiVsBison" # Add other stages as necessary ] num_envs = 8 + # env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)]) env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)]) policy_kwargs = { @@ -46,7 +72,7 @@ def main(): device="cuda", policy_kwargs=policy_kwargs, verbose=1, - n_steps=2048, + n_steps=5400, batch_size=64, n_epochs=10, learning_rate=0.0003, @@ -59,9 +85,25 @@ def main(): use_sde=False, sde_sample_freq=-1 ) - model.learn(total_timesteps=int(500000)) - model.save("ppo_sf2_cnn_new") + # Set the save directory + save_dir = "trained_models_cv_level6up" + os.makedirs(save_dir, exist_ok=True) + + # Set up callbacks + opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage + checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds) + checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli") + stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir) + + + model.learn( + total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds) + callback=[checkpoint_callback, stage_increase_callback] + ) + + # Save the final model + model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip")) if __name__ == "__main__": main()