mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 15:10:43 +00:00
add cv method
This commit is contained in:
parent
720b0488f7
commit
4a35b81937
51
custom_street_fighter_env.py
Normal file
51
custom_street_fighter_env.py
Normal file
@ -0,0 +1,51 @@
|
||||
import gym
|
||||
|
||||
# Create a custom environment for Street Fighter II
|
||||
class CustomStreetFighterEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super(CustomStreetFighterEnv, self).__init__(env)
|
||||
self.previous_health = 0
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
|
||||
# Reward function
|
||||
custom_reward = self.custom_reward_function(info)
|
||||
|
||||
return observation, custom_reward, done, info
|
||||
|
||||
def reset(self):
|
||||
self.previous_health = 0
|
||||
return self.env.reset()
|
||||
|
||||
def custom_reward_function(self, info):
|
||||
# Reward weights
|
||||
health_weight = 1
|
||||
hit_weight = 2
|
||||
block_weight = 1
|
||||
knockdown_weight = 5
|
||||
|
||||
# Retrieve relevant information from info
|
||||
player_health = info["health1"]
|
||||
opponent_health = info["health2"]
|
||||
player_is_hit = info["is_hit1"]
|
||||
opponent_is_hit = info["is_hit2"]
|
||||
player_is_blocking = info["is_blocking1"]
|
||||
# opponent_is_blocking = info["is_blocking2"]
|
||||
player_is_knockdown = info["is_knockdown1"]
|
||||
opponent_is_knockdown = info["is_knockdown2"]
|
||||
|
||||
# Compute reward components
|
||||
health_reward = (player_health - opponent_health) * health_weight
|
||||
hit_reward = hit_weight if opponent_is_hit else 0
|
||||
block_reward = block_weight if player_is_blocking else 0
|
||||
knockdown_reward = knockdown_weight if opponent_is_knockdown else 0
|
||||
|
||||
# Penalty components
|
||||
hit_penalty = -hit_weight if player_is_hit else 0
|
||||
knockdown_penalty = -knockdown_weight if player_is_knockdown else 0
|
||||
|
||||
# Calculate total custom reward
|
||||
custom_reward = health_reward + hit_reward + block_reward + knockdown_reward + hit_penalty + knockdown_penalty
|
||||
|
||||
return custom_reward
|
@ -11,7 +11,7 @@ retro.data.Integrations.add_custom_path(rom_directory)
|
||||
|
||||
env = retro.RetroEnv(
|
||||
game='StreetFighterIISpecialChampionEdition-Genesis',
|
||||
state='Champion.Level1.ChunLiVsGuile'
|
||||
state='Champion.Level3.ChunLiVsChunLi'
|
||||
)
|
||||
# Champion.Level2.ChunLiVsKen
|
||||
# Champion.Level3.ChunLiVsChunLi
|
||||
|
92
train_cv_sf2_ai.py
Normal file
92
train_cv_sf2_ai.py
Normal file
@ -0,0 +1,92 @@
|
||||
import gym
|
||||
import cv2
|
||||
import retro
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(3136, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
|
||||
# Custom environment wrapper for preprocessing
|
||||
class CustomAtariWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
# self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
observation = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
return np.expand_dims(observation, axis=-1)
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
return self._preprocess_observation(observation)
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return self._preprocess_observation(observation), reward, done, info
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
||||
env = CustomAtariWrapper(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
num_envs = 8
|
||||
seed = 42
|
||||
|
||||
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
)
|
||||
model.learn(total_timesteps=int(1000))
|
||||
|
||||
model.save("ppo_sf2_cnn")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# missing reward function
|
@ -76,12 +76,6 @@ def main():
|
||||
seed=None,
|
||||
)
|
||||
|
||||
|
||||
checkpoint_path = None
|
||||
if checkpoint_path is not None:
|
||||
model = model.load(checkpoint_path, env)
|
||||
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@ -101,3 +95,4 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
1344
training_history.txt
1344
training_history.txt
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user