ram_based_image_stack

This commit is contained in:
linyiLYi 2023-03-31 02:10:25 +08:00
parent d4fb6dbc59
commit 02e39f0a52
64 changed files with 13545 additions and 17 deletions

View File

@ -0,0 +1,234 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "10d267bb",
"metadata": {},
"outputs": [],
"source": [
"import retro"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1ef8ff20",
"metadata": {},
"outputs": [],
"source": [
"game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n",
"state = \"Champion.Level1.ChunLiVsGuile\"\n",
"env = retro.make(game=game, state=state)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5ce656b8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.action_space.sample()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8c3f0a4d",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(200, 256, 3)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.observation_space.sample().shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "46db7b05",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(200, 256, 3)\n",
"{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n"
]
}
],
"source": [
"observation = env.reset()\n",
"print(observation.shape)\n",
"\n",
"action = env.action_space.sample()\n",
"obs, rewards, done, info = env.step(action)\n",
"print(info)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "09f0c6b0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MultiBinary(12)\n"
]
}
],
"source": [
"from gym.spaces import Box, MultiBinary\n",
"\n",
"print(MultiBinary(12))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "97df18cf",
"metadata": {},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"from gym.spaces import Box, MultiBinary\n",
"\n",
"class StreetFighter(gym.Env):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n",
" self.action_space = MultiBinary(12)\n",
" self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n",
" \n",
" self.full_hp = 176\n",
" self.player_health = self.full_hp\n",
" self.oppont_health = self.full_hp\n",
" \n",
" self.score = 0\n",
" \n",
" def __preprocess(self, observation):\n",
" gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n",
" resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n",
" return resize\n",
"\n",
" def step(self, action):\n",
"\n",
" obs, reward, done, info = self.game.step(action)\n",
" custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n",
"\n",
" # During fighting, either player or opponent has positive health points.\n",
" if info['health'] > 0 or info['enemy_health'] > 0:\n",
"\n",
" # Player Loses\n",
" if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n",
" reward = (-self.full_hp) * info['enemy_health']\n",
"\n",
" # Player Wins\n",
" elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n",
" reward = self.full_hp * info['health']\n",
"\n",
" # During Fighting\n",
" else:\n",
" reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n",
" \n",
" self.player_health = info['health']\n",
" self.oppont_health = info['enemy_health']\n",
" \n",
" return custom_obs, reward, done, info\n",
" \n",
" def render(self, *args, **kwargs):\n",
" self.game.render()\n",
" \n",
" def reset(self):\n",
" obs = self.game.reset()\n",
" custom_obs = self.__preprocess(obs)\n",
" self.previous_frame = obs\n",
" \n",
" self.player_health = self.full_hp\n",
" self.oppont_health = self.full_hp\n",
" return custom_obs\n",
"\n",
" def close(self):\n",
" self.game.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "0b137b88",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(84, 84, 1)\n"
]
}
],
"source": [
"env.close()\n",
"env = StreetFighter()\n",
"print(env.observation_space.shape)\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2da50dbc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,46 @@
import time
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state = "Champion.Level1.ChunLiVsGuile"#"ChampionX.Level1.ChunLiVsKen"
env = make_env(game, state)()
env = Monitor(env, 'logs/')
num_episodes = 30
episode_reward_sum = 0
for _ in range(num_episodes):
done = False
obs = env.reset()
total_reward = 0
while not done:
timestamp = time.time()
obs, reward, done, info = env.step(env.action_space.sample())
if reward != 0:
total_reward += reward
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
env.render()
print("Total reward: {}".format(total_reward))
episode_reward_sum += total_reward
env.close()
print("Average reward for random strategy: {}".format(episode_reward_sum/num_episodes))

View File

@ -0,0 +1,52 @@
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.ChunLiVsGuile",
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
env = make_env(game, state_stages[0])()
# Wrap the environment
# env = Monitor(env, 'logs/')
policy_kwargs = {'features_extractor_class': CustomCNN}
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
model = PPO.load(r"dummy_model_ppo_chunli")
# model.load(r"trained_models/ppo_chunli_864000_steps")
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10, deterministic=False, return_episode_rewards=True)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

View File

