rgb image stack

This commit is contained in:
linyiLYi 2023-04-03 00:19:56 +08:00
parent 16c80d5fba
commit ded261ba69
33 changed files with 635 additions and 35 deletions

View File

Can't render this file because it contains an unexpected character in line 1 and column 3.

View File

@ -1,35 +0,0 @@
import torch.nn as nn
def conv2d_custom_init(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
nn.init.xavier_uniform_(conv.weight)
return conv
def custom_conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
return nn.Sequential(
conv2d_custom_init(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
nn.Relu(),
nn.MaxPool2d((2, 2))
)
# Custom feature extractor (CNN)
class CustomCNN(nn.Module):
def __init__(self, num_frames, num_moves, num_attacks):
super(CustomCNN, self).__init__()
self.num_moves = num_moves
self.num_attacks = num_attacks
self.cnn = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16384, self.features_dim),
nn.ReLU()
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.cnn(observations)

View File

@ -0,0 +1,47 @@
import time
import retro
from stable_baselines3.common.monitor import Monitor
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state = "Champion.Level1.RyuVsGuile"
env = make_env(game, state)()
env = Monitor(env, 'logs/')
num_episodes = 30
episode_reward_sum = 0
for _ in range(num_episodes):
done = False
obs = env.reset()
total_reward = 0
while not done:
timestamp = time.time()
obs, reward, done, info = env.step(env.action_space.sample())
# Note that if player wins but only has 0 HP left, the winning reward is still 0, so it won't be printed.
if reward != 0:
total_reward += reward
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
env.render()
# time.sleep(0.005)
print("Total reward: {}".format(total_reward))
episode_reward_sum += total_reward
env.close()
print("Average reward for random strategy: {}".format(episode_reward_sum/num_episodes))

View File

@ -0,0 +1,24 @@
import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# Custom feature extractor (CNN)
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(CustomCNN, self).__init__(observation_space, features_dim=512)
self.cnn = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16384, self.features_dim),
nn.ReLU()
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.cnn(observations)

View File

@ -0,0 +1,52 @@
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy
from custom_cnn import CustomCNN
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.ChunLiVsGuile",
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
env = make_env(game, state_stages[0])()
# Wrap the environment
# env = Monitor(env, 'logs/')
policy_kwargs = {'features_extractor_class': CustomCNN}
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
model = PPO.load(r"dummy_model_ppo_chunli")
# model.load(r"trained_models/ppo_chunli_864000_steps")
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10, deterministic=False, return_episode_rewards=True)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

View File

