mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-03 22:50:43 +00:00
random training
This commit is contained in:
parent
41650747b0
commit
dc4aad6e8a
Binary file not shown.
Binary file not shown.
@ -2,6 +2,7 @@ import gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
from torchvision.models import mobilenet_v3_small
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
@ -20,4 +21,5 @@ class CustomCNN(BaseFeaturesExtractor):
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension
|
||||
return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
|
||||
# Custom environment wrapper
|
||||
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
def __init__(self, env, win_template, lose_template, threshold=0.65):
|
||||
def __init__(self, env, win_template, lose_template, testing=False, threshold=0.65):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
self.win_template = win_template
|
||||
self.lose_template = lose_template
|
||||
@ -18,24 +18,21 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0.0, high=1.0, shape=(84, 84, 1), dtype=np.float32
|
||||
)
|
||||
|
||||
self.testing = testing
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
# Print the size of self.game_screen_gray
|
||||
# print("self.game_screen_gray size: ", self.game_screen_gray.shape)
|
||||
# Print the size of the observation
|
||||
# print("Observation size: ", observation.shape)
|
||||
resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
|
||||
return np.expand_dims(resized_image, axis=-1)
|
||||
|
||||
def _check_game_over(self):
|
||||
win_res = cv2.matchTemplate(self.game_screen_gray, self.win_template, cv2.TM_CCOEFF_NORMED)
|
||||
lose_res = cv2.matchTemplate(self.game_screen_gray, self.lose_template, cv2.TM_CCOEFF_NORMED)
|
||||
if np.max(win_res) >= self.threshold:
|
||||
return True
|
||||
if np.max(lose_res) >= self.threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_win_or_lose_bonus(self):
|
||||
if self.prev_player_health > self.prev_opponent_health:
|
||||
# print('You win!')
|
||||
return 200
|
||||
else:
|
||||
# print('You lose!')
|
||||
return -200
|
||||
|
||||
def _get_reward(self):
|
||||
player_health_area = self.game_screen_gray[15:20, 32:120]
|
||||
@ -48,17 +45,13 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
player_health_diff = self.prev_player_health - player_health
|
||||
opponent_health_diff = self.prev_opponent_health - opponent_health
|
||||
|
||||
reward = (opponent_health_diff - player_health_diff) * 100
|
||||
|
||||
# Add bonus for successful attacks or penalize for taking damage
|
||||
if opponent_health_diff > player_health_diff:
|
||||
reward += 10 # Bonus for successful attacks
|
||||
elif opponent_health_diff < player_health_diff:
|
||||
reward -= 10 # Penalty for taking damage
|
||||
reward = (opponent_health_diff - player_health_diff) * 100 # max would be 100
|
||||
|
||||
self.prev_player_health = player_health
|
||||
self.prev_opponent_health = opponent_health
|
||||
|
||||
# Print the health values of the player and the opponent
|
||||
# print("Player health: %f Opponent health:%f" % (player_health, opponent_health))
|
||||
return reward
|
||||
|
||||
def reset(self):
|
||||
@ -68,7 +61,17 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
return self._preprocess_observation(observation)
|
||||
|
||||
def step(self, action):
|
||||
observation, _, _, info = self.env.step(action)
|
||||
# observation, _, _, info = self.env.step(action)
|
||||
observation, _reward, _done, info = self.env.step(action)
|
||||
custom_reward = self._get_reward()
|
||||
custom_done = self._check_game_over() or False
|
||||
|
||||
custom_done = False
|
||||
if self.prev_player_health <= 0.00001 or self.prev_opponent_health <= 0.00001:
|
||||
custom_reward += self._get_win_or_lose_bonus()
|
||||
if not self.testing:
|
||||
custom_done = True
|
||||
else:
|
||||
self.prev_player_health = 1.0
|
||||
self.prev_opponent_health = 1.0
|
||||
|
||||
return self._preprocess_observation(observation), custom_reward, custom_done, info
|
Before Width: | Height: | Size: 292 B After Width: | Height: | Size: 292 B |
@ -13,11 +13,11 @@ from custom_sf2_cv_env import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
||||
env = StreetFighterCustomWrapper(env, win_template, lose_template)
|
||||
env.seed(seed)
|
||||
env = StreetFighterCustomWrapper(env, win_template, lose_template, testing=True)
|
||||
# env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
@ -27,6 +27,14 @@ state_stages = [
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
@ -45,7 +53,7 @@ model = PPO(
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
)
|
||||
model.load("ppo_sf2_cnn_new")
|
||||
model.load(r"trained_models_cv_test/ppo_sf2_chunli_final")
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
@ -59,4 +67,4 @@ while True:
|
||||
if render_time < 0.0111:
|
||||
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
|
||||
|
||||
# env.close()
|
||||
# env.close()
|
||||
|
@ -1,3 +1,6 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import gym
|
||||
import cv2
|
||||
import retro
|
||||
@ -5,19 +8,33 @@ import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from custom_sf2_cv_env import StreetFighterCustomWrapper
|
||||
|
||||
class RandomOpponentChangeCallback(BaseCallback):
|
||||
def __init__(self, stages, opponent_interval, save_dir, verbose=0):
|
||||
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
||||
self.stages = stages
|
||||
self.opponent_interval = opponent_interval
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.opponent_interval == 0:
|
||||
new_state = random.choice(self.stages)
|
||||
print("\nCurrent state:", new_state)
|
||||
self.training_env.env_method("load_state", new_state, indices=None)
|
||||
return True
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
win_template = cv2.imread('images/pattern_win_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
||||
env = StreetFighterCustomWrapper(env, win_template, lose_template)
|
||||
env.seed(seed)
|
||||
# env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
@ -25,15 +42,24 @@ def main():
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
# "Champion.Level1.ChunLiVsGuile",
|
||||
# "Champion.Level2.ChunLiVsKen",
|
||||
# "Champion.Level3.ChunLiVsChunLi",
|
||||
# "Champion.Level4.ChunLiVsZangief",
|
||||
# "Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
num_envs = 8
|
||||
|
||||
# env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
|
||||
policy_kwargs = {
|
||||
@ -46,7 +72,7 @@ def main():
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
n_steps=2048,
|
||||
n_steps=5400,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
learning_rate=0.0003,
|
||||
@ -59,9 +85,25 @@ def main():
|
||||
use_sde=False,
|
||||
sde_sample_freq=-1
|
||||
)
|
||||
model.learn(total_timesteps=int(500000))
|
||||
|
||||
model.save("ppo_sf2_cnn_new")
|
||||
# Set the save directory
|
||||
save_dir = "trained_models_cv_level6up"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Set up callbacks
|
||||
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
|
||||
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
||||
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback, stage_increase_callback]
|
||||
)
|
||||
|
||||
# Save the final model
|
||||
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user