@ -0,0 +1,53 @@
#{"t_start": 1680186251.3110938, "env_id": null}
r,l,t
-50,2150,3.695703
-40,2886,12.564373
-128,2196,20.599987
-217,3000,25.620172
-210,2753,34.631877
27,2177,42.807461
-161,2502,46.870715
-227,2122,54.492589
-289,1567,61.321581
1,2075,64.463465
130,2465,72.662509
-192,3007,82.093462
3927.0,6468,97.611361
-109,1823,104.996175
200,1820,112.333123
-300,2478,116.020238
-42,2351,124.010789
-263,1990,127.212089
-351,1486,134.405471
-225,2611,143.112158
-56,3290,153.69294
-65,2138,157.640509
62,3161,167.244644
-189,2652,175.720904
224,2138,179.193385
-48,3706,189.4923
-209,3172,199.319699
-98,2059,207.148574
51,2787,216.523835
-88,3218,225.952495
-263,1828,228.707771
-38,2328,236.642072
7,3179,245.83899
-133,2421,249.558141
-296,1684,256.702009
-211,2881,266.1996
-261,1710,269.33675
-176,1974,277.229695
184,1310,279.58493
218,2222,288.236686
-229,2460,291.904952
-345,2510,299.876746
-345,2510,302.781091
-345,2510,305.701696
-345,2510,308.687105
-345,2510,311.624716
-345,2510,314.566203
-345,2510,317.608539
-345,2510,320.618201
-345,2510,323.649133
-345,2510,326.561072
Can't render this file because it contains an unexpected character in line 1 and column 3.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,93 @@
import torch
from torch.optim import Optimizer
class RMSpropTF(Optimizer):
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10,
weight_decay=0, momentum=0., centered=False,
decoupled_decay=False, lr_in_momentum=True
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps,
centered=centered, weight_decay=weight_decay,
decoupled_decay=decoupled_decay,
lr_in_momentum=lr_in_momentum
)
super(RMSpropTF, self).__init__(params, defaults)
def __setstate__(self, state):
super(RMSpropTF, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('momentum', 0)
group.setdefault('centered', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the
model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
if group['momentum'] > 0:
state['momentum_buffer'] = torch.zeros_like(p)
if group['centered']:
state['grad_avg'] = torch.zeros_like(p)
square_avg = state['square_avg']
one_minus_alpha = 1. - group['alpha']
state['step'] += 1
if group['weight_decay'] != 0:
if group['decoupled_decay']:
p.mul_(1. - group['lr'] * group['weight_decay'])
else:
grad = grad.add(p, alpha=group['weight_decay'])
# Tensorflow order of ops for updating squared avg
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
# square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
if group['centered']:
grad_avg = state['grad_avg']
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
# grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) #
# PyTorch original
else:
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
if group['momentum'] > 0:
buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
p.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group['momentum']).addcdiv_(grad, avg)
p.add_(buf, alpha=-group['lr'])
else:
p.addcdiv_(grad, avg, value=-group['lr'])
return loss

View File

@ -0,0 +1,97 @@
import collections
import gym
import cv2
import numpy as np
# Custom environment wrapper
class StreetFighterCustomWrapper(gym.Wrapper):
def __init__(self, env, testing=False):
super(StreetFighterCustomWrapper, self).__init__(env)
self.env = env
# Use a deque to store the last 4 frames
self.num_frames = 3
self.frame_stack = collections.deque(maxlen=self.num_frames)
self.full_hp = 176
self.prev_player_health = self.full_hp
self.prev_oppont_health = self.full_hp
# Update observation space to include stacked grayscale images
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8)
self.testing = testing
def _preprocess_observation(self, observation):
obs_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
obs_gray_resized = cv2.resize(obs_gray, (84, 84), interpolation=cv2.INTER_AREA)
# Add the resized image to the frame stack
self.frame_stack.append(obs_gray_resized)
# Stack the frames and return the "image"
stacked_frames = np.stack(self.frame_stack, axis=-1)
return stacked_frames
def reset(self):
observation = self.env.reset()
self.prev_player_health = self.full_hp
self.prev_oppont_health = self.full_hp
obs_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
obs_gray_resized = cv2.resize(obs_gray, (84, 84), interpolation=cv2.INTER_AREA)
# Clear the frame stack and add the first observation [num_frames] times
self.frame_stack.clear()
for _ in range(self.num_frames):
self.frame_stack.append(obs_gray_resized)
return np.stack(self.frame_stack, axis=-1)
def step(self, action):
obs, reward, done, info = self.env.step(action)
# During fighting, either player or opponent has positive health points.
if info['health'] > 0 or info['enemy_health'] > 0:
# Player Loses
if info['health'] < 0 and info['enemy_health'] > 0:
# reward = (-self.full_hp) * info['enemy_health'] * 0.05 # max = 0.05 * 176 * 176 = 1548.8
reward = -info['enemy_health'] # Use the left over health points as penalty
# Prevent data overflow
if reward < -self.full_hp:
reward = 0
done = True
# Player Wins
elif info['enemy_health'] < 0 and info['health'] > 0:
# reward = self.full_hp * info['health'] * 0.05
reward = info['health']
# Prevent data overflow
if reward > self.full_hp:
reward = 0
done = True
# During Fighting
else:
reward = (self.prev_oppont_health - info['enemy_health']) - (self.prev_player_health - info['health'])
# Prevent data overflow
if reward > 99:
reward = 0
self.prev_player_health = info['health']
self.prev_oppont_health = info['enemy_health']
if self.testing:
done = False
return self._preprocess_observation(obs), reward, done, info