@ -0,0 +1,203 @@
#{"t_start": 1680450538.0135336, "env_id": null}
r,l,t
140,2721,9.666247
475,2051,16.014586
659,1841,22.078956
767,1374,25.788732
253,2591,35.3779
337,2369,42.159742
280,2490,51.242315
316,2761,59.177625
498,1532,65.235702
794,1294,68.963929
758,1295,72.564837
343,1862,78.748257
366,1293,82.443477
344,3356,95.206057
655,1716,99.877595
717,1496,105.722482
410,2494,113.515226
906,1496,119.400876
230,1769,124.036038
966,1043,127.542873
321,2284,136.329181
812,1519,140.423345
519,1729,146.851307
722,1412,150.934866
700,1646,157.093265
387,1965,163.827162
343,2604,171.611118
123,1625,177.426801
417,3407,187.666056
187,2616,196.882013
26,2142,203.498234
1056,1218,207.239627
927,1124,210.857355
47,1994,217.59863
510,1774,223.863903
377,1381,227.737895
481,1390,233.374429
658,1447,237.396851
349,3437,248.157832
718,1829,254.408287
180,2025,261.068938
313,1839,267.477898
902,975,270.676313
292,2950,280.273983
182,2268,287.275446
187,1953,293.603531
362,3130,303.389702
638,1779,309.646411
228,2179,316.57769
341,2109,323.229764
85,2241,330.295613
260,3871,343.38057
-72,2846,353.025726
319,2520,360.176952
121,2334,367.161462
-143,2696,376.380194
-10,2199,383.130368
254,3800,395.603224
51,3540,406.048309
340,2647,415.170668
331,6905,437.351264
741,1698,441.462111
331,3519,454.219213
489,2276,461.541512
368,3599,472.906958
-126,3151,482.855691
-235,3220,494.629744
-17,2824,502.354869
-82,3238,514.654498
-160,3254,525.209512
-150,2923,535.200898
-93,7337,558.950392
-275,2822,566.908611
245,3758,579.955454
-117,2586,589.282403
-230,2788,597.257352
90,2896,607.108898
967,1681,613.189581
-162,3178,623.550789
-111,2603,632.824294
-111,2160,639.682543
-91,3059,649.811759
552,3472,660.344764
156,2905,669.981916
506,1973,676.494263
-29,3483,687.2097
-342,2388,694.675364
-98,7522,720.510979
268,3198,730.537025
123,7866,755.786125
41,3279,766.078602
454,3852,778.612931
241,3330,788.852483
-197,2695,796.350531
16,2708,805.798072
-276,2921,815.975549
230,7838,842.84947
-343,2955,852.604732
-124,2710,860.170467
-318,1477,866.041014
-37,3970,877.800582
-26,2989,888.043861
732,8207,914.516506
326,3080,924.643157
-189,2372,933.448068
275,8785,961.099634
-189,2881,970.522553
108,3190,980.352367
-351,2851,989.817974
-287,3287,1000.238004
-262,2847,1010.036809
-206,7848,1034.80143
102,7532,1058.580046
0,8037,1084.80434
110,4073,1097.11781
421,3413,1107.327183
-203,3154,1117.49272
655,14205,1161.34724
126,3993,1174.695066
48,3832,1187.983629
68,2995,1197.869093
80,3252,1207.617986
84,3776,1219.716212
-192,3176,1229.660275
-143,2819,1237.98842
10,2730,1247.57504
-191,2460,1256.580759
-28,2546,1263.612297
192,3534,1273.532951
268,3797,1285.940749
-98,3139,1296.872381
75,3568,1309.84493
-123,7274,1332.694059
326,3440,1342.804182
349,3737,1355.733267
22,2943,1366.020118
-202,3018,1375.968116
888,1928,1382.114744
-209,1646,1386.164465
46,1613,1391.748134
-318,2434,1398.519993
-275,2288,1405.521706
397,3578,1418.447945
317,2150,1425.714679
-75,2716,1435.669649
-93,2679,1442.908609
564,2987,1452.369631
216,2904,1461.84269
44,2300,1468.801676
401,1470,1474.693752
381,3590,1485.786877
256,2522,1494.610591
-141,1773,1498.837729
335,2651,1507.840033
860,1561,1511.859426
357,1743,1517.856122
846,1433,1523.547
702,720,1524.743536
81,3314,1535.593991
608,1468,1541.5119
464,2507,1549.079192
382,1465,1554.894497
661,2153,1561.58434
-220,2172,1568.434792
470,2597,1577.624233
606,1471,1581.617123
128,2485,1589.245833
-151,2076,1596.189308
-34,1775,1602.548944
7,2518,1611.704006
-73,1256,1615.417475
981,952,1618.628284
537,1555,1622.779646
336,2464,1631.451718
490,2070,1638.048966
337,3439,1648.578499
367,2505,1657.904252
365,2554,1665.472756
654,1061,1669.007638
334,3193,1679.491146
-125,1751,1685.74332
342,2740,1695.388833
541,1674,1699.803759
303,3218,1709.898359
62,2140,1716.650506
37,1838,1722.973549
-9,2999,1732.744556
-47,2898,1742.188218
462,3518,1752.401364
206,2255,1760.763199
494,2294,1767.674661
198,2530,1775.444748
149,2196,1782.305408
593,2317,1791.355
349,2208,1798.589246
-74,1673,1804.620136
41,2712,1811.981201
432,2759,1821.619351
75,2880,1831.46073
397,1858,1837.717627
-204,3008,1845.908291
Can't render this file because it contains an unexpected character in line 1 and column 3.

