From 41650747b0b1eef1621b26a48829bc4e99cbb98d Mon Sep 17 00:00:00 2001 From: linyiLYi <48440925+linyiLYi@users.noreply.github.com> Date: Tue, 28 Mar 2023 17:39:01 +0800 Subject: [PATCH] updated reward function (added bonus for hit and get-hit) --- __pycache__/custom_cnn.cpython-38.pyc | Bin 1153 -> 1189 bytes __pycache__/custom_sf2_cv_env.cpython-38.pyc | Bin 2429 -> 2622 bytes custom_cnn.py | 8 +++---- custom_sf2_cv_env.py | 22 +++++++++++++++---- test_cv_sf2_ai.py | 6 ++--- train_cv_sf2_ai.py | 18 ++++++++++++--- 6 files changed, 40 insertions(+), 14 deletions(-) diff --git a/__pycache__/custom_cnn.cpython-38.pyc b/__pycache__/custom_cnn.cpython-38.pyc index 9adaf11d4d4726e7c12ce42c2e81586f0189ec41..87ffae45ccb22b96221dade378e7e01e05f60e2d 100644 GIT binary patch delta 280 zcmZqVT*}EC%FD~e00hdLl~VdP^7=APc48Ef=S<;D;YwjnVNc=cWi4S|z_yS9NV0=T zjuh_6y^P|FJd>9)DzdQx^}b}9e1=hvof*i?WN?_w#=AVeiW@-ma(FdH%oPL^Xa)k8>uSu#NRTO2mI`6;D2 UsdgaSi$N~sVB}!rVdP;30B`{`y8r+H delta 234 zcmZ3=*~rNo%FD~e00i@~;-OlBaF>A*9Yk!c&F;N%@liIyM{ zXCN*H*;~V~fT4zQA!9IuCZnGwW03?9xG6;U34ky@l z_Aoz(Co$8aY%|C7Jv~NFo}p!A8;idc3ZXwwYUwTmygQT>1Hspwz;m~_aO&LkqIx1t za+_K!p6^C?XXGsTin07oUgG=mmNxmAX``r1BC~15gz3!KBfxGh+C&(@n;IJ@xkfwU zBJe#oYWVHmi$DE&`CfZAGz_?;pKk_kn>nn}3T;hXls~lb!ezyfnkx!3BA4`$K?M&@ zFbEfnGVaXCZ8{2(&v-)q((gS_>9m%+DFy(exL~VT8;!OjcvzvHA98i-5udOZoNAp{ zSDl*M;*P)0#Z~#%Sjtnlt;2c>EBq#pj5o`usmEFwb&VJ`jaY}18Jn?DMvRtF*HY(I zU3`cw*7pGZw+pEUi7(1Cs}LIKpvOuLBetMl8F|gTK7GvKjbo*rZHfYHqE}0})=KoS z(cxlHIg1KFp+)b^7dw^W8L^;$D=U?f9|V4zx1;|OrryL9e8Uy&eBJ#{jYrTB6`KO; za10E>!+2OL;S$N<49RK~XE4Q-nAjwp8_SX4JX&bf>hdEUl)5#!GY)s49cUjPpe|+@ zN+YR-Qup;3A7ChxEDF_1+YpyvOAJtfM<+!ITLnofm{6&lOI17q-#n~PiGm;{NC$Vu zAsx9(%ksctH&bFgjCdd(0!G|XCgrtK-o$jPY6QiUoThf^?1YIqKrcv-vK@7$cR8g| zFQ?HY=Qvf*4Z|w;JjXee;i!VBITZwq8qhEchT{AH?SrpO%`^hqUQnnBFp1G#YXw_9 Wovqix4FWC{s(8Bg-i%7OZ|on1NZQc= delta 765 zcmZ8f-D(p-6rMBtlTFjCv93)eB0+3qwR)jNBNi)oBO+B0Q3ATInIR@-H(_>TYSLL{XjNY zLc>YrQ$qTEcn6Ykqs!B_k4_1qSG913M^XrgyhUMR9R>dO4GmQf)! zo0o&2-4Bzb&HMczIFQ;2<)Q{GD1vIc?uG+y#VO({Ii_Z?sYjJp9U-^kQ8au#?|b2q f2!}&1l-I?BMeonp*@y;g)aN(UtJSZ^0bIsE91Nc% diff --git a/custom_cnn.py b/custom_cnn.py index bad8a1b..8daa92e 100644 --- a/custom_cnn.py +++ b/custom_cnn.py @@ -8,16 +8,16 @@ class CustomCNN(BaseFeaturesExtractor): def __init__(self, observation_space: gym.Space): super(CustomCNN, self).__init__(observation_space, features_dim=512) self.cnn = nn.Sequential( - nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=0), + nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0), nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0), + nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), nn.ReLU(), nn.Flatten(), - nn.Linear(3136, self.features_dim), + nn.Linear(16384, self.features_dim), nn.ReLU() ) def forward(self, observations: torch.Tensor) -> torch.Tensor: - return self.cnn(observations) \ No newline at end of file + return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension \ No newline at end of file diff --git a/custom_sf2_cv_env.py b/custom_sf2_cv_env.py index 017605b..55b3d0f 100644 --- a/custom_sf2_cv_env.py +++ b/custom_sf2_cv_env.py @@ -9,7 +9,6 @@ class StreetFighterCustomWrapper(gym.Wrapper): self.win_template = win_template self.lose_template = lose_template self.threshold = threshold - self.game_screen_gray = None self.prev_player_health = 1.0 @@ -17,7 +16,7 @@ class StreetFighterCustomWrapper(gym.Wrapper): # Update observation space to single-channel grayscale image self.observation_space = gym.spaces.Box( - low=0, high=255, shape=(84, 84, 1), dtype=np.uint8 + low=0.0, high=1.0, shape=(84, 84, 1), dtype=np.float32 ) def _preprocess_observation(self, observation): @@ -26,7 +25,7 @@ class StreetFighterCustomWrapper(gym.Wrapper): # print("self.game_screen_gray size: ", self.game_screen_gray.shape) # Print the size of the observation # print("Observation size: ", observation.shape) - resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) + resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0 return np.expand_dims(resized_image, axis=-1) def _check_game_over(self): @@ -46,11 +45,26 @@ class StreetFighterCustomWrapper(gym.Wrapper): player_health = np.sum(player_health_area > 129) / player_health_area.size opponent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size - reward = player_health - opponent_health + player_health_diff = self.prev_player_health - player_health + opponent_health_diff = self.prev_opponent_health - opponent_health + + reward = (opponent_health_diff - player_health_diff) * 100 + + # Add bonus for successful attacks or penalize for taking damage + if opponent_health_diff > player_health_diff: + reward += 10 # Bonus for successful attacks + elif opponent_health_diff < player_health_diff: + reward -= 10 # Penalty for taking damage + + self.prev_player_health = player_health + self.prev_opponent_health = opponent_health + return reward def reset(self): observation = self.env.reset() + self.prev_player_health = 1.0 + self.prev_opponent_health = 1.0 return self._preprocess_observation(observation) def step(self, action): diff --git a/test_cv_sf2_ai.py b/test_cv_sf2_ai.py index fc29a8a..5410f39 100644 --- a/test_cv_sf2_ai.py +++ b/test_cv_sf2_ai.py @@ -45,12 +45,12 @@ model = PPO( policy_kwargs=policy_kwargs, verbose=1 ) -model.load("ppo_sf2_cnn") +model.load("ppo_sf2_cnn_new") obs = env.reset() done = False -while not done: +while True: timestamp = time.time() action, _ = model.predict(obs) obs, rewards, done, info = env.step(action) @@ -59,4 +59,4 @@ while not done: if render_time < 0.0111: time.sleep(0.0111 - render_time) # Add a delay for 90 FPS -env.close() \ No newline at end of file +# env.close() \ No newline at end of file diff --git a/train_cv_sf2_ai.py b/train_cv_sf2_ai.py index 5036a19..80c5253 100644 --- a/train_cv_sf2_ai.py +++ b/train_cv_sf2_ai.py @@ -45,11 +45,23 @@ def main(): env, device="cuda", policy_kwargs=policy_kwargs, - verbose=1 + verbose=1, + n_steps=2048, + batch_size=64, + n_epochs=10, + learning_rate=0.0003, + ent_coef=0.01, + clip_range=0.2, + clip_range_vf=None, + gamma=0.99, + gae_lambda=0.95, + max_grad_norm=0.5, + use_sde=False, + sde_sample_freq=-1 ) - model.learn(total_timesteps=int(1000)) + model.learn(total_timesteps=int(500000)) - model.save("ppo_sf2_cnn") + model.save("ppo_sf2_cnn_new") if __name__ == "__main__": main()