View File

@ -0,0 +1,314 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bfc79b8c",
"metadata": {},
"outputs": [],
"source": [
"import retro"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c24fbcab",
"metadata": {},
"outputs": [],
"source": [
"game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n",
"state = \"Champion.Level1.ChunLiVsGuile\"\n",
"env = retro.make(game=game, state=state)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "59839d9c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.action_space.sample()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e068cb0a",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(200, 256, 3)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.observation_space.sample().shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "1cb0297f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(200, 256, 3)\n",
"{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n"
]
}
],
"source": [
"observation = env.reset()\n",
"print(observation.shape)\n",
"\n",
"action = env.action_space.sample()\n",
"obs, rewards, done, info = env.step(action)\n",
"print(info)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0eaa5cc8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MultiBinary(12)\n"
]
}
],
"source": [
"from gym.spaces import Box, MultiBinary\n",
"\n",
"print(MultiBinary(12))"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "49f6cf5c",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"\n",
"import gym\n",
"import numpy as np\n",
"from gym.spaces import Box, MultiBinary\n",
"\n",
"class StreetFighter(gym.Env):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n",
" self.action_space = MultiBinary(12)\n",
" self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n",
" \n",
" self.full_hp = 176\n",
" self.player_health = self.full_hp\n",
" self.oppont_health = self.full_hp\n",
" \n",
" self.score = 0\n",
" \n",
" def __preprocess(self, observation):\n",
" gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n",
" resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n",
" return resize\n",
"\n",
" def step(self, action):\n",
"\n",
" obs, reward, done, info = self.game.step(action)\n",
" custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n",
"\n",
" # During fighting, either player or opponent has positive health points.\n",
" if info['health'] > 0 or info['enemy_health'] > 0:\n",
"\n",
" # Player Loses\n",
" if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n",
" reward = (-self.full_hp) * info['enemy_health']\n",
"\n",
" # Player Wins\n",
" elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n",
" reward = self.full_hp * info['health']\n",
"\n",
" # During Fighting\n",
" else:\n",
" reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n",
" \n",
" self.player_health = info['health']\n",
" self.oppont_health = info['enemy_health']\n",
" \n",
" return custom_obs, reward, done, info\n",
" \n",
" def render(self, *args, **kwargs):\n",
" self.game.render()\n",
" \n",
" def reset(self):\n",
" obs = self.game.reset()\n",
" custom_obs = self.__preprocess(obs)\n",
" self.previous_frame = obs\n",
" \n",
" self.player_health = self.full_hp\n",
" self.oppont_health = self.full_hp\n",
" return custom_obs\n",
"\n",
" def close(self):\n",
" self.game.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "6ec30177",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(84, 84)\n"
]
}
],
"source": [
"env.close()\n",
"env = StreetFighter()\n",
"print(env.observation_space.shape)\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "7d9eab3a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\ProgramData\\Anaconda3\\envs\\StreetFighterAI\\lib\\site-packages\\pyglet\\image\\codecs\\wic.py:289: UserWarning: [WinError -2147417850] Cannot change thread mode after it is set\n",
" warnings.warn(str(err))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"-22 154 176\n",
"-32 122 176\n",
"29 122 147\n",
"7 122 140\n",
"-31 91 140\n",
"29 91 111\n",
"-23 68 111\n",
"-24 44 111\n",
"-24 20 111\n",
"31 20 80\n",
"10 20 70\n",
"45 20 25\n",
"5 20 20\n",
"-15 5 20\n",
"19 5 1\n",
"-176 -1 1\n",
"46 176 130\n",
"7 176 123\n",
"-24 152 123\n",
"29 152 94\n",
"-24 128 94\n",
"7 128 87\n",
"39 128 48\n",
"-31 97 48\n",
"36 97 12\n",
"-24 73 12\n",
"-24 49 12\n",
"8624 49 -1\n",
"39 176 137\n",
"-24 152 137\n",
"-23 129 137\n",
"-23 106 137\n",
"-26 80 137\n",
"-24 56 137\n",
"-23 33 137\n",
"-21 12 137\n",
"-12 0 137\n",
"-24112 -1 137\n"
]
}
],
"source": [
"## Checking Rewards functionality\n",
"import time\n",
"\n",
"env = StreetFighter()\n",
"obs = env.reset()\n",
"done = False\n",
"\n",
"for game in range(5):\n",
" while not done:\n",
" if done:\n",
" obs = env.reset()\n",
" env.render()\n",
" obs, reward, done, info = env.step(env.action_space.sample())\n",
" if reward != 0:\n",
" print(reward, info['health'], info['enemy_health'])\n",
" time.sleep(0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1ae8310",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,69 @@
import time
import retro
from stable_baselines3 import PPO
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
"ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
env = make_env(game, state_stages[0])()
model = PPO(
"CnnPolicy",
env,
verbose=1
)
model_path = r"optuna/trial_1_best_model" # Average reward for optuna/trial_1_best_model: -82.3
model.load(model_path)
obs = env.reset()
done = False
num_episodes = 30
episode_reward_sum = 0
for _ in range(num_episodes):
done = False
obs = env.reset()
total_reward = 0
while not done:
timestamp = time.time()
obs, reward, done, info = env.step(env.action_space.sample())
if reward != 0:
total_reward += reward
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
env.render()
print("Total reward: {}".format(total_reward))
episode_reward_sum += total_reward
env.close()
print("Average reward for {}: {}".format(model_path, episode_reward_sum/num_episodes))

View File

@ -0,0 +1,125 @@
import os
import random
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from rmsprop_optim import RMSpropTF
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
class RandomOpponentChangeCallback(BaseCallback):
def __init__(self, stages, opponent_interval, verbose=0):
super(RandomOpponentChangeCallback, self).__init__(verbose)
self.stages = stages
self.opponent_interval = opponent_interval
def _on_step(self) -> bool:
if self.n_calls % self.opponent_interval == 0:
new_state = random.choice(self.stages)
print("\nCurrent state:", new_state)
self.training_env.env_method("load_state", new_state, indices=None)
return True
def make_env(game, state, seed=0):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env.seed(seed)
return env
return _init
def main():
# Set up the environment and model
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"ChampionX.Level1.ChunLiVsKen",
"ChampionX.Level2.ChunLiVsChunLi",
"ChampionX.Level3.ChunLiVsZangief",
"ChampionX.Level4.ChunLiVsDhalsim",
"ChampionX.Level5.ChunLiVsRyu",
"ChampionX.Level6.ChunLiVsEHonda",
"ChampionX.Level7.ChunLiVsBlanka",
"ChampionX.Level8.ChunLiVsGuile",
"ChampionX.Level9.ChunLiVsBalrog",
"ChampionX.Level10.ChunLiVsVega",
"ChampionX.Level11.ChunLiVsSagat",
"ChampionX.Level12.ChunLiVsBison"
# Add other stages as necessary
]
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
num_envs = 8
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
# Using CustomCNN as the feature extractor
policy_kwargs = {
'features_extractor_class': CustomCNN
}
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=5400,
batch_size=64,
learning_rate=0.0001,
ent_coef=0.01,
clip_range=0.2,
gamma=0.99,
gae_lambda=0.95,
tensorboard_log="logs/"
)
# Set the save directory
save_dir = "trained_models"
os.makedirs(save_dir, exist_ok=True)
# Load the model from file
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
# Load model and modify the learning rate and entropy coefficient
# custom_objects = {
# "learning_rate": 0.0002
# }
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
# Set up callbacks
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
# model_params = {
# 'n_steps': 5,
# 'gamma': 0.99,
# 'gae_lambda':1,
# 'learning_rate': 7e-4,
# 'vf_coef': 0.5,
# 'ent_coef': 0.0,
# 'max_grad_norm':0.5,
# 'rms_prop_eps':1e-05
# }
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
model.learn(
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback, stage_increase_callback]
)
env.close()
# Save the final model
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,81 @@
import gym
import retro
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state, seed=0):
def _init():
env = retro.RetroEnv(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env = Monitor(env)
env.seed(seed)
return env
return _init
def objective(trial):
game = "StreetFighterIISpecialChampionEdition-Genesis"
env = make_env(game, state="ChampionX.Level1.ChunLiVsKen")()
# Suggest hyperparameters
learning_rate = trial.suggest_float("learning_rate", 5e-5, 1e-3, log=True)
n_steps = trial.suggest_int("n_steps", 256, 8192, log=True)
batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
gamma = trial.suggest_float("gamma", 0.9, 0.9999)
gae_lambda = trial.suggest_float("gae_lambda", 0.9, 1.0)
clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
ent_coef = trial.suggest_float("ent_coef", 1e-4, 1e-2, log=True)
vf_coef = trial.suggest_float("vf_coef", 0.1, 1.0)
# Using CustomCNN as the feature extractor
policy_kwargs = {
'features_extractor_class': CustomCNN
}
# Train the model
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=n_steps,
batch_size=batch_size,
learning_rate=learning_rate,
ent_coef=ent_coef,
clip_range=clip_range,
vf_coef=vf_coef,
gamma=gamma,
gae_lambda=gae_lambda
)
for iteration in range(10):
model.learn(total_timesteps=100000)
mean_reward, _std_reward = evaluate_policy(model, env, n_eval_episodes=10)
trial.report(mean_reward, iteration)
if trial.should_prune():
raise optuna.TrialPruned()
return mean_reward
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, timeout=7200) # Run optimization for 100 trials or 2 hours, whichever comes first
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(f"{key}: {value}")