View File

@ -0,0 +1,79 @@
import collections
import gym
import numpy as np
# Custom environment wrapper
class StreetFighterCustomWrapper(gym.Wrapper):
def __init__(self, env, testing=False):
super(StreetFighterCustomWrapper, self).__init__(env)
self.env = env
# Use a deque to store the last 4 frames
self.num_frames = 3
self.frame_stack = collections.deque(maxlen=self.num_frames)
self.reward_coeff = 3
self.full_hp = 176
self.prev_player_health = self.full_hp
self.prev_oppont_health = self.full_hp
# Update observation space to include stacked grayscale images
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(100, 128, 3), dtype=np.uint8)
self.testing = testing
def _preprocess_observation(self, observation):
# Stack the downsampled frames.
self.frame_stack.append(observation[::2, ::2, :])
# Stack the R, G, B channel of each frame and return the "image".
stacked_image = np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1)
return stacked_image
def reset(self):
observation = self.env.reset()
self.prev_player_health = self.full_hp
self.prev_oppont_health = self.full_hp
# Clear the frame stack and add the first observation [num_frames] times
self.frame_stack.clear()
for _ in range(self.num_frames):
self.frame_stack.append(observation[::2, ::2, :])
return np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1)
def step(self, action):
obs, _reward, _done, info = self.env.step(action)
curr_player_health = info['health']
curr_oppont_health = info['enemy_health']
# Game is over and player loses.
if curr_player_health < 0:
custom_reward = -curr_oppont_health # Use the remaining health points of opponent as penalty.
# If the opponent also has negative health points, it's a even game and the reward is +1.
custom_done = True
# Game is over and player wins.
elif curr_oppont_health < 0:
custom_reward = curr_player_health * self.reward_coeff # Use the remaining health points of player as reward.
# Multiply by reward_coeff to make the reward larger than the penalty to avoid cowardice of agent.
custom_done = True
# While the fighting is still going on.
else:
custom_reward = self.reward_coeff * (self.prev_oppont_health - curr_oppont_health) - (self.prev_player_health - curr_player_health)
self.prev_player_health = curr_player_health
self.prev_oppont_health = curr_oppont_health
custom_done = False
# During testing, the session should always keep going.
if self.testing:
custom_done = False
# Max reward is 6 * full_hp = 1054 (damage * 3 + winning_reward * 3)
return self._preprocess_observation(obs), custom_reward, custom_done, info

View File

@ -0,0 +1,76 @@
import time
import retro
from stable_baselines3 import PPO
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.RyuVsGuile",
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3 | -20.4
"ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
env = make_env(game, state_stages[0])()
model = PPO(
"CnnPolicy",
env,
verbose=1
)
model_path = r"trained_models_level_1/ppo_ryu_000000_steps"
model.load(model_path)
# Average reward for optuna/trial_1_best_model: -82.3
# Average reward for optuna/trial_9_best_model: 36.7 | -86.23
# Average reward for trained_models/ppo_chunli_5376000_steps: -77.8
obs = env.reset()
done = False
num_episodes = 30
episode_reward_sum = 0
for _ in range(num_episodes):
done = False
obs = env.reset()
total_reward = 0
while not done:
timestamp = time.time()
action, _states = model.predict(obs)
obs, reward, done, info = env.step(action)
if reward != 0:
total_reward += reward
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
env.render()
time.sleep(0.01)
print("Total reward: {}".format(total_reward))
episode_reward_sum += total_reward
env.close()
print("Average reward for {}: {}".format(model_path, episode_reward_sum/num_episodes))

View File

