random training

This commit is contained in:
linyiLYi 2023-03-29 02:27:06 +08:00
parent 41650747b0
commit dc4aad6e8a
7 changed files with 93 additions and 38 deletions

View File

@ -2,6 +2,7 @@ import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from torchvision.models import mobilenet_v3_small
# Custom feature extractor (CNN)
class CustomCNN(BaseFeaturesExtractor):
@ -21,3 +22,4 @@ class CustomCNN(BaseFeaturesExtractor):
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension

View File

@ -4,7 +4,7 @@ import numpy as np
# Custom environment wrapper
class StreetFighterCustomWrapper(gym.Wrapper):
def __init__(self, env, win_template, lose_template, threshold=0.65):
def __init__(self, env, win_template, lose_template, testing=False, threshold=0.65):
super(StreetFighterCustomWrapper, self).__init__(env)
self.win_template = win_template
self.lose_template = lose_template
@ -19,23 +19,20 @@ class StreetFighterCustomWrapper(gym.Wrapper):
low=0.0, high=1.0, shape=(84, 84, 1), dtype=np.float32
)
self.testing = testing
def _preprocess_observation(self, observation):
self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
# Print the size of self.game_screen_gray
# print("self.game_screen_gray size: ", self.game_screen_gray.shape)
# Print the size of the observation
# print("Observation size: ", observation.shape)
resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
return np.expand_dims(resized_image, axis=-1)
def _check_game_over(self):
win_res = cv2.matchTemplate(self.game_screen_gray, self.win_template, cv2.TM_CCOEFF_NORMED)
lose_res = cv2.matchTemplate(self.game_screen_gray, self.lose_template, cv2.TM_CCOEFF_NORMED)
if np.max(win_res) >= self.threshold:
return True
if np.max(lose_res) >= self.threshold:
return True
return False
def _get_win_or_lose_bonus(self):
if self.prev_player_health > self.prev_opponent_health:
# print('You win!')
return 200
else:
# print('You lose!')
return -200
def _get_reward(self):
player_health_area = self.game_screen_gray[15:20, 32:120]
@ -48,17 +45,13 @@ class StreetFighterCustomWrapper(gym.Wrapper):
player_health_diff = self.prev_player_health - player_health
opponent_health_diff = self.prev_opponent_health - opponent_health
reward = (opponent_health_diff - player_health_diff) * 100
# Add bonus for successful attacks or penalize for taking damage
if opponent_health_diff > player_health_diff:
reward += 10 # Bonus for successful attacks
elif opponent_health_diff < player_health_diff:
reward -= 10 # Penalty for taking damage
reward = (opponent_health_diff - player_health_diff) * 100 # max would be 100
self.prev_player_health = player_health
self.prev_opponent_health = opponent_health
# Print the health values of the player and the opponent
# print("Player health: %f Opponent health:%f" % (player_health, opponent_health))
return reward
def reset(self):
@ -68,7 +61,17 @@ class StreetFighterCustomWrapper(gym.Wrapper):
return self._preprocess_observation(observation)
def step(self, action):
observation, _, _, info = self.env.step(action)
# observation, _, _, info = self.env.step(action)
observation, _reward, _done, info = self.env.step(action)
custom_reward = self._get_reward()
custom_done = self._check_game_over() or False
custom_done = False
if self.prev_player_health <= 0.00001 or self.prev_opponent_health <= 0.00001:
custom_reward += self._get_win_or_lose_bonus()
if not self.testing:
custom_done = True
else:
self.prev_player_health = 1.0
self.prev_opponent_health = 1.0
return self._preprocess_observation(observation), custom_reward, custom_done, info

View File

Before

Width:  |  Height:  |  Size: 292 B

After

Width:  |  Height:  |  Size: 292 B

View File

@ -13,11 +13,11 @@ from custom_sf2_cv_env import StreetFighterCustomWrapper
def make_env(game, state, seed=0):
def _init():
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE)
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
env = StreetFighterCustomWrapper(env, win_template, lose_template)
env.seed(seed)
env = StreetFighterCustomWrapper(env, win_template, lose_template, testing=True)
# env.seed(seed)
return env
return _init
@ -27,6 +27,14 @@ state_stages = [
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
@ -45,7 +53,7 @@ model = PPO(
policy_kwargs=policy_kwargs,
verbose=1
)
model.load("ppo_sf2_cnn_new")
model.load(r"trained_models_cv_test/ppo_sf2_chunli_final")
obs = env.reset()
done = False

View File

@ -1,3 +1,6 @@
import os
import random
import gym
import cv2
import retro
@ -5,19 +8,33 @@ import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
import torch
import torch.nn as nn
from custom_cnn import CustomCNN
from custom_sf2_cv_env import StreetFighterCustomWrapper
class RandomOpponentChangeCallback(BaseCallback):
def __init__(self, stages, opponent_interval, save_dir, verbose=0):
super(RandomOpponentChangeCallback, self).__init__(verbose)
self.stages = stages
self.opponent_interval = opponent_interval
def _on_step(self) -> bool:
if self.n_calls % self.opponent_interval == 0:
new_state = random.choice(self.stages)
print("\nCurrent state:", new_state)
self.training_env.env_method("load_state", new_state, indices=None)
return True
def make_env(game, state, seed=0):
def _init():
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE)
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
env = StreetFighterCustomWrapper(env, win_template, lose_template)
env.seed(seed)
# env.seed(seed)
return env
return _init
@ -25,15 +42,24 @@ def main():
# Set up the environment and model
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.ChunLiVsGuile",
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
# "Champion.Level1.ChunLiVsGuile",
# "Champion.Level2.ChunLiVsKen",
# "Champion.Level3.ChunLiVsChunLi",
# "Champion.Level4.ChunLiVsZangief",
# "Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
num_envs = 8
# env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
policy_kwargs = {
@ -46,7 +72,7 @@ def main():
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=2048,
n_steps=5400,
batch_size=64,
n_epochs=10,
learning_rate=0.0003,
@ -59,9 +85,25 @@ def main():
use_sde=False,
sde_sample_freq=-1
)
model.learn(total_timesteps=int(500000))
model.save("ppo_sf2_cnn_new")
# Set the save directory
save_dir = "trained_models_cv_level6up"
os.makedirs(save_dir, exist_ok=True)
# Set up callbacks
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
model.learn(
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback, stage_increase_callback]
)
# Save the final model
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
if __name__ == "__main__":
main()