View File

@ -0,0 +1,69 @@
import os
import retro
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
LOG_DIR = 'logs/'
OPT_DIR = 'optuna/'
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(OPT_DIR, exist_ok=True)
def optimize_ppo(trial):
return {
'n_steps':trial.suggest_int('n_steps', 1024, 8192, log=True),
'gamma':trial.suggest_float('gamma', 0.9, 0.9999),
'learning_rate':trial.suggest_float('learning_rate', 5e-5, 1e-4, log=True),
'clip_range':trial.suggest_float('clip_range', 0.1, 0.4),
'gae_lambda':trial.suggest_float('gae_lambda', 0.8, 0.99)
}
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
def optimize_agent(trial):
game = "StreetFighterIISpecialChampionEdition-Genesis"
state = "Champion.Level1.ChunLiVsGuile"#"ChampionX.Level1.ChunLiVsKen"
try:
model_params = optimize_ppo(trial)
# Create environment
env = make_env(game, state)()
env = Monitor(env, LOG_DIR)
# Create algo
model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, **model_params)
model.learn(total_timesteps=100000)
# Evaluate model
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=30)
env.close()
SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
model.save(SAVE_PATH)
return mean_reward
except Exception as e:
return -1
# Creating the experiment
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=10, n_jobs=1)
print(study.best_params)
print(study.best_trial)

