mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-03 22:50:43 +00:00
ram_based_image_stack
This commit is contained in:
parent
d4fb6dbc59
commit
02e39f0a52
@ -0,0 +1,234 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "10d267bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import retro"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "1ef8ff20",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n",
|
||||
"state = \"Champion.Level1.ChunLiVsGuile\"\n",
|
||||
"env = retro.make(game=game, state=state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "5ce656b8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env.action_space.sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "8c3f0a4d",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(200, 256, 3)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env.observation_space.sample().shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "46db7b05",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(200, 256, 3)\n",
|
||||
"{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"observation = env.reset()\n",
|
||||
"print(observation.shape)\n",
|
||||
"\n",
|
||||
"action = env.action_space.sample()\n",
|
||||
"obs, rewards, done, info = env.step(action)\n",
|
||||
"print(info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "09f0c6b0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MultiBinary(12)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from gym.spaces import Box, MultiBinary\n",
|
||||
"\n",
|
||||
"print(MultiBinary(12))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "97df18cf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"from gym.spaces import Box, MultiBinary\n",
|
||||
"\n",
|
||||
"class StreetFighter(gym.Env):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n",
|
||||
" self.action_space = MultiBinary(12)\n",
|
||||
" self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n",
|
||||
" \n",
|
||||
" self.full_hp = 176\n",
|
||||
" self.player_health = self.full_hp\n",
|
||||
" self.oppont_health = self.full_hp\n",
|
||||
" \n",
|
||||
" self.score = 0\n",
|
||||
" \n",
|
||||
" def __preprocess(self, observation):\n",
|
||||
" gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n",
|
||||
" resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n",
|
||||
" return resize\n",
|
||||
"\n",
|
||||
" def step(self, action):\n",
|
||||
"\n",
|
||||
" obs, reward, done, info = self.game.step(action)\n",
|
||||
" custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n",
|
||||
"\n",
|
||||
" # During fighting, either player or opponent has positive health points.\n",
|
||||
" if info['health'] > 0 or info['enemy_health'] > 0:\n",
|
||||
"\n",
|
||||
" # Player Loses\n",
|
||||
" if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n",
|
||||
" reward = (-self.full_hp) * info['enemy_health']\n",
|
||||
"\n",
|
||||
" # Player Wins\n",
|
||||
" elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n",
|
||||
" reward = self.full_hp * info['health']\n",
|
||||
"\n",
|
||||
" # During Fighting\n",
|
||||
" else:\n",
|
||||
" reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n",
|
||||
" \n",
|
||||
" self.player_health = info['health']\n",
|
||||
" self.oppont_health = info['enemy_health']\n",
|
||||
" \n",
|
||||
" return custom_obs, reward, done, info\n",
|
||||
" \n",
|
||||
" def render(self, *args, **kwargs):\n",
|
||||
" self.game.render()\n",
|
||||
" \n",
|
||||
" def reset(self):\n",
|
||||
" obs = self.game.reset()\n",
|
||||
" custom_obs = self.__preprocess(obs)\n",
|
||||
" self.previous_frame = obs\n",
|
||||
" \n",
|
||||
" self.player_health = self.full_hp\n",
|
||||
" self.oppont_health = self.full_hp\n",
|
||||
" return custom_obs\n",
|
||||
"\n",
|
||||
" def close(self):\n",
|
||||
" self.game.close()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "0b137b88",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(84, 84, 1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env.close()\n",
|
||||
"env = StreetFighter()\n",
|
||||
"print(env.observation_space.shape)\n",
|
||||
"env.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2da50dbc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
46
000_image_stack_ram_based_reward/check_reward.py
Normal file
46
000_image_stack_ram_based_reward/check_reward.py
Normal file
@ -0,0 +1,46 @@
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state = "Champion.Level1.ChunLiVsGuile"#"ChampionX.Level1.ChunLiVsKen"
|
||||
|
||||
env = make_env(game, state)()
|
||||
env = Monitor(env, 'logs/')
|
||||
|
||||
num_episodes = 30
|
||||
episode_reward_sum = 0
|
||||
for _ in range(num_episodes):
|
||||
done = False
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
|
||||
if reward != 0:
|
||||
total_reward += reward
|
||||
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
|
||||
env.render()
|
||||
print("Total reward: {}".format(total_reward))
|
||||
episode_reward_sum += total_reward
|
||||
|
||||
env.close()
|
||||
print("Average reward for random strategy: {}".format(episode_reward_sum/num_episodes))
|
52
000_image_stack_ram_based_reward/evaluate.py
Normal file
52
000_image_stack_ram_based_reward/evaluate.py
Normal file
@ -0,0 +1,52 @@
|
||||
import retro
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
# Wrap the environment
|
||||
# env = Monitor(env, 'logs/')
|
||||
|
||||
policy_kwargs = {'features_extractor_class': CustomCNN}
|
||||
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
|
||||
|
||||
model = PPO.load(r"dummy_model_ppo_chunli")
|
||||
# model.load(r"trained_models/ppo_chunli_864000_steps")
|
||||
|
||||
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10, deterministic=False, return_episode_rewards=True)
|
||||
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
53
000_image_stack_ram_based_reward/logs/monitor.csv
Normal file
53
000_image_stack_ram_based_reward/logs/monitor.csv
Normal file
@ -0,0 +1,53 @@
|
||||
#{"t_start": 1680186251.3110938, "env_id": null}
|
||||
r,l,t
|
||||
-50,2150,3.695703
|
||||
-40,2886,12.564373
|
||||
-128,2196,20.599987
|
||||
-217,3000,25.620172
|
||||
-210,2753,34.631877
|
||||
27,2177,42.807461
|
||||
-161,2502,46.870715
|
||||
-227,2122,54.492589
|
||||
-289,1567,61.321581
|
||||
1,2075,64.463465
|
||||
130,2465,72.662509
|
||||
-192,3007,82.093462
|
||||
3927.0,6468,97.611361
|
||||
-109,1823,104.996175
|
||||
200,1820,112.333123
|
||||
-300,2478,116.020238
|
||||
-42,2351,124.010789
|
||||
-263,1990,127.212089
|
||||
-351,1486,134.405471
|
||||
-225,2611,143.112158
|
||||
-56,3290,153.69294
|
||||
-65,2138,157.640509
|
||||
62,3161,167.244644
|
||||
-189,2652,175.720904
|
||||
224,2138,179.193385
|
||||
-48,3706,189.4923
|
||||
-209,3172,199.319699
|
||||
-98,2059,207.148574
|
||||
51,2787,216.523835
|
||||
-88,3218,225.952495
|
||||
-263,1828,228.707771
|
||||
-38,2328,236.642072
|
||||
7,3179,245.83899
|
||||
-133,2421,249.558141
|
||||
-296,1684,256.702009
|
||||
-211,2881,266.1996
|
||||
-261,1710,269.33675
|
||||
-176,1974,277.229695
|
||||
184,1310,279.58493
|
||||
218,2222,288.236686
|
||||
-229,2460,291.904952
|
||||
-345,2510,299.876746
|
||||
-345,2510,302.781091
|
||||
-345,2510,305.701696
|
||||
-345,2510,308.687105
|
||||
-345,2510,311.624716
|
||||
-345,2510,314.566203
|
||||
-345,2510,317.608539
|
||||
-345,2510,320.618201
|
||||
-345,2510,323.649133
|
||||
-345,2510,326.561072
|
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
8947
000_image_stack_ram_based_reward/optuna/tuning_log.txt
Normal file
8947
000_image_stack_ram_based_reward/optuna/tuning_log.txt
Normal file
File diff suppressed because it is too large
Load Diff
93
000_image_stack_ram_based_reward/rmsprop_optim.py
Normal file
93
000_image_stack_ram_based_reward/rmsprop_optim.py
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
class RMSpropTF(Optimizer):
|
||||
def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10,
|
||||
weight_decay=0, momentum=0., centered=False,
|
||||
decoupled_decay=False, lr_in_momentum=True
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= momentum:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= alpha:
|
||||
raise ValueError("Invalid alpha value: {}".format(alpha))
|
||||
defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps,
|
||||
centered=centered, weight_decay=weight_decay,
|
||||
decoupled_decay=decoupled_decay,
|
||||
lr_in_momentum=lr_in_momentum
|
||||
)
|
||||
super(RMSpropTF, self).__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(RMSpropTF, self).__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('momentum', 0)
|
||||
group.setdefault('centered', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the
|
||||
model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RMSprop does not support sparse gradients')
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
|
||||
if group['momentum'] > 0:
|
||||
state['momentum_buffer'] = torch.zeros_like(p)
|
||||
if group['centered']:
|
||||
state['grad_avg'] = torch.zeros_like(p)
|
||||
square_avg = state['square_avg']
|
||||
one_minus_alpha = 1. - group['alpha']
|
||||
state['step'] += 1
|
||||
if group['weight_decay'] != 0:
|
||||
if group['decoupled_decay']:
|
||||
p.mul_(1. - group['lr'] * group['weight_decay'])
|
||||
else:
|
||||
grad = grad.add(p, alpha=group['weight_decay'])
|
||||
|
||||
# Tensorflow order of ops for updating squared avg
|
||||
square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
|
||||
# square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
|
||||
if group['centered']:
|
||||
grad_avg = state['grad_avg']
|
||||
grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
|
||||
avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
|
||||
# grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) #
|
||||
# PyTorch original
|
||||
else:
|
||||
avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
|
||||
if group['momentum'] > 0:
|
||||
buf = state['momentum_buffer']
|
||||
# Tensorflow accumulates the LR scaling in the momentum buffer
|
||||
if group['lr_in_momentum']:
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
|
||||
p.add_(-buf)
|
||||
else:
|
||||
# PyTorch scales the param update by LR
|
||||
buf.mul_(group['momentum']).addcdiv_(grad, avg)
|
||||
p.add_(buf, alpha=-group['lr'])
|
||||
else:
|
||||
p.addcdiv_(grad, avg, value=-group['lr'])
|
||||
return loss
|
||||
|
@ -0,0 +1,97 @@
|
||||
import collections
|
||||
|
||||
import gym
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# Custom environment wrapper
|
||||
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
def __init__(self, env, testing=False):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
self.env = env
|
||||
|
||||
# Use a deque to store the last 4 frames
|
||||
self.num_frames = 3
|
||||
self.frame_stack = collections.deque(maxlen=self.num_frames)
|
||||
|
||||
self.full_hp = 176
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
# Update observation space to include stacked grayscale images
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8)
|
||||
|
||||
self.testing = testing
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
obs_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
obs_gray_resized = cv2.resize(obs_gray, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# Add the resized image to the frame stack
|
||||
self.frame_stack.append(obs_gray_resized)
|
||||
|
||||
# Stack the frames and return the "image"
|
||||
stacked_frames = np.stack(self.frame_stack, axis=-1)
|
||||
return stacked_frames
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
obs_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
obs_gray_resized = cv2.resize(obs_gray, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
|
||||
# Clear the frame stack and add the first observation [num_frames] times
|
||||
self.frame_stack.clear()
|
||||
for _ in range(self.num_frames):
|
||||
self.frame_stack.append(obs_gray_resized)
|
||||
|
||||
return np.stack(self.frame_stack, axis=-1)
|
||||
|
||||
def step(self, action):
|
||||
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
|
||||
# During fighting, either player or opponent has positive health points.
|
||||
if info['health'] > 0 or info['enemy_health'] > 0:
|
||||
|
||||
# Player Loses
|
||||
if info['health'] < 0 and info['enemy_health'] > 0:
|
||||
# reward = (-self.full_hp) * info['enemy_health'] * 0.05 # max = 0.05 * 176 * 176 = 1548.8
|
||||
reward = -info['enemy_health'] # Use the left over health points as penalty
|
||||
|
||||
# Prevent data overflow
|
||||
if reward < -self.full_hp:
|
||||
reward = 0
|
||||
|
||||
done = True
|
||||
|
||||
# Player Wins
|
||||
elif info['enemy_health'] < 0 and info['health'] > 0:
|
||||
# reward = self.full_hp * info['health'] * 0.05
|
||||
reward = info['health']
|
||||
|
||||
|
||||
# Prevent data overflow
|
||||
if reward > self.full_hp:
|
||||
reward = 0
|
||||
|
||||
done = True
|
||||
|
||||
# During Fighting
|
||||
else:
|
||||
reward = (self.prev_oppont_health - info['enemy_health']) - (self.prev_player_health - info['health'])
|
||||
|
||||
# Prevent data overflow
|
||||
if reward > 99:
|
||||
reward = 0
|
||||
|
||||
self.prev_player_health = info['health']
|
||||
self.prev_oppont_health = info['enemy_health']
|
||||
|
||||
if self.testing:
|
||||
done = False
|
||||
|
||||
return self._preprocess_observation(obs), reward, done, info
|
||||
|
314
000_image_stack_ram_based_reward/street_fighter_notebook.ipynb
Normal file
314
000_image_stack_ram_based_reward/street_fighter_notebook.ipynb
Normal file
@ -0,0 +1,314 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "bfc79b8c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import retro"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "c24fbcab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n",
|
||||
"state = \"Champion.Level1.ChunLiVsGuile\"\n",
|
||||
"env = retro.make(game=game, state=state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "59839d9c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env.action_space.sample()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "e068cb0a",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(200, 256, 3)"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env.observation_space.sample().shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "1cb0297f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(200, 256, 3)\n",
|
||||
"{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"observation = env.reset()\n",
|
||||
"print(observation.shape)\n",
|
||||
"\n",
|
||||
"action = env.action_space.sample()\n",
|
||||
"obs, rewards, done, info = env.step(action)\n",
|
||||
"print(info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "0eaa5cc8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"MultiBinary(12)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from gym.spaces import Box, MultiBinary\n",
|
||||
"\n",
|
||||
"print(MultiBinary(12))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"id": "49f6cf5c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"\n",
|
||||
"import gym\n",
|
||||
"import numpy as np\n",
|
||||
"from gym.spaces import Box, MultiBinary\n",
|
||||
"\n",
|
||||
"class StreetFighter(gym.Env):\n",
|
||||
" def __init__(self):\n",
|
||||
" super().__init__()\n",
|
||||
" self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n",
|
||||
" self.action_space = MultiBinary(12)\n",
|
||||
" self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n",
|
||||
" \n",
|
||||
" self.full_hp = 176\n",
|
||||
" self.player_health = self.full_hp\n",
|
||||
" self.oppont_health = self.full_hp\n",
|
||||
" \n",
|
||||
" self.score = 0\n",
|
||||
" \n",
|
||||
" def __preprocess(self, observation):\n",
|
||||
" gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n",
|
||||
" resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n",
|
||||
" return resize\n",
|
||||
"\n",
|
||||
" def step(self, action):\n",
|
||||
"\n",
|
||||
" obs, reward, done, info = self.game.step(action)\n",
|
||||
" custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n",
|
||||
"\n",
|
||||
" # During fighting, either player or opponent has positive health points.\n",
|
||||
" if info['health'] > 0 or info['enemy_health'] > 0:\n",
|
||||
"\n",
|
||||
" # Player Loses\n",
|
||||
" if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n",
|
||||
" reward = (-self.full_hp) * info['enemy_health']\n",
|
||||
"\n",
|
||||
" # Player Wins\n",
|
||||
" elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n",
|
||||
" reward = self.full_hp * info['health']\n",
|
||||
"\n",
|
||||
" # During Fighting\n",
|
||||
" else:\n",
|
||||
" reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n",
|
||||
" \n",
|
||||
" self.player_health = info['health']\n",
|
||||
" self.oppont_health = info['enemy_health']\n",
|
||||
" \n",
|
||||
" return custom_obs, reward, done, info\n",
|
||||
" \n",
|
||||
" def render(self, *args, **kwargs):\n",
|
||||
" self.game.render()\n",
|
||||
" \n",
|
||||
" def reset(self):\n",
|
||||
" obs = self.game.reset()\n",
|
||||
" custom_obs = self.__preprocess(obs)\n",
|
||||
" self.previous_frame = obs\n",
|
||||
" \n",
|
||||
" self.player_health = self.full_hp\n",
|
||||
" self.oppont_health = self.full_hp\n",
|
||||
" return custom_obs\n",
|
||||
"\n",
|
||||
" def close(self):\n",
|
||||
" self.game.close()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"id": "6ec30177",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(84, 84)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"env.close()\n",
|
||||
"env = StreetFighter()\n",
|
||||
"print(env.observation_space.shape)\n",
|
||||
"env.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "7d9eab3a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\ProgramData\\Anaconda3\\envs\\StreetFighterAI\\lib\\site-packages\\pyglet\\image\\codecs\\wic.py:289: UserWarning: [WinError -2147417850] Cannot change thread mode after it is set\n",
|
||||
" warnings.warn(str(err))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"-22 154 176\n",
|
||||
"-32 122 176\n",
|
||||
"29 122 147\n",
|
||||
"7 122 140\n",
|
||||
"-31 91 140\n",
|
||||
"29 91 111\n",
|
||||
"-23 68 111\n",
|
||||
"-24 44 111\n",
|
||||
"-24 20 111\n",
|
||||
"31 20 80\n",
|
||||
"10 20 70\n",
|
||||
"45 20 25\n",
|
||||
"5 20 20\n",
|
||||
"-15 5 20\n",
|
||||
"19 5 1\n",
|
||||
"-176 -1 1\n",
|
||||
"46 176 130\n",
|
||||
"7 176 123\n",
|
||||
"-24 152 123\n",
|
||||
"29 152 94\n",
|
||||
"-24 128 94\n",
|
||||
"7 128 87\n",
|
||||
"39 128 48\n",
|
||||
"-31 97 48\n",
|
||||
"36 97 12\n",
|
||||
"-24 73 12\n",
|
||||
"-24 49 12\n",
|
||||
"8624 49 -1\n",
|
||||
"39 176 137\n",
|
||||
"-24 152 137\n",
|
||||
"-23 129 137\n",
|
||||
"-23 106 137\n",
|
||||
"-26 80 137\n",
|
||||
"-24 56 137\n",
|
||||
"-23 33 137\n",
|
||||
"-21 12 137\n",
|
||||
"-12 0 137\n",
|
||||
"-24112 -1 137\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"## Checking Rewards functionality\n",
|
||||
"import time\n",
|
||||
"\n",
|
||||
"env = StreetFighter()\n",
|
||||
"obs = env.reset()\n",
|
||||
"done = False\n",
|
||||
"\n",
|
||||
"for game in range(5):\n",
|
||||
" while not done:\n",
|
||||
" if done:\n",
|
||||
" obs = env.reset()\n",
|
||||
" env.render()\n",
|
||||
" obs, reward, done, info = env.step(env.action_space.sample())\n",
|
||||
" if reward != 0:\n",
|
||||
" print(reward, info['health'], info['enemy_health'])\n",
|
||||
" time.sleep(0.01)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b1ae8310",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
69
000_image_stack_ram_based_reward/test.py
Normal file
69
000_image_stack_ram_based_reward/test.py
Normal file
@ -0,0 +1,69 @@
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
|
||||
"ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
verbose=1
|
||||
)
|
||||
model_path = r"optuna/trial_1_best_model" # Average reward for optuna/trial_1_best_model: -82.3
|
||||
model.load(model_path)
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
||||
num_episodes = 30
|
||||
episode_reward_sum = 0
|
||||
for _ in range(num_episodes):
|
||||
done = False
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
|
||||
if reward != 0:
|
||||
total_reward += reward
|
||||
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
|
||||
env.render()
|
||||
print("Total reward: {}".format(total_reward))
|
||||
episode_reward_sum += total_reward
|
||||
|
||||
env.close()
|
||||
print("Average reward for {}: {}".format(model_path, episode_reward_sum/num_episodes))
|
125
000_image_stack_ram_based_reward/train.py
Normal file
125
000_image_stack_ram_based_reward/train.py
Normal file
@ -0,0 +1,125 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
|
||||
from rmsprop_optim import RMSpropTF
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
class RandomOpponentChangeCallback(BaseCallback):
|
||||
def __init__(self, stages, opponent_interval, verbose=0):
|
||||
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
||||
self.stages = stages
|
||||
self.opponent_interval = opponent_interval
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.opponent_interval == 0:
|
||||
new_state = random.choice(self.stages)
|
||||
print("\nCurrent state:", new_state)
|
||||
self.training_env.env_method("load_state", new_state, indices=None)
|
||||
return True
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"ChampionX.Level1.ChunLiVsKen",
|
||||
"ChampionX.Level2.ChunLiVsChunLi",
|
||||
"ChampionX.Level3.ChunLiVsZangief",
|
||||
"ChampionX.Level4.ChunLiVsDhalsim",
|
||||
"ChampionX.Level5.ChunLiVsRyu",
|
||||
"ChampionX.Level6.ChunLiVsEHonda",
|
||||
"ChampionX.Level7.ChunLiVsBlanka",
|
||||
"ChampionX.Level8.ChunLiVsGuile",
|
||||
"ChampionX.Level9.ChunLiVsBalrog",
|
||||
"ChampionX.Level10.ChunLiVsVega",
|
||||
"ChampionX.Level11.ChunLiVsSagat",
|
||||
"ChampionX.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
|
||||
|
||||
num_envs = 8
|
||||
|
||||
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
|
||||
# Using CustomCNN as the feature extractor
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
n_steps=5400,
|
||||
batch_size=64,
|
||||
learning_rate=0.0001,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
tensorboard_log="logs/"
|
||||
)
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Load the model from file
|
||||
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
|
||||
|
||||
# Load model and modify the learning rate and entropy coefficient
|
||||
# custom_objects = {
|
||||
# "learning_rate": 0.0002
|
||||
# }
|
||||
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
|
||||
|
||||
# Set up callbacks
|
||||
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
|
||||
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
||||
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
|
||||
# model_params = {
|
||||
# 'n_steps': 5,
|
||||
# 'gamma': 0.99,
|
||||
# 'gae_lambda':1,
|
||||
# 'learning_rate': 7e-4,
|
||||
# 'vf_coef': 0.5,
|
||||
# 'ent_coef': 0.0,
|
||||
# 'max_grad_norm':0.5,
|
||||
# 'rms_prop_eps':1e-05
|
||||
# }
|
||||
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback, stage_increase_callback]
|
||||
)
|
||||
env.close()
|
||||
|
||||
# Save the final model
|
||||
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
81
000_image_stack_ram_based_reward/tune.py
Normal file
81
000_image_stack_ram_based_reward/tune.py
Normal file
@ -0,0 +1,81 @@
|
||||
import gym
|
||||
import retro
|
||||
import optuna
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.RetroEnv(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env = Monitor(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def objective(trial):
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
env = make_env(game, state="ChampionX.Level1.ChunLiVsKen")()
|
||||
|
||||
# Suggest hyperparameters
|
||||
learning_rate = trial.suggest_float("learning_rate", 5e-5, 1e-3, log=True)
|
||||
n_steps = trial.suggest_int("n_steps", 256, 8192, log=True)
|
||||
batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
|
||||
gamma = trial.suggest_float("gamma", 0.9, 0.9999)
|
||||
gae_lambda = trial.suggest_float("gae_lambda", 0.9, 1.0)
|
||||
clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
|
||||
ent_coef = trial.suggest_float("ent_coef", 1e-4, 1e-2, log=True)
|
||||
vf_coef = trial.suggest_float("vf_coef", 0.1, 1.0)
|
||||
|
||||
# Using CustomCNN as the feature extractor
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
# Train the model
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
ent_coef=ent_coef,
|
||||
clip_range=clip_range,
|
||||
vf_coef=vf_coef,
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda
|
||||
)
|
||||
|
||||
for iteration in range(10):
|
||||
model.learn(total_timesteps=100000)
|
||||
mean_reward, _std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
||||
|
||||
trial.report(mean_reward, iteration)
|
||||
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
return mean_reward
|
||||
|
||||
study = optuna.create_study(direction="maximize")
|
||||
study.optimize(objective, n_trials=100, timeout=7200) # Run optimization for 100 trials or 2 hours, whichever comes first
|
||||
|
||||
print("Best trial:")
|
||||
trial = study.best_trial
|
||||
|
||||
print(" Value: ", trial.value)
|
||||
print(" Params: ")
|
||||
for key, value in trial.params.items():
|
||||
print(f"{key}: {value}")
|
69
000_image_stack_ram_based_reward/tune_ppo.py
Normal file
69
000_image_stack_ram_based_reward/tune_ppo.py
Normal file
@ -0,0 +1,69 @@
|
||||
import os
|
||||
|
||||
import retro
|
||||
import optuna
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
LOG_DIR = 'logs/'
|
||||
OPT_DIR = 'optuna/'
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
os.makedirs(OPT_DIR, exist_ok=True)
|
||||
|
||||
def optimize_ppo(trial):
|
||||
return {
|
||||
'n_steps':trial.suggest_int('n_steps', 1024, 8192, log=True),
|
||||
'gamma':trial.suggest_float('gamma', 0.9, 0.9999),
|
||||
'learning_rate':trial.suggest_float('learning_rate', 5e-5, 1e-4, log=True),
|
||||
'clip_range':trial.suggest_float('clip_range', 0.1, 0.4),
|
||||
'gae_lambda':trial.suggest_float('gae_lambda', 0.8, 0.99)
|
||||
}
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def optimize_agent(trial):
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state = "Champion.Level1.ChunLiVsGuile"#"ChampionX.Level1.ChunLiVsKen"
|
||||
|
||||
try:
|
||||
model_params = optimize_ppo(trial)
|
||||
|
||||
# Create environment
|
||||
env = make_env(game, state)()
|
||||
env = Monitor(env, LOG_DIR)
|
||||
|
||||
# Create algo
|
||||
model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, **model_params)
|
||||
model.learn(total_timesteps=100000)
|
||||
|
||||
# Evaluate model
|
||||
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=30)
|
||||
env.close()
|
||||
|
||||
SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
|
||||
model.save(SAVE_PATH)
|
||||
|
||||
return mean_reward
|
||||
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
# Creating the experiment
|
||||
study = optuna.create_study(direction='maximize')
|
||||
study.optimize(optimize_agent, n_trials=10, n_jobs=1)
|
||||
|
||||
print(study.best_params)
|
||||
print(study.best_trial)
|
Binary file not shown.
Binary file not shown.
39
001_image_stack_vision_based_reward/check_reward.py
Normal file
39
001_image_stack_vision_based_reward/check_reward.py
Normal file
@ -0,0 +1,39 @@
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.RetroEnv(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env, testing=True)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state = "Champion.Level1.ChunLiVsGuile"
|
||||
|
||||
env = make_env(game, state)()
|
||||
model = PPO.load(r"trained_models_continued/ppo_chunli_6048000_steps")
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
action, _ = model.predict(obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
print(info)
|
||||
if reward != 0:
|
||||
print(reward, info['health'], info['enemy_health'])
|
||||
env.render()
|
||||
|
||||
env.close()
|
24
001_image_stack_vision_based_reward/custom_cnn.py
Normal file
24
001_image_stack_vision_based_reward/custom_cnn.py
Normal file
@ -0,0 +1,24 @@
|
||||
import gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(16384, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
|
47
001_image_stack_vision_based_reward/evaluate.py
Normal file
47
001_image_stack_vision_based_reward/evaluate.py
Normal file
@ -0,0 +1,47 @@
|
||||
import retro
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.RetroEnv(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
# Wrap the environment
|
||||
env = Monitor(env, 'logs/')
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
model = PPO.load('trained_models/ppo_chunli_1296000_steps')
|
||||
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10)
|
||||
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
12
001_image_stack_vision_based_reward/logs/monitor.csv
Normal file
12
001_image_stack_vision_based_reward/logs/monitor.csv
Normal file
@ -0,0 +1,12 @@
|
||||
#{"t_start": 1680163278.6497958, "env_id": null}
|
||||
r,l,t
|
||||
-1115.766667,2842,13.829476
|
||||
-1115.766667,2842,22.367655
|
||||
-1115.766667,2842,32.010939
|
||||
-1115.766667,2842,41.401216
|
||||
-1115.766667,2842,50.451062
|
||||
-1115.766667,2842,59.522487
|
||||
-1115.766667,2842,68.723222
|
||||
-1115.766667,2842,78.205462
|
||||
-1115.766667,2842,88.455592
|
||||
-1115.766667,2842,97.656297
|
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
@ -12,8 +12,6 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
def __init__(self, env, testing=False, threshold=0.65):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
|
||||
self.action_space = MultiBinary(12)
|
||||
|
||||
# Use a deque to store the last 4 frames
|
||||
self.frame_stack = collections.deque(maxlen=4)
|
||||
|
||||
@ -89,7 +87,7 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
|
||||
def step(self, action):
|
||||
# observation, _, _, info = self.env.step(action)
|
||||
observation, _reward, _done, info = self.env.step(self.env.action_space.sample())
|
||||
observation, _reward, _done, info = self.env.step(action)
|
||||
custom_reward = self._get_reward()
|
||||
custom_reward -= 1.0 / 60.0 # penalty for each step (-1 points per second)
|
||||
|
@ -53,7 +53,7 @@ model = PPO(
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
)
|
||||
model.load(r"trained_models_continued/ppo_chunli_432000_steps")
|
||||
model.load(r"trained_models/ppo_chunli_1296000_steps")
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
@ -1,13 +1,9 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import gym
|
||||
import cv2
|
||||
import retro
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
@ -77,20 +73,16 @@ def main():
|
||||
verbose=1,
|
||||
n_steps=5400,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
learning_rate=0.0003,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2,
|
||||
clip_range_vf=None,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
max_grad_norm=0.5,
|
||||
use_sde=False,
|
||||
sde_sample_freq=-1
|
||||
tensorboard_log="logs/"
|
||||
)
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models_continued"
|
||||
save_dir = "trained_models_continued_new"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Load the model from file
|
||||
@ -99,8 +91,7 @@ def main():
|
||||
|
||||
# Load model and modify the learning rate and entropy coefficient
|
||||
custom_objects = {
|
||||
"learning_rate": 0.00005,
|
||||
"ent_coef": 0.2
|
||||
"learning_rate": 0.0001
|
||||
}
|
||||
model = PPO.load(model_path, env=env, device="cuda", custom_objects=custom_objects)
|
||||
|
||||
@ -110,7 +101,6 @@ def main():
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
||||
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback, stage_increase_callback]
|
2791
001_image_stack_vision_based_reward/trainging_log_continued.txt
Normal file
2791
001_image_stack_vision_based_reward/trainging_log_continued.txt
Normal file
File diff suppressed because it is too large
Load Diff
81
001_image_stack_vision_based_reward/tune.py
Normal file
81
001_image_stack_vision_based_reward/tune.py
Normal file
@ -0,0 +1,81 @@
|
||||
import gym
|
||||
import retro
|
||||
import optuna
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.RetroEnv(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env = Monitor(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def objective(trial):
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
env = make_env(game, state="ChampionX.Level1.ChunLiVsKen")()
|
||||
|
||||
# Suggest hyperparameters
|
||||
learning_rate = trial.suggest_float("learning_rate", 5e-5, 1e-3, log=True)
|
||||
n_steps = trial.suggest_int("n_steps", 256, 8192, log=True)
|
||||
batch_size = trial.suggest_int("batch_size", 16, 128, log=True)
|
||||
gamma = trial.suggest_float("gamma", 0.9, 0.9999)
|
||||
gae_lambda = trial.suggest_float("gae_lambda", 0.9, 1.0)
|
||||
clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
|
||||
ent_coef = trial.suggest_float("ent_coef", 1e-4, 1e-2, log=True)
|
||||
vf_coef = trial.suggest_float("vf_coef", 0.1, 1.0)
|
||||
|
||||
# Using CustomCNN as the feature extractor
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
# Train the model
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
ent_coef=ent_coef,
|
||||
clip_range=clip_range,
|
||||
vf_coef=vf_coef,
|
||||
gamma=gamma,
|
||||
gae_lambda=gae_lambda
|
||||
)
|
||||
|
||||
for iteration in range(10):
|
||||
model.learn(total_timesteps=100000)
|
||||
mean_reward, _std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
||||
|
||||
trial.report(mean_reward, iteration)
|
||||
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
return mean_reward
|
||||
|
||||
study = optuna.create_study(direction="maximize")
|
||||
study.optimize(objective, n_trials=100, timeout=7200) # Run optimization for 100 trials or 2 hours, whichever comes first
|
||||
|
||||
print("Best trial:")
|
||||
trial = study.best_trial
|
||||
|
||||
print(" Value: ", trial.value)
|
||||
print(" Params: ")
|
||||
for key, value in trial.params.items():
|
||||
print(f"{key}: {value}")
|
BIN
003_frame_delta_ram_based/__pycache__/custom_cnn.cpython-38.pyc
Normal file
BIN
003_frame_delta_ram_based/__pycache__/custom_cnn.cpython-38.pyc
Normal file
Binary file not shown.
Binary file not shown.
25
003_frame_delta_ram_based/custom_cnn.py
Normal file
25
003_frame_delta_ram_based/custom_cnn.py
Normal file
@ -0,0 +1,25 @@
|
||||
import gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(16384, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
observations = observations.unsqueeze(1)
|
||||
return self.cnn(observations)
|
||||
|
2
003_frame_delta_ram_based/logs/monitor.csv
Normal file
2
003_frame_delta_ram_based/logs/monitor.csv
Normal file
@ -0,0 +1,2 @@
|
||||
#{"t_start": 1680175884.8182795, "env_id": null}
|
||||
r,l,t
|
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
72
003_frame_delta_ram_based/street_fighter_custom_wrapper.py
Normal file
72
003_frame_delta_ram_based/street_fighter_custom_wrapper.py
Normal file
@ -0,0 +1,72 @@
|
||||
import gym
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# Custom environment wrapper
|
||||
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
def __init__(self, env, testing=False):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
self.env = env
|
||||
self.testing = testing
|
||||
|
||||
# Store the previous frame
|
||||
self.prev_frame = None
|
||||
|
||||
self.full_hp = 176
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
# Update observation space to include one grayscale frame difference image
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
obs_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
obs_gray_resized = cv2.resize(obs_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
|
||||
return obs_gray_resized
|
||||
|
||||
def reset(self):
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
observation = self.env.reset()
|
||||
# Reset the previous frame
|
||||
self.prev_frame = self._preprocess_observation(observation)
|
||||
return np.zeros_like(self.prev_frame)
|
||||
|
||||
def step(self, action):
|
||||
observation, _reward, _done, info = self.env.step(action)
|
||||
|
||||
obs_gray_resized = self._preprocess_observation(observation)
|
||||
|
||||
if self.prev_frame is not None:
|
||||
frame_delta = obs_gray_resized - self.prev_frame
|
||||
else:
|
||||
frame_delta = np.zeros_like(obs_gray_resized)
|
||||
|
||||
self.prev_frame = obs_gray_resized
|
||||
|
||||
# During fighting, either player or opponent has positive health points.
|
||||
if info['health'] > 0 or info['enemy_health'] > 0:
|
||||
|
||||
# Player Loses
|
||||
if info['health'] < 0 and info['enemy_health'] > 0:
|
||||
reward = (-self.full_hp) * info['enemy_health']
|
||||
done = True
|
||||
|
||||
# Player Wins
|
||||
elif info['enemy_health'] < 0 and info['health'] > 0:
|
||||
reward = self.full_hp * info['health']
|
||||
done = True
|
||||
|
||||
# During Fighting
|
||||
else:
|
||||
reward = (self.prev_oppont_health - info['enemy_health']) - (self.prev_player_health - info['health'])
|
||||
|
||||
self.prev_player_health = info['health']
|
||||
self.prev_oppont_health = info['enemy_health']
|
||||
|
||||
if self.testing:
|
||||
done = False
|
||||
|
||||
return frame_delta, reward, done, info
|
||||
|
70
003_frame_delta_ram_based/test.py
Normal file
70
003_frame_delta_ram_based/test.py
Normal file
@ -0,0 +1,70 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.RetroEnv(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env, testing=True)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
# Wrap the environment
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
)
|
||||
model.load(r"trained_models_continued/ppo_chunli_6048000_steps")
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
||||
while True:
|
||||
timestamp = time.time()
|
||||
action, _ = model.predict(obs)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
env.render()
|
||||
render_time = time.time() - timestamp
|
||||
if render_time < 0.0111:
|
||||
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
|
||||
|
||||
# env.close()
|
124
003_frame_delta_ram_based/train.py
Normal file
124
003_frame_delta_ram_based/train.py
Normal file
@ -0,0 +1,124 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO, A2C
|
||||
from stable_baselines3.common.vec_env import SubprocVecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
class RandomOpponentChangeCallback(BaseCallback):
|
||||
def __init__(self, stages, opponent_interval, verbose=0):
|
||||
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
||||
self.stages = stages
|
||||
self.opponent_interval = opponent_interval
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.opponent_interval == 0:
|
||||
new_state = random.choice(self.stages)
|
||||
print("\nCurrent state:", new_state)
|
||||
self.training_env.env_method("load_state", new_state, indices=None)
|
||||
return True
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"ChampionX.Level1.ChunLiVsKen",
|
||||
"ChampionX.Level2.ChunLiVsChunLi",
|
||||
"ChampionX.Level3.ChunLiVsZangief",
|
||||
"ChampionX.Level4.ChunLiVsDhalsim",
|
||||
"ChampionX.Level5.ChunLiVsRyu",
|
||||
"ChampionX.Level6.ChunLiVsEHonda",
|
||||
"ChampionX.Level7.ChunLiVsBlanka",
|
||||
"ChampionX.Level8.ChunLiVsGuile",
|
||||
"ChampionX.Level9.ChunLiVsBalrog",
|
||||
"ChampionX.Level10.ChunLiVsVega",
|
||||
"ChampionX.Level11.ChunLiVsSagat",
|
||||
"ChampionX.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
|
||||
|
||||
num_envs = 8
|
||||
|
||||
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
|
||||
# Using CustomCNN as the feature extractor
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1,
|
||||
n_steps=5400,
|
||||
batch_size=64,
|
||||
learning_rate=0.0001,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
tensorboard_log="logs/"
|
||||
)
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Load the model from file
|
||||
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
|
||||
|
||||
# Load model and modify the learning rate and entropy coefficient
|
||||
# custom_objects = {
|
||||
# "learning_rate": 0.0002
|
||||
# }
|
||||
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
|
||||
|
||||
# Set up callbacks
|
||||
opponent_interval = 5400 # stage_interval * num_envs = total_steps_per_stage
|
||||
checkpoint_interval = 54000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_chunli")
|
||||
stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
|
||||
# model_params = {
|
||||
# 'n_steps': 5,
|
||||
# 'gamma': 0.99,
|
||||
# 'gae_lambda':1,
|
||||
# 'learning_rate': 7e-4,
|
||||
# 'vf_coef': 0.5,
|
||||
# 'ent_coef': 0.0,
|
||||
# 'max_grad_norm':0.5,
|
||||
# 'rms_prop_eps':1e-05
|
||||
# }
|
||||
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(6048000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback, stage_increase_callback]
|
||||
)
|
||||
env.close()
|
||||
|
||||
# Save the final model
|
||||
model.save(os.path.join(save_dir, "ppo_sf2_chunli_final.zip"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
73
003_frame_delta_ram_based/tune_ppo.py
Normal file
73
003_frame_delta_ram_based/tune_ppo.py
Normal file
@ -0,0 +1,73 @@
|
||||
import os
|
||||
|
||||
import retro
|
||||
import optuna
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
LOG_DIR = 'logs/'
|
||||
OPT_DIR = 'optuna/'
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
os.makedirs(OPT_DIR, exist_ok=True)
|
||||
|
||||
def optimize_ppo(trial):
|
||||
return {
|
||||
'n_steps':trial.suggest_int('n_steps', 1024, 8192, log=True),
|
||||
'gamma':trial.suggest_float('gamma', 0.9, 0.9999),
|
||||
'learning_rate':trial.suggest_float('learning_rate', 5e-5, 1e-4, log=True),
|
||||
'clip_range':trial.suggest_float('clip_range', 0.1, 0.4),
|
||||
'gae_lambda':trial.suggest_float('gae_lambda', 0.8, 0.99)
|
||||
}
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def optimize_agent(trial):
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state = "ChampionX.Level1.ChunLiVsKen"
|
||||
|
||||
# try:
|
||||
model_params = optimize_ppo(trial)
|
||||
|
||||
# Create environment
|
||||
env = make_env(game, state)()
|
||||
env = Monitor(env, LOG_DIR)
|
||||
env = DummyVecEnv([lambda: env])
|
||||
env = VecFrameStack(env, 4, channels_order='last')
|
||||
|
||||
# Create algo
|
||||
model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=0, **model_params)
|
||||
model.learn(total_timesteps=100000)
|
||||
|
||||
# Evaluate model
|
||||
mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)
|
||||
env.close()
|
||||
|
||||
SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
|
||||
model.save(SAVE_PATH)
|
||||
|
||||
return mean_reward
|
||||
|
||||
# except Exception as e:
|
||||
# return -1
|
||||
|
||||
# Creating the experiment
|
||||
study = optuna.create_study(direction='maximize')
|
||||
study.optimize(optimize_agent, n_trials=10, n_jobs=1)
|
||||
|
||||
print(study.best_params)
|
||||
print(study.best_trial)
|
Loading…
Reference in New Issue
Block a user