mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 23:20:43 +00:00
92 lines
2.8 KiB
Python
92 lines
2.8 KiB
Python
import gym
|
|
import cv2
|
|
import retro
|
|
import numpy as np
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
|
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# Custom feature extractor (CNN)
|
|
class CustomCNN(BaseFeaturesExtractor):
|
|
def __init__(self, observation_space: gym.Space):
|
|
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
|
self.cnn = nn.Sequential(
|
|
nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
|
nn.ReLU(),
|
|
nn.Flatten(),
|
|
nn.Linear(3136, self.features_dim),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
|
return self.cnn(observations)
|
|
|
|
# Custom environment wrapper for preprocessing
|
|
class CustomAtariWrapper(gym.Wrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
# self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
|
|
|
def _preprocess_observation(self, observation):
|
|
observation = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
|
return np.expand_dims(observation, axis=-1)
|
|
|
|
def reset(self):
|
|
observation = self.env.reset()
|
|
return self._preprocess_observation(observation)
|
|
|
|
def step(self, action):
|
|
observation, reward, done, info = self.env.step(action)
|
|
return self._preprocess_observation(observation), reward, done, info
|
|
|
|
def make_env(game, state, seed=0):
|
|
def _init():
|
|
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
|
env = CustomAtariWrapper(env)
|
|
env.seed(seed)
|
|
return env
|
|
return _init
|
|
|
|
def main():
|
|
|
|
# Set up the environment and model
|
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
|
state_stages = [
|
|
"Champion.Level1.ChunLiVsGuile",
|
|
"Champion.Level2.ChunLiVsKen",
|
|
"Champion.Level3.ChunLiVsChunLi",
|
|
"Champion.Level4.ChunLiVsZangief",
|
|
# Add other stages as necessary
|
|
]
|
|
|
|
num_envs = 8
|
|
seed = 42
|
|
|
|
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
|
|
|
policy_kwargs = {
|
|
'features_extractor_class': CustomCNN
|
|
}
|
|
|
|
model = PPO(
|
|
"CnnPolicy",
|
|
env,
|
|
device="cuda",
|
|
policy_kwargs=policy_kwargs,
|
|
verbose=1
|
|
)
|
|
model.learn(total_timesteps=int(1000))
|
|
|
|
model.save("ppo_sf2_cnn")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
# missing reward function |