View File

@ -0,0 +1,39 @@
import time
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.RetroEnv(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env, testing=True)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state = "Champion.Level1.ChunLiVsGuile"
env = make_env(game, state)()
model = PPO.load(r"trained_models_continued/ppo_chunli_6048000_steps")
obs = env.reset()
done = False
while not done:
timestamp = time.time()
action, _ = model.predict(obs)
obs, reward, done, info = env.step(action)
print(info)
if reward != 0:
print(reward, info['health'], info['enemy_health'])
env.render()
env.close()

View File

@ -0,0 +1,24 @@
import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# Custom feature extractor (CNN)
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(CustomCNN, self).__init__(observation_space, features_dim=512)
self.cnn = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16384, self.features_dim),
nn.ReLU()
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.cnn(observations)

View File

@ -0,0 +1,47 @@
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.RetroEnv(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.ChunLiVsGuile",
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
env = make_env(game, state_stages[0])()
# Wrap the environment
env = Monitor(env, 'logs/')
env = DummyVecEnv([lambda: env])
model = PPO.load('trained_models/ppo_chunli_1296000_steps')
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

View File

@ -0,0 +1,12 @@
#{"t_start": 1680163278.6497958, "env_id": null}
r,l,t
-1115.766667,2842,13.829476
-1115.766667,2842,22.367655
-1115.766667,2842,32.010939
-1115.766667,2842,41.401216
-1115.766667,2842,50.451062
-1115.766667,2842,59.522487
-1115.766667,2842,68.723222
-1115.766667,2842,78.205462
-1115.766667,2842,88.455592
-1115.766667,2842,97.656297
Can't render this file because it contains an unexpected character in line 1 and column 3.

View File

@ -12,8 +12,6 @@ class StreetFighterCustomWrapper(gym.Wrapper):
def __init__(self, env, testing=False, threshold=0.65):
super(StreetFighterCustomWrapper, self).__init__(env)
self.action_space = MultiBinary(12)
# Use a deque to store the last 4 frames
self.frame_stack = collections.deque(maxlen=4)
@ -89,7 +87,7 @@ class StreetFighterCustomWrapper(gym.Wrapper):
def step(self, action):
# observation, _, _, info = self.env.step(action)
observation, _reward, _done, info = self.env.step(self.env.action_space.sample())
observation, _reward, _done, info = self.env.step(action)
custom_reward = self._get_reward()
custom_reward -= 1.0 / 60.0 # penalty for each step (-1 points per second)

View File

@ -53,7 +53,7 @@ model = PPO(
policy_kwargs=policy_kwargs,
verbose=1
)
model.load(r"trained_models_continued/ppo_chunli_432000_steps")
model.load(r"trained_models/ppo_chunli_1296000_steps")
obs = env.reset()
done = False

View File

@ -1,13 +1,9 @@
import os
import random
import gym
import cv2
import retro
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from custom_cnn import CustomCNN
@ -77,20 +73,16 @@ def main():
verbose=1,
n_steps=5400,
batch_size=64,
n_epochs=10,
learning_rate=0.0003,
ent_coef=0.01,
clip_range=0.2,
clip_range_vf=None,
gamma=0.99,
gae_lambda=0.95,
max_grad_norm=0.5,
use_sde=False,
sde_sample_freq=-1
tensorboard_log="logs/"
)
# Set the save directory
save_dir = "trained_models_continued"
save_dir = "trained_models_continued_new"
os.makedirs(save_dir, exist_ok=True)
# Load the model from file
@ -99,8 +91,7 @@ def main():
# Load model and modify the learning rate and entropy coefficient
custom_objects = {
"learning_rate": 0.00005,
"ent_coef": 0.2
"learning_rate": 0.0001
}
model = PPO.load(model_path, env=env, device="cuda", custom_objects=custom_objects)
@ -110,7 +101,6 @@ def main():
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
model.learn(
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback, stage_increase_callback]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,81 @@
import gym
import retro
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state, seed=0):
def _init():
env = retro.RetroEnv(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env = Monitor(env)
env.seed(seed)
return env
return _init
def objective(trial):
game = "StreetFighterIISpecialChampionEdition-Genesis"
env = make_env(game, state="ChampionX.Level1.ChunLiVsKen")()
# Suggest hyperparameters
learning_rate = trial.suggest_float("learning_rate", 5e-5, 1e-3, log=True)
n_steps = trial.suggest_int("n_steps", 256, 8192, log=True)
batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
gamma = trial.suggest_float("gamma", 0.9, 0.9999)
gae_lambda = trial.suggest_float("gae_lambda", 0.9, 1.0)
clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
ent_coef = trial.suggest_float("ent_coef", 1e-4, 1e-2, log=True)
vf_coef = trial.suggest_float("vf_coef", 0.1, 1.0)
# Using CustomCNN as the feature extractor
policy_kwargs = {
'features_extractor_class': CustomCNN
}
# Train the model
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=n_steps,
batch_size=batch_size,
learning_rate=learning_rate,
ent_coef=ent_coef,
clip_range=clip_range,
vf_coef=vf_coef,
gamma=gamma,
gae_lambda=gae_lambda
)
for iteration in range(10):
model.learn(total_timesteps=100000)
mean_reward, _std_reward = evaluate_policy(model, env, n_eval_episodes=10)
trial.report(mean_reward, iteration)
if trial.should_prune():
raise optuna.TrialPruned()
return mean_reward
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=100, timeout=7200) # Run optimization for 100 trials or 2 hours, whichever comes first
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(f"{key}: {value}")

