2023-03-27 17:31:23 +00:00
|
|
|
import gym
|
|
|
|
import cv2
|
|
|
|
import retro
|
|
|
|
import numpy as np
|
|
|
|
from stable_baselines3 import PPO
|
|
|
|
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
|
|
|
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
|
2023-03-28 07:39:30 +00:00
|
|
|
from custom_cnn import CustomCNN
|
|
|
|
from custom_sf2_cv_env import StreetFighterCustomWrapper
|
2023-03-27 17:31:23 +00:00
|
|
|
|
|
|
|
def make_env(game, state, seed=0):
|
|
|
|
def _init():
|
2023-03-28 07:39:30 +00:00
|
|
|
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
|
|
|
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
2023-03-27 17:31:23 +00:00
|
|
|
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
2023-03-28 07:39:30 +00:00
|
|
|
env = StreetFighterCustomWrapper(env, win_template, lose_template)
|
2023-03-27 17:31:23 +00:00
|
|
|
env.seed(seed)
|
|
|
|
return env
|
|
|
|
return _init
|
|
|
|
|
|
|
|
def main():
|
|
|
|
# Set up the environment and model
|
|
|
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
|
|
|
state_stages = [
|
|
|
|
"Champion.Level1.ChunLiVsGuile",
|
|
|
|
"Champion.Level2.ChunLiVsKen",
|
|
|
|
"Champion.Level3.ChunLiVsChunLi",
|
|
|
|
"Champion.Level4.ChunLiVsZangief",
|
|
|
|
# Add other stages as necessary
|
|
|
|
]
|
|
|
|
|
|
|
|
num_envs = 8
|
|
|
|
|
|
|
|
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
|
|
|
|
|
|
|
policy_kwargs = {
|
2023-03-28 07:39:30 +00:00
|
|
|
'features_extractor_class': CustomCNN
|
2023-03-27 17:31:23 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
model = PPO(
|
|
|
|
"CnnPolicy",
|
|
|
|
env,
|
|
|
|
device="cuda",
|
|
|
|
policy_kwargs=policy_kwargs,
|
|
|
|
verbose=1
|
|
|
|
)
|
|
|
|
model.learn(total_timesteps=int(1000))
|
|
|
|
|
|
|
|
model.save("ppo_sf2_cnn")
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|