mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-05 15:40:45 +00:00
rgb image stack
This commit is contained in:
parent
16c80d5fba
commit
ded261ba69
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
@ -1,35 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
def conv2d_custom_init(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
|
|
||||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
|
||||||
nn.init.xavier_uniform_(conv.weight)
|
|
||||||
return conv
|
|
||||||
|
|
||||||
def custom_conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
|
|
||||||
return nn.Sequential(
|
|
||||||
conv2d_custom_init(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
|
||||||
nn.Relu(),
|
|
||||||
nn.MaxPool2d((2, 2))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Custom feature extractor (CNN)
|
|
||||||
class CustomCNN(nn.Module):
|
|
||||||
def __init__(self, num_frames, num_moves, num_attacks):
|
|
||||||
super(CustomCNN, self).__init__()
|
|
||||||
self.num_moves = num_moves
|
|
||||||
self.num_attacks = num_attacks
|
|
||||||
self.cnn = nn.Sequential(
|
|
||||||
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(16384, self.features_dim),
|
|
||||||
nn.ReLU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.cnn(observations)
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
47
004_image_stack_ram_based_reward_custom/check_reward.py
Normal file
47
004_image_stack_ram_based_reward_custom/check_reward.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import retro
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
|
||||||
|
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||||
|
|
||||||
|
def make_env(game, state):
|
||||||
|
def _init():
|
||||||
|
env = retro.make(
|
||||||
|
game=game,
|
||||||
|
state=state,
|
||||||
|
use_restricted_actions=retro.Actions.FILTERED,
|
||||||
|
obs_type=retro.Observations.IMAGE
|
||||||
|
)
|
||||||
|
env = StreetFighterCustomWrapper(env)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||||
|
state = "Champion.Level1.RyuVsGuile"
|
||||||
|
|
||||||
|
env = make_env(game, state)()
|
||||||
|
env = Monitor(env, 'logs/')
|
||||||
|
|
||||||
|
num_episodes = 30
|
||||||
|
episode_reward_sum = 0
|
||||||
|
for _ in range(num_episodes):
|
||||||
|
done = False
|
||||||
|
obs = env.reset()
|
||||||
|
total_reward = 0
|
||||||
|
while not done:
|
||||||
|
timestamp = time.time()
|
||||||
|
obs, reward, done, info = env.step(env.action_space.sample())
|
||||||
|
|
||||||
|
# Note that if player wins but only has 0 HP left, the winning reward is still 0, so it won't be printed.
|
||||||
|
if reward != 0:
|
||||||
|
total_reward += reward
|
||||||
|
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
|
||||||
|
env.render()
|
||||||
|
# time.sleep(0.005)
|
||||||
|
|
||||||
|
print("Total reward: {}".format(total_reward))
|
||||||
|
episode_reward_sum += total_reward
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
print("Average reward for random strategy: {}".format(episode_reward_sum/num_episodes))
|
24
004_image_stack_ram_based_reward_custom/custom_cnn.py
Normal file
24
004_image_stack_ram_based_reward_custom/custom_cnn.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||||
|
|
||||||
|
# Custom feature extractor (CNN)
|
||||||
|
class CustomCNN(BaseFeaturesExtractor):
|
||||||
|
def __init__(self, observation_space: gym.Space):
|
||||||
|
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||||
|
self.cnn = nn.Sequential(
|
||||||
|
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(16384, self.features_dim),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.cnn(observations)
|
||||||
|
|
52
004_image_stack_ram_based_reward_custom/evaluate.py
Normal file
52
004_image_stack_ram_based_reward_custom/evaluate.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import retro
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
|
|
||||||
|
from custom_cnn import CustomCNN
|
||||||
|
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||||
|
|
||||||
|
def make_env(game, state):
|
||||||
|
def _init():
|
||||||
|
env = retro.make(
|
||||||
|
game=game,
|
||||||
|
state=state,
|
||||||
|
use_restricted_actions=retro.Actions.FILTERED,
|
||||||
|
obs_type=retro.Observations.IMAGE
|
||||||
|
)
|
||||||
|
env = StreetFighterCustomWrapper(env)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||||
|
state_stages = [
|
||||||
|
"Champion.Level1.ChunLiVsGuile",
|
||||||
|
"Champion.Level2.ChunLiVsKen",
|
||||||
|
"Champion.Level3.ChunLiVsChunLi",
|
||||||
|
"Champion.Level4.ChunLiVsZangief",
|
||||||
|
"Champion.Level5.ChunLiVsDhalsim",
|
||||||
|
"Champion.Level6.ChunLiVsRyu",
|
||||||
|
"Champion.Level7.ChunLiVsEHonda",
|
||||||
|
"Champion.Level8.ChunLiVsBlanka",
|
||||||
|
"Champion.Level9.ChunLiVsBalrog",
|
||||||
|
"Champion.Level10.ChunLiVsVega",
|
||||||
|
"Champion.Level11.ChunLiVsSagat",
|
||||||
|
"Champion.Level12.ChunLiVsBison"
|
||||||
|
# Add other stages as necessary
|
||||||
|
]
|
||||||
|
|
||||||
|
env = make_env(game, state_stages[0])()
|
||||||
|
|
||||||
|
# Wrap the environment
|
||||||
|
# env = Monitor(env, 'logs/')
|
||||||
|
|
||||||
|
policy_kwargs = {'features_extractor_class': CustomCNN}
|
||||||
|
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
|
||||||
|
|
||||||
|
model = PPO.load(r"dummy_model_ppo_chunli")
|
||||||
|
# model.load(r"trained_models/ppo_chunli_864000_steps")
|
||||||
|
|
||||||
|
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10, deterministic=False, return_episode_rewards=True)
|
||||||
|
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
203
004_image_stack_ram_based_reward_custom/logs/monitor.csv
Normal file
203
004_image_stack_ram_based_reward_custom/logs/monitor.csv
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
#{"t_start": 1680450538.0135336, "env_id": null}
|
||||||
|
r,l,t
|
||||||
|
140,2721,9.666247
|
||||||
|
475,2051,16.014586
|
||||||
|
659,1841,22.078956
|
||||||
|
767,1374,25.788732
|
||||||
|
253,2591,35.3779
|
||||||
|
337,2369,42.159742
|
||||||
|
280,2490,51.242315
|
||||||
|
316,2761,59.177625
|
||||||
|
498,1532,65.235702
|
||||||
|
794,1294,68.963929
|
||||||
|
758,1295,72.564837
|
||||||
|
343,1862,78.748257
|
||||||
|
366,1293,82.443477
|
||||||
|
344,3356,95.206057
|
||||||
|
655,1716,99.877595
|
||||||
|
717,1496,105.722482
|
||||||
|
410,2494,113.515226
|
||||||
|
906,1496,119.400876
|
||||||
|
230,1769,124.036038
|
||||||
|
966,1043,127.542873
|
||||||
|
321,2284,136.329181
|
||||||
|
812,1519,140.423345
|
||||||
|
519,1729,146.851307
|
||||||
|
722,1412,150.934866
|
||||||
|
700,1646,157.093265
|
||||||
|
387,1965,163.827162
|
||||||
|
343,2604,171.611118
|
||||||
|
123,1625,177.426801
|
||||||
|
417,3407,187.666056
|
||||||
|
187,2616,196.882013
|
||||||
|
26,2142,203.498234
|
||||||
|
1056,1218,207.239627
|
||||||
|
927,1124,210.857355
|
||||||
|
47,1994,217.59863
|
||||||
|
510,1774,223.863903
|
||||||
|
377,1381,227.737895
|
||||||
|
481,1390,233.374429
|
||||||
|
658,1447,237.396851
|
||||||
|
349,3437,248.157832
|
||||||
|
718,1829,254.408287
|
||||||
|
180,2025,261.068938
|
||||||
|
313,1839,267.477898
|
||||||
|
902,975,270.676313
|
||||||
|
292,2950,280.273983
|
||||||
|
182,2268,287.275446
|
||||||
|
187,1953,293.603531
|
||||||
|
362,3130,303.389702
|
||||||
|
638,1779,309.646411
|
||||||
|
228,2179,316.57769
|
||||||
|
341,2109,323.229764
|
||||||
|
85,2241,330.295613
|
||||||
|
260,3871,343.38057
|
||||||
|
-72,2846,353.025726
|
||||||
|
319,2520,360.176952
|
||||||
|
121,2334,367.161462
|
||||||
|
-143,2696,376.380194
|
||||||
|
-10,2199,383.130368
|
||||||
|
254,3800,395.603224
|
||||||
|
51,3540,406.048309
|
||||||
|
340,2647,415.170668
|
||||||
|
331,6905,437.351264
|
||||||
|
741,1698,441.462111
|
||||||
|
331,3519,454.219213
|
||||||
|
489,2276,461.541512
|
||||||
|
368,3599,472.906958
|
||||||
|
-126,3151,482.855691
|
||||||
|
-235,3220,494.629744
|
||||||
|
-17,2824,502.354869
|
||||||
|
-82,3238,514.654498
|
||||||
|
-160,3254,525.209512
|
||||||
|
-150,2923,535.200898
|
||||||
|
-93,7337,558.950392
|
||||||
|
-275,2822,566.908611
|
||||||
|
245,3758,579.955454
|
||||||
|
-117,2586,589.282403
|
||||||
|
-230,2788,597.257352
|
||||||
|
90,2896,607.108898
|
||||||
|
967,1681,613.189581
|
||||||
|
-162,3178,623.550789
|
||||||
|
-111,2603,632.824294
|
||||||
|
-111,2160,639.682543
|
||||||
|
-91,3059,649.811759
|
||||||
|
552,3472,660.344764
|
||||||
|
156,2905,669.981916
|
||||||
|
506,1973,676.494263
|
||||||
|
-29,3483,687.2097
|
||||||
|
-342,2388,694.675364
|
||||||
|
-98,7522,720.510979
|
||||||
|
268,3198,730.537025
|
||||||
|
123,7866,755.786125
|
||||||
|
41,3279,766.078602
|
||||||
|
454,3852,778.612931
|
||||||
|
241,3330,788.852483
|
||||||
|
-197,2695,796.350531
|
||||||
|
16,2708,805.798072
|
||||||
|
-276,2921,815.975549
|
||||||
|
230,7838,842.84947
|
||||||
|
-343,2955,852.604732
|
||||||
|
-124,2710,860.170467
|
||||||
|
-318,1477,866.041014
|
||||||
|
-37,3970,877.800582
|
||||||
|
-26,2989,888.043861
|
||||||
|
732,8207,914.516506
|
||||||
|
326,3080,924.643157
|
||||||
|
-189,2372,933.448068
|
||||||
|
275,8785,961.099634
|
||||||
|
-189,2881,970.522553
|
||||||
|
108,3190,980.352367
|
||||||
|
-351,2851,989.817974
|
||||||
|
-287,3287,1000.238004
|
||||||
|
-262,2847,1010.036809
|
||||||
|
-206,7848,1034.80143
|
||||||
|
102,7532,1058.580046
|
||||||
|
0,8037,1084.80434
|
||||||
|
110,4073,1097.11781
|
||||||
|
421,3413,1107.327183
|
||||||
|
-203,3154,1117.49272
|
||||||
|
655,14205,1161.34724
|
||||||
|
126,3993,1174.695066
|
||||||
|
48,3832,1187.983629
|
||||||
|
68,2995,1197.869093
|
||||||
|
80,3252,1207.617986
|
||||||
|
84,3776,1219.716212
|
||||||
|
-192,3176,1229.660275
|
||||||
|
-143,2819,1237.98842
|
||||||
|
10,2730,1247.57504
|
||||||
|
-191,2460,1256.580759
|
||||||
|
-28,2546,1263.612297
|
||||||
|
192,3534,1273.532951
|
||||||
|
268,3797,1285.940749
|
||||||
|
-98,3139,1296.872381
|
||||||
|
75,3568,1309.84493
|
||||||
|
-123,7274,1332.694059
|
||||||
|
326,3440,1342.804182
|
||||||
|
349,3737,1355.733267
|
||||||
|
22,2943,1366.020118
|
||||||
|
-202,3018,1375.968116
|
||||||
|
888,1928,1382.114744
|
||||||
|
-209,1646,1386.164465
|
||||||
|
46,1613,1391.748134
|
||||||
|
-318,2434,1398.519993
|
||||||
|
-275,2288,1405.521706
|
||||||
|
397,3578,1418.447945
|
||||||
|
317,2150,1425.714679
|
||||||
|
-75,2716,1435.669649
|
||||||
|
-93,2679,1442.908609
|
||||||
|
564,2987,1452.369631
|
||||||
|
216,2904,1461.84269
|
||||||
|
44,2300,1468.801676
|
||||||
|
401,1470,1474.693752
|
||||||
|
381,3590,1485.786877
|
||||||
|
256,2522,1494.610591
|
||||||
|
-141,1773,1498.837729
|
||||||
|
335,2651,1507.840033
|
||||||
|
860,1561,1511.859426
|
||||||
|
357,1743,1517.856122
|
||||||
|
846,1433,1523.547
|
||||||
|
702,720,1524.743536
|
||||||
|
81,3314,1535.593991
|
||||||
|
608,1468,1541.5119
|
||||||
|
464,2507,1549.079192
|
||||||
|
382,1465,1554.894497
|
||||||
|
661,2153,1561.58434
|
||||||
|
-220,2172,1568.434792
|
||||||
|
470,2597,1577.624233
|
||||||
|
606,1471,1581.617123
|
||||||
|
128,2485,1589.245833
|
||||||
|
-151,2076,1596.189308
|
||||||
|
-34,1775,1602.548944
|
||||||
|
7,2518,1611.704006
|
||||||
|
-73,1256,1615.417475
|
||||||
|
981,952,1618.628284
|
||||||
|
537,1555,1622.779646
|
||||||
|
336,2464,1631.451718
|
||||||
|
490,2070,1638.048966
|
||||||
|
337,3439,1648.578499
|
||||||
|
367,2505,1657.904252
|
||||||
|
365,2554,1665.472756
|
||||||
|
654,1061,1669.007638
|
||||||
|
334,3193,1679.491146
|
||||||
|
-125,1751,1685.74332
|
||||||
|
342,2740,1695.388833
|
||||||
|
541,1674,1699.803759
|
||||||
|
303,3218,1709.898359
|
||||||
|
62,2140,1716.650506
|
||||||
|
37,1838,1722.973549
|
||||||
|
-9,2999,1732.744556
|
||||||
|
-47,2898,1742.188218
|
||||||
|
462,3518,1752.401364
|
||||||
|
206,2255,1760.763199
|
||||||
|
494,2294,1767.674661
|
||||||
|
198,2530,1775.444748
|
||||||
|
149,2196,1782.305408
|
||||||
|
593,2317,1791.355
|
||||||
|
349,2208,1798.589246
|
||||||
|
-74,1673,1804.620136
|
||||||
|
41,2712,1811.981201
|
||||||
|
432,2759,1821.619351
|
||||||
|
75,2880,1831.46073
|
||||||
|
397,1858,1837.717627
|
||||||
|
-204,3008,1845.908291
|
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
@ -0,0 +1,79 @@
|
|||||||
|
import collections
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Custom environment wrapper
|
||||||
|
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, testing=False):
|
||||||
|
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||||
|
self.env = env
|
||||||
|
|
||||||
|
# Use a deque to store the last 4 frames
|
||||||
|
self.num_frames = 3
|
||||||
|
self.frame_stack = collections.deque(maxlen=self.num_frames)
|
||||||
|
|
||||||
|
self.reward_coeff = 3
|
||||||
|
|
||||||
|
self.full_hp = 176
|
||||||
|
self.prev_player_health = self.full_hp
|
||||||
|
self.prev_oppont_health = self.full_hp
|
||||||
|
|
||||||
|
# Update observation space to include stacked grayscale images
|
||||||
|
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(100, 128, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
self.testing = testing
|
||||||
|
|
||||||
|
def _preprocess_observation(self, observation):
|
||||||
|
|
||||||
|
# Stack the downsampled frames.
|
||||||
|
self.frame_stack.append(observation[::2, ::2, :])
|
||||||
|
|
||||||
|
# Stack the R, G, B channel of each frame and return the "image".
|
||||||
|
stacked_image = np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1)
|
||||||
|
return stacked_image
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
observation = self.env.reset()
|
||||||
|
self.prev_player_health = self.full_hp
|
||||||
|
self.prev_oppont_health = self.full_hp
|
||||||
|
|
||||||
|
# Clear the frame stack and add the first observation [num_frames] times
|
||||||
|
self.frame_stack.clear()
|
||||||
|
for _ in range(self.num_frames):
|
||||||
|
self.frame_stack.append(observation[::2, ::2, :])
|
||||||
|
|
||||||
|
return np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
|
||||||
|
obs, _reward, _done, info = self.env.step(action)
|
||||||
|
curr_player_health = info['health']
|
||||||
|
curr_oppont_health = info['enemy_health']
|
||||||
|
|
||||||
|
# Game is over and player loses.
|
||||||
|
if curr_player_health < 0:
|
||||||
|
custom_reward = -curr_oppont_health # Use the remaining health points of opponent as penalty.
|
||||||
|
# If the opponent also has negative health points, it's a even game and the reward is +1.
|
||||||
|
custom_done = True
|
||||||
|
|
||||||
|
# Game is over and player wins.
|
||||||
|
elif curr_oppont_health < 0:
|
||||||
|
custom_reward = curr_player_health * self.reward_coeff # Use the remaining health points of player as reward.
|
||||||
|
# Multiply by reward_coeff to make the reward larger than the penalty to avoid cowardice of agent.
|
||||||
|
custom_done = True
|
||||||
|
|
||||||
|
# While the fighting is still going on.
|
||||||
|
else:
|
||||||
|
custom_reward = self.reward_coeff * (self.prev_oppont_health - curr_oppont_health) - (self.prev_player_health - curr_player_health)
|
||||||
|
self.prev_player_health = curr_player_health
|
||||||
|
self.prev_oppont_health = curr_oppont_health
|
||||||
|
custom_done = False
|
||||||
|
|
||||||
|
# During testing, the session should always keep going.
|
||||||
|
if self.testing:
|
||||||
|
custom_done = False
|
||||||
|
|
||||||
|
# Max reward is 6 * full_hp = 1054 (damage * 3 + winning_reward * 3)
|
||||||
|
return self._preprocess_observation(obs), custom_reward, custom_done, info
|
||||||
|
|
76
004_image_stack_ram_based_reward_custom/test.py
Normal file
76
004_image_stack_ram_based_reward_custom/test.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import retro
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
|
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||||
|
|
||||||
|
def make_env(game, state):
|
||||||
|
def _init():
|
||||||
|
env = retro.make(
|
||||||
|
game=game,
|
||||||
|
state=state,
|
||||||
|
use_restricted_actions=retro.Actions.FILTERED,
|
||||||
|
obs_type=retro.Observations.IMAGE
|
||||||
|
)
|
||||||
|
env = StreetFighterCustomWrapper(env)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||||
|
state_stages = [
|
||||||
|
"Champion.Level1.RyuVsGuile",
|
||||||
|
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3 | -20.4
|
||||||
|
"ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||||
|
"Champion.Level2.ChunLiVsKen",
|
||||||
|
"Champion.Level3.ChunLiVsChunLi",
|
||||||
|
"Champion.Level4.ChunLiVsZangief",
|
||||||
|
"Champion.Level5.ChunLiVsDhalsim",
|
||||||
|
"Champion.Level6.ChunLiVsRyu",
|
||||||
|
"Champion.Level7.ChunLiVsEHonda",
|
||||||
|
"Champion.Level8.ChunLiVsBlanka",
|
||||||
|
"Champion.Level9.ChunLiVsBalrog",
|
||||||
|
"Champion.Level10.ChunLiVsVega",
|
||||||
|
"Champion.Level11.ChunLiVsSagat",
|
||||||
|
"Champion.Level12.ChunLiVsBison"
|
||||||
|
# Add other stages as necessary
|
||||||
|
]
|
||||||
|
|
||||||
|
env = make_env(game, state_stages[0])()
|
||||||
|
|
||||||
|
model = PPO(
|
||||||
|
"CnnPolicy",
|
||||||
|
env,
|
||||||
|
verbose=1
|
||||||
|
)
|
||||||
|
model_path = r"trained_models_level_1/ppo_ryu_000000_steps"
|
||||||
|
model.load(model_path)
|
||||||
|
# Average reward for optuna/trial_1_best_model: -82.3
|
||||||
|
# Average reward for optuna/trial_9_best_model: 36.7 | -86.23
|
||||||
|
# Average reward for trained_models/ppo_chunli_5376000_steps: -77.8
|
||||||
|
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
num_episodes = 30
|
||||||
|
episode_reward_sum = 0
|
||||||
|
for _ in range(num_episodes):
|
||||||
|
done = False
|
||||||
|
obs = env.reset()
|
||||||
|
total_reward = 0
|
||||||
|
while not done:
|
||||||
|
timestamp = time.time()
|
||||||
|
action, _states = model.predict(obs)
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
if reward != 0:
|
||||||
|
total_reward += reward
|
||||||
|
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
|
||||||
|
env.render()
|
||||||
|
time.sleep(0.01)
|
||||||
|
print("Total reward: {}".format(total_reward))
|
||||||
|
episode_reward_sum += total_reward
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
print("Average reward for {}: {}".format(model_path, episode_reward_sum/num_episodes))
|
154
004_image_stack_ram_based_reward_custom/train.py
Normal file
154
004_image_stack_ram_based_reward_custom/train.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import retro
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.monitor import Monitor
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||||
|
|
||||||
|
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||||
|
|
||||||
|
LOG_DIR = 'logs/'
|
||||||
|
os.makedirs(LOG_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
class RandomOpponentChangeCallback(BaseCallback):
|
||||||
|
def __init__(self, stages, opponent_interval, verbose=0):
|
||||||
|
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
||||||
|
self.stages = stages
|
||||||
|
self.opponent_interval = opponent_interval
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
if self.n_calls % self.opponent_interval == 0:
|
||||||
|
new_state = random.choice(self.stages)
|
||||||
|
print("\nCurrent state:", new_state)
|
||||||
|
self.training_env.env_method("load_state", new_state, indices=None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# class StageIncreaseCallback(BaseCallback):
|
||||||
|
# def __init__(self, stages, stage_interval, save_dir, verbose=0):
|
||||||
|
# super(StageIncreaseCallback, self).__init__(verbose)
|
||||||
|
# self.stages = stages
|
||||||
|
# self.stage_interval = stage_interval
|
||||||
|
# self.save_dir = save_dir
|
||||||
|
# self.current_stage = 0
|
||||||
|
|
||||||
|
# def _on_step(self) -> bool:
|
||||||
|
# if self.n_calls % self.stage_interval == 0 and self.current_stage < len(self.stages) - 1:
|
||||||
|
# self.current_stage += 1
|
||||||
|
# new_state = self.stages[self.current_stage]
|
||||||
|
# self.training_env.env_method("load_state", new_state, indices=None)
|
||||||
|
# self.model.save(os.path.join(self.save_dir, f"ppo_chunli_stage_{self.current_stage}.zip"))
|
||||||
|
# return True
|
||||||
|
|
||||||
|
def make_env(game, state):
|
||||||
|
def _init():
|
||||||
|
env = retro.make(
|
||||||
|
game=game,
|
||||||
|
state=state,
|
||||||
|
use_restricted_actions=retro.Actions.FILTERED,
|
||||||
|
obs_type=retro.Observations.IMAGE
|
||||||
|
)
|
||||||
|
env = StreetFighterCustomWrapper(env)
|
||||||
|
return env
|
||||||
|
return _init
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Set up the environment and model
|
||||||
|
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||||
|
|
||||||
|
state_stages = [
|
||||||
|
"Champion.Level1.RyuVsGuile",
|
||||||
|
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
|
||||||
|
"Champion.Level2.ChunLiVsKen",
|
||||||
|
"Champion.Level3.ChunLiVsChunLi",
|
||||||
|
"Champion.Level4.ChunLiVsZangief",
|
||||||
|
"Champion.Level5.ChunLiVsDhalsim",
|
||||||
|
"Champion.Level6.ChunLiVsRyu",
|
||||||
|
"Champion.Level7.ChunLiVsEHonda",
|
||||||
|
"Champion.Level8.ChunLiVsBlanka",
|
||||||
|
"Champion.Level9.ChunLiVsBalrog",
|
||||||
|
"Champion.Level10.ChunLiVsVega",
|
||||||
|
"Champion.Level11.ChunLiVsSagat",
|
||||||
|
"Champion.Level12.ChunLiVsBison"
|
||||||
|
# Add other stages as necessary
|
||||||
|
]
|
||||||
|
|
||||||
|
# state_stages = [
|
||||||
|
# "ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||||
|
# "ChampionX.Level2.ChunLiVsChunLi",
|
||||||
|
# "ChampionX.Level3.ChunLiVsZangief",
|
||||||
|
# "ChampionX.Level4.ChunLiVsDhalsim",
|
||||||
|
# "ChampionX.Level5.ChunLiVsRyu",
|
||||||
|
# "ChampionX.Level6.ChunLiVsEHonda",
|
||||||
|
# "ChampionX.Level7.ChunLiVsBlanka",
|
||||||
|
# "ChampionX.Level8.ChunLiVsGuile",
|
||||||
|
# "ChampionX.Level9.ChunLiVsBalrog",
|
||||||
|
# "ChampionX.Level10.ChunLiVsVega",
|
||||||
|
# "ChampionX.Level11.ChunLiVsSagat",
|
||||||
|
# "ChampionX.Level12.ChunLiVsBison"
|
||||||
|
# # Add other stages as necessary
|
||||||
|
# ]
|
||||||
|
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
|
||||||
|
|
||||||
|
env = make_env(game, state_stages[0])()
|
||||||
|
|
||||||
|
# Warp env in Monitor wrapper to record training progress
|
||||||
|
env = Monitor(env, LOG_DIR)
|
||||||
|
|
||||||
|
model = PPO(
|
||||||
|
"CnnPolicy",
|
||||||
|
env,
|
||||||
|
device="cuda",
|
||||||
|
verbose=1,
|
||||||
|
n_steps=1024,
|
||||||
|
batch_size=64,
|
||||||
|
learning_rate=1e-4,
|
||||||
|
ent_coef=0.01,
|
||||||
|
clip_range=0.2,
|
||||||
|
gamma=0.95,
|
||||||
|
gae_lambda=0.81322,
|
||||||
|
tensorboard_log="logs/"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the save directory
|
||||||
|
save_dir = "trained_models_ryu_level_1_reward_x3"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Load the model from file
|
||||||
|
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
|
||||||
|
|
||||||
|
# Load model and modify the learning rate and entropy coefficient
|
||||||
|
# custom_objects = {
|
||||||
|
# "learning_rate": 0.0002
|
||||||
|
# }
|
||||||
|
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
|
||||||
|
|
||||||
|
# Set up callbacks
|
||||||
|
# opponent_interval = 35840 # stage_interval * num_envs = total_steps_per_stage
|
||||||
|
checkpoint_interval = 200000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||||
|
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu")
|
||||||
|
# stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||||
|
|
||||||
|
# model_params = {
|
||||||
|
# 'n_steps': 5,
|
||||||
|
# 'gamma': 0.99,
|
||||||
|
# 'gae_lambda':1,
|
||||||
|
# 'learning_rate': 7e-4,
|
||||||
|
# 'vf_coef': 0.5,
|
||||||
|
# 'ent_coef': 0.0,
|
||||||
|
# 'max_grad_norm':0.5,
|
||||||
|
# 'rms_prop_eps':1e-05
|
||||||
|
# }
|
||||||
|
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
|
||||||
|
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=int(10000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||||
|
callback=[checkpoint_callback]#, stage_increase_callback]
|
||||||
|
)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
# Save the final model
|
||||||
|
model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip"))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user