View File

@ -0,0 +1,25 @@
import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# Custom feature extractor (CNN)
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(CustomCNN, self).__init__(observation_space, features_dim=512)
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16384, self.features_dim),
nn.ReLU()
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
observations = observations.unsqueeze(1)
return self.cnn(observations)

View File

@ -0,0 +1,2 @@
#{"t_start": 1680175884.8182795, "env_id": null}
r,l,t
Can't render this file because it contains an unexpected character in line 1 and column 3.

View File

@ -0,0 +1,72 @@
import gym
import cv2
import numpy as np
# Custom environment wrapper
class StreetFighterCustomWrapper(gym.Wrapper):
def __init__(self, env, testing=False):
super(StreetFighterCustomWrapper, self).__init__(env)
self.env = env
self.testing = testing
# Store the previous frame
self.prev_frame = None
self.full_hp = 176
self.prev_player_health = self.full_hp
self.prev_oppont_health = self.full_hp
# Update observation space to include one grayscale frame difference image
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _preprocess_observation(self, observation):
obs_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
obs_gray_resized = cv2.resize(obs_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
return obs_gray_resized
def reset(self):
self.prev_player_health = self.full_hp
self.prev_oppont_health = self.full_hp
observation = self.env.reset()
# Reset the previous frame
self.prev_frame = self._preprocess_observation(observation)
return np.zeros_like(self.prev_frame)
def step(self, action):
observation, _reward, _done, info = self.env.step(action)
obs_gray_resized = self._preprocess_observation(observation)
if self.prev_frame is not None:
frame_delta = obs_gray_resized - self.prev_frame
else:
frame_delta = np.zeros_like(obs_gray_resized)
self.prev_frame = obs_gray_resized
# During fighting, either player or opponent has positive health points.
if info['health'] > 0 or info['enemy_health'] > 0:
# Player Loses
if info['health'] < 0 and info['enemy_health'] > 0:
reward = (-self.full_hp) * info['enemy_health']
done = True
# Player Wins
elif info['enemy_health'] < 0 and info['health'] > 0:
reward = self.full_hp * info['health']
done = True
# During Fighting
else:
reward = (self.prev_oppont_health - info['enemy_health']) - (self.prev_player_health - info['health'])
self.prev_player_health = info['health']
self.prev_oppont_health = info['enemy_health']
if self.testing:
done = False
return frame_delta, reward, done, info

View File

@ -0,0 +1,70 @@
import time
import cv2
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.RetroEnv(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env, testing=True)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.ChunLiVsGuile",
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
env = make_env(game, state_stages[0])()
# Wrap the environment
env = DummyVecEnv([lambda: env])
policy_kwargs = {
'features_extractor_class': CustomCNN
}
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1
)
model.load(r"trained_models_continued/ppo_chunli_6048000_steps")
obs = env.reset()
done = False
while True:
timestamp = time.time()
action, _ = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()
render_time = time.time() - timestamp
if render_time < 0.0111:
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
# env.close()