@ -0,0 +1,154 @@
import os
import random
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
LOG_DIR = 'logs/'
os.makedirs(LOG_DIR, exist_ok=True)
class RandomOpponentChangeCallback(BaseCallback):
def __init__(self, stages, opponent_interval, verbose=0):
super(RandomOpponentChangeCallback, self).__init__(verbose)
self.stages = stages
self.opponent_interval = opponent_interval
def _on_step(self) -> bool:
if self.n_calls % self.opponent_interval == 0:
new_state = random.choice(self.stages)
print("\nCurrent state:", new_state)
self.training_env.env_method("load_state", new_state, indices=None)
return True
# class StageIncreaseCallback(BaseCallback):
# def __init__(self, stages, stage_interval, save_dir, verbose=0):
# super(StageIncreaseCallback, self).__init__(verbose)
# self.stages = stages
# self.stage_interval = stage_interval
# self.save_dir = save_dir
# self.current_stage = 0
# def _on_step(self) -> bool:
# if self.n_calls % self.stage_interval == 0 and self.current_stage < len(self.stages) - 1:
# self.current_stage += 1
# new_state = self.stages[self.current_stage]
# self.training_env.env_method("load_state", new_state, indices=None)
# self.model.save(os.path.join(self.save_dir, f"ppo_chunli_stage_{self.current_stage}.zip"))
# return True
def make_env(game, state):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
return env
return _init
def main():
# Set up the environment and model
game = "StreetFighterIISpecialChampionEdition-Genesis"
state_stages = [
"Champion.Level1.RyuVsGuile",
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
"Champion.Level2.ChunLiVsKen",
"Champion.Level3.ChunLiVsChunLi",
"Champion.Level4.ChunLiVsZangief",
"Champion.Level5.ChunLiVsDhalsim",
"Champion.Level6.ChunLiVsRyu",
"Champion.Level7.ChunLiVsEHonda",
"Champion.Level8.ChunLiVsBlanka",
"Champion.Level9.ChunLiVsBalrog",
"Champion.Level10.ChunLiVsVega",
"Champion.Level11.ChunLiVsSagat",
"Champion.Level12.ChunLiVsBison"
# Add other stages as necessary
]
# state_stages = [
# "ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
# "ChampionX.Level2.ChunLiVsChunLi",
# "ChampionX.Level3.ChunLiVsZangief",
# "ChampionX.Level4.ChunLiVsDhalsim",
# "ChampionX.Level5.ChunLiVsRyu",
# "ChampionX.Level6.ChunLiVsEHonda",
# "ChampionX.Level7.ChunLiVsBlanka",
# "ChampionX.Level8.ChunLiVsGuile",
# "ChampionX.Level9.ChunLiVsBalrog",
# "ChampionX.Level10.ChunLiVsVega",
# "ChampionX.Level11.ChunLiVsSagat",
# "ChampionX.Level12.ChunLiVsBison"
# # Add other stages as necessary
# ]
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
env = make_env(game, state_stages[0])()
# Warp env in Monitor wrapper to record training progress
env = Monitor(env, LOG_DIR)
model = PPO(
"CnnPolicy",
env,
device="cuda",
verbose=1,
n_steps=1024,
batch_size=64,
learning_rate=1e-4,
ent_coef=0.01,
clip_range=0.2,
gamma=0.95,
gae_lambda=0.81322,
tensorboard_log="logs/"
)
# Set the save directory
save_dir = "trained_models_ryu_level_1_reward_x3"
os.makedirs(save_dir, exist_ok=True)
# Load the model from file
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
# Load model and modify the learning rate and entropy coefficient
# custom_objects = {
# "learning_rate": 0.0002
# }
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
# Set up callbacks
# opponent_interval = 35840 # stage_interval * num_envs = total_steps_per_stage
checkpoint_interval = 200000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu")
# stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
# model_params = {
# 'n_steps': 5,
# 'gamma': 0.99,
# 'gae_lambda':1,
# 'learning_rate': 7e-4,
# 'vf_coef': 0.5,
# 'ent_coef': 0.0,
# 'max_grad_norm':0.5,
# 'rms_prop_eps':1e-05
# }
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
model.learn(
total_timesteps=int(10000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback]#, stage_increase_callback]
)
env.close()
# Save the final model
model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip"))
if __name__ == "__main__":
main()