View File

@ -0,0 +1,124 @@
import os
import random
import retro
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
class RandomOpponentChangeCallback(BaseCallback):
def __init__(self, stages, opponent_interval, verbose=0):
super(RandomOpponentChangeCallback, self).__init__(verbose)
self.stages = stages
self.opponent_interval = opponent_interval
def _on_step(self) -> bool:
if self.n_calls % self.opponent_interval == 0:
new_state = random.choice(self.stages)
print("\nCurrent state:", new_state)
self.training_env.env_method("load_state", new_state, indices=None)
return True
def make_env(game, state, seed=0):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env.seed(seed)
return env
return _init
def main():
# Set up the environment and model
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"ChampionX.Level1.ChunLiVsKen",
"ChampionX.Level2.ChunLiVsChunLi",
"ChampionX.Level3.ChunLiVsZangief",
"ChampionX.Level4.ChunLiVsDhalsim",
"ChampionX.Level5.ChunLiVsRyu",
"ChampionX.Level6.ChunLiVsEHonda",
"ChampionX.Level7.ChunLiVsBlanka",
"ChampionX.Level8.ChunLiVsGuile",
"ChampionX.Level9.ChunLiVsBalrog",
"ChampionX.Level10.ChunLiVsVega",
"ChampionX.Level11.ChunLiVsSagat",
"ChampionX.Level12.ChunLiVsBison"
# Add other stages as necessary
]
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
num_envs = 8
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
# Using CustomCNN as the feature extractor
policy_kwargs = {
'features_extractor_class': CustomCNN
}
model = PPO(
"CnnPolicy",
env,
device="cuda",
policy_kwargs=policy_kwargs,
verbose=1,
n_steps=5400,
batch_size=64,
learning_rate=0.0001,
ent_coef=0.01,
clip_range=0.2,
gamma=0.99,
gae_lambda=0.95,
tensorboard_log="logs/"
)
# Set the save directory
save_dir = "trained_models"
os.makedirs(save_dir, exist_ok=True)
# Load the model from file
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
# Load model and modify the learning rate and entropy coefficient
# custom_objects = {
# "learning_rate": 0.0002
# }
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
# Set up callbacks
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
# model_params = {
# 'n_steps': 5,
# 'gamma': 0.99,
# 'gae_lambda':1,
# 'learning_rate': 7e-4,
# 'vf_coef': 0.5,
# 'ent_coef': 0.0,
# 'max_grad_norm':0.5,
# 'rms_prop_eps':1e-05
# }
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
model.learn(
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback, stage_increase_callback]
)
env.close()
# Save the final model
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
if __name__ == "__main__":
main()

View File

@ -0,0 +1,73 @@
import os
import retro
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
LOG_DIR = 'logs/'
OPT_DIR = 'optuna/'
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(OPT_DIR, exist_ok=True)
def optimize_ppo(trial):
return {
'n_steps':trial.suggest_int('n_steps', 1024, 8192, log=True),
'gamma':trial.suggest_float('gamma', 0.9, 0.9999),
'learning_rate':trial.suggest_float('learning_rate', 5e-5, 1e-4, log=True),
'clip_range':trial.suggest_float('clip_range', 0.1, 0.4),
'gae_lambda':trial.suggest_float('gae_lambda', 0.8, 0.99)
}
def make_env(game, state, seed=0):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env.seed(seed)
return env
return _init
def optimize_agent(trial):
game = "StreetFighterIISpecialChampionEdition-Genesis"
state = "ChampionX.Level1.ChunLiVsKen"
# try:
model_params = optimize_ppo(trial)
# Create environment
env = make_env(game, state)()
env = Monitor(env, LOG_DIR)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last')
# Create algo
model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=0, **model_params)
model.learn(total_timesteps=100000)
# Evaluate model
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)
env.close()
SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
model.save(SAVE_PATH)
return mean_reward
# except Exception as e:
# return -1
# Creating the experiment
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=10, n_jobs=1)
print(study.best_params)
print(study.best_trial)