mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-03 22:50:43 +00:00
rgb image stack
This commit is contained in:
parent
16c80d5fba
commit
ded261ba69
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
@ -1,35 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
def conv2d_custom_init(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
nn.init.xavier_uniform_(conv.weight)
|
||||
return conv
|
||||
|
||||
def custom_conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False):
|
||||
return nn.Sequential(
|
||||
conv2d_custom_init(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
|
||||
nn.Relu(),
|
||||
nn.MaxPool2d((2, 2))
|
||||
)
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(nn.Module):
|
||||
def __init__(self, num_frames, num_moves, num_attacks):
|
||||
super(CustomCNN, self).__init__()
|
||||
self.num_moves = num_moves
|
||||
self.num_attacks = num_attacks
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(16384, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
|
Binary file not shown.
Binary file not shown.
47
004_image_stack_ram_based_reward_custom/check_reward.py
Normal file
47
004_image_stack_ram_based_reward_custom/check_reward.py
Normal file
@ -0,0 +1,47 @@
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state = "Champion.Level1.RyuVsGuile"
|
||||
|
||||
env = make_env(game, state)()
|
||||
env = Monitor(env, 'logs/')
|
||||
|
||||
num_episodes = 30
|
||||
episode_reward_sum = 0
|
||||
for _ in range(num_episodes):
|
||||
done = False
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
|
||||
# Note that if player wins but only has 0 HP left, the winning reward is still 0, so it won't be printed.
|
||||
if reward != 0:
|
||||
total_reward += reward
|
||||
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
|
||||
env.render()
|
||||
# time.sleep(0.005)
|
||||
|
||||
print("Total reward: {}".format(total_reward))
|
||||
episode_reward_sum += total_reward
|
||||
|
||||
env.close()
|
||||
print("Average reward for random strategy: {}".format(episode_reward_sum/num_episodes))
|
24
004_image_stack_ram_based_reward_custom/custom_cnn.py
Normal file
24
004_image_stack_ram_based_reward_custom/custom_cnn.py
Normal file
@ -0,0 +1,24 @@
|
||||
import gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(16384, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
|
52
004_image_stack_ram_based_reward_custom/evaluate.py
Normal file
52
004_image_stack_ram_based_reward_custom/evaluate.py
Normal file
@ -0,0 +1,52 @@
|
||||
import retro
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
# Wrap the environment
|
||||
# env = Monitor(env, 'logs/')
|
||||
|
||||
policy_kwargs = {'features_extractor_class': CustomCNN}
|
||||
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
|
||||
|
||||
model = PPO.load(r"dummy_model_ppo_chunli")
|
||||
# model.load(r"trained_models/ppo_chunli_864000_steps")
|
||||
|
||||
mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10, deterministic=False, return_episode_rewards=True)
|
||||
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
203
004_image_stack_ram_based_reward_custom/logs/monitor.csv
Normal file
203
004_image_stack_ram_based_reward_custom/logs/monitor.csv
Normal file
@ -0,0 +1,203 @@
|
||||
#{"t_start": 1680450538.0135336, "env_id": null}
|
||||
r,l,t
|
||||
140,2721,9.666247
|
||||
475,2051,16.014586
|
||||
659,1841,22.078956
|
||||
767,1374,25.788732
|
||||
253,2591,35.3779
|
||||
337,2369,42.159742
|
||||
280,2490,51.242315
|
||||
316,2761,59.177625
|
||||
498,1532,65.235702
|
||||
794,1294,68.963929
|
||||
758,1295,72.564837
|
||||
343,1862,78.748257
|
||||
366,1293,82.443477
|
||||
344,3356,95.206057
|
||||
655,1716,99.877595
|
||||
717,1496,105.722482
|
||||
410,2494,113.515226
|
||||
906,1496,119.400876
|
||||
230,1769,124.036038
|
||||
966,1043,127.542873
|
||||
321,2284,136.329181
|
||||
812,1519,140.423345
|
||||
519,1729,146.851307
|
||||
722,1412,150.934866
|
||||
700,1646,157.093265
|
||||
387,1965,163.827162
|
||||
343,2604,171.611118
|
||||
123,1625,177.426801
|
||||
417,3407,187.666056
|
||||
187,2616,196.882013
|
||||
26,2142,203.498234
|
||||
1056,1218,207.239627
|
||||
927,1124,210.857355
|
||||
47,1994,217.59863
|
||||
510,1774,223.863903
|
||||
377,1381,227.737895
|
||||
481,1390,233.374429
|
||||
658,1447,237.396851
|
||||
349,3437,248.157832
|
||||
718,1829,254.408287
|
||||
180,2025,261.068938
|
||||
313,1839,267.477898
|
||||
902,975,270.676313
|
||||
292,2950,280.273983
|
||||
182,2268,287.275446
|
||||
187,1953,293.603531
|
||||
362,3130,303.389702
|
||||
638,1779,309.646411
|
||||
228,2179,316.57769
|
||||
341,2109,323.229764
|
||||
85,2241,330.295613
|
||||
260,3871,343.38057
|
||||
-72,2846,353.025726
|
||||
319,2520,360.176952
|
||||
121,2334,367.161462
|
||||
-143,2696,376.380194
|
||||
-10,2199,383.130368
|
||||
254,3800,395.603224
|
||||
51,3540,406.048309
|
||||
340,2647,415.170668
|
||||
331,6905,437.351264
|
||||
741,1698,441.462111
|
||||
331,3519,454.219213
|
||||
489,2276,461.541512
|
||||
368,3599,472.906958
|
||||
-126,3151,482.855691
|
||||
-235,3220,494.629744
|
||||
-17,2824,502.354869
|
||||
-82,3238,514.654498
|
||||
-160,3254,525.209512
|
||||
-150,2923,535.200898
|
||||
-93,7337,558.950392
|
||||
-275,2822,566.908611
|
||||
245,3758,579.955454
|
||||
-117,2586,589.282403
|
||||
-230,2788,597.257352
|
||||
90,2896,607.108898
|
||||
967,1681,613.189581
|
||||
-162,3178,623.550789
|
||||
-111,2603,632.824294
|
||||
-111,2160,639.682543
|
||||
-91,3059,649.811759
|
||||
552,3472,660.344764
|
||||
156,2905,669.981916
|
||||
506,1973,676.494263
|
||||
-29,3483,687.2097
|
||||
-342,2388,694.675364
|
||||
-98,7522,720.510979
|
||||
268,3198,730.537025
|
||||
123,7866,755.786125
|
||||
41,3279,766.078602
|
||||
454,3852,778.612931
|
||||
241,3330,788.852483
|
||||
-197,2695,796.350531
|
||||
16,2708,805.798072
|
||||
-276,2921,815.975549
|
||||
230,7838,842.84947
|
||||
-343,2955,852.604732
|
||||
-124,2710,860.170467
|
||||
-318,1477,866.041014
|
||||
-37,3970,877.800582
|
||||
-26,2989,888.043861
|
||||
732,8207,914.516506
|
||||
326,3080,924.643157
|
||||
-189,2372,933.448068
|
||||
275,8785,961.099634
|
||||
-189,2881,970.522553
|
||||
108,3190,980.352367
|
||||
-351,2851,989.817974
|
||||
-287,3287,1000.238004
|
||||
-262,2847,1010.036809
|
||||
-206,7848,1034.80143
|
||||
102,7532,1058.580046
|
||||
0,8037,1084.80434
|
||||
110,4073,1097.11781
|
||||
421,3413,1107.327183
|
||||
-203,3154,1117.49272
|
||||
655,14205,1161.34724
|
||||
126,3993,1174.695066
|
||||
48,3832,1187.983629
|
||||
68,2995,1197.869093
|
||||
80,3252,1207.617986
|
||||
84,3776,1219.716212
|
||||
-192,3176,1229.660275
|
||||
-143,2819,1237.98842
|
||||
10,2730,1247.57504
|
||||
-191,2460,1256.580759
|
||||
-28,2546,1263.612297
|
||||
192,3534,1273.532951
|
||||
268,3797,1285.940749
|
||||
-98,3139,1296.872381
|
||||
75,3568,1309.84493
|
||||
-123,7274,1332.694059
|
||||
326,3440,1342.804182
|
||||
349,3737,1355.733267
|
||||
22,2943,1366.020118
|
||||
-202,3018,1375.968116
|
||||
888,1928,1382.114744
|
||||
-209,1646,1386.164465
|
||||
46,1613,1391.748134
|
||||
-318,2434,1398.519993
|
||||
-275,2288,1405.521706
|
||||
397,3578,1418.447945
|
||||
317,2150,1425.714679
|
||||
-75,2716,1435.669649
|
||||
-93,2679,1442.908609
|
||||
564,2987,1452.369631
|
||||
216,2904,1461.84269
|
||||
44,2300,1468.801676
|
||||
401,1470,1474.693752
|
||||
381,3590,1485.786877
|
||||
256,2522,1494.610591
|
||||
-141,1773,1498.837729
|
||||
335,2651,1507.840033
|
||||
860,1561,1511.859426
|
||||
357,1743,1517.856122
|
||||
846,1433,1523.547
|
||||
702,720,1524.743536
|
||||
81,3314,1535.593991
|
||||
608,1468,1541.5119
|
||||
464,2507,1549.079192
|
||||
382,1465,1554.894497
|
||||
661,2153,1561.58434
|
||||
-220,2172,1568.434792
|
||||
470,2597,1577.624233
|
||||
606,1471,1581.617123
|
||||
128,2485,1589.245833
|
||||
-151,2076,1596.189308
|
||||
-34,1775,1602.548944
|
||||
7,2518,1611.704006
|
||||
-73,1256,1615.417475
|
||||
981,952,1618.628284
|
||||
537,1555,1622.779646
|
||||
336,2464,1631.451718
|
||||
490,2070,1638.048966
|
||||
337,3439,1648.578499
|
||||
367,2505,1657.904252
|
||||
365,2554,1665.472756
|
||||
654,1061,1669.007638
|
||||
334,3193,1679.491146
|
||||
-125,1751,1685.74332
|
||||
342,2740,1695.388833
|
||||
541,1674,1699.803759
|
||||
303,3218,1709.898359
|
||||
62,2140,1716.650506
|
||||
37,1838,1722.973549
|
||||
-9,2999,1732.744556
|
||||
-47,2898,1742.188218
|
||||
462,3518,1752.401364
|
||||
206,2255,1760.763199
|
||||
494,2294,1767.674661
|
||||
198,2530,1775.444748
|
||||
149,2196,1782.305408
|
||||
593,2317,1791.355
|
||||
349,2208,1798.589246
|
||||
-74,1673,1804.620136
|
||||
41,2712,1811.981201
|
||||
432,2759,1821.619351
|
||||
75,2880,1831.46073
|
||||
397,1858,1837.717627
|
||||
-204,3008,1845.908291
|
Can't render this file because it contains an unexpected character in line 1 and column 3.
|
@ -0,0 +1,79 @@
|
||||
import collections
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
# Custom environment wrapper
|
||||
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
def __init__(self, env, testing=False):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
self.env = env
|
||||
|
||||
# Use a deque to store the last 4 frames
|
||||
self.num_frames = 3
|
||||
self.frame_stack = collections.deque(maxlen=self.num_frames)
|
||||
|
||||
self.reward_coeff = 3
|
||||
|
||||
self.full_hp = 176
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
# Update observation space to include stacked grayscale images
|
||||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(100, 128, 3), dtype=np.uint8)
|
||||
|
||||
self.testing = testing
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
|
||||
# Stack the downsampled frames.
|
||||
self.frame_stack.append(observation[::2, ::2, :])
|
||||
|
||||
# Stack the R, G, B channel of each frame and return the "image".
|
||||
stacked_image = np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1)
|
||||
return stacked_image
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
self.prev_player_health = self.full_hp
|
||||
self.prev_oppont_health = self.full_hp
|
||||
|
||||
# Clear the frame stack and add the first observation [num_frames] times
|
||||
self.frame_stack.clear()
|
||||
for _ in range(self.num_frames):
|
||||
self.frame_stack.append(observation[::2, ::2, :])
|
||||
|
||||
return np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1)
|
||||
|
||||
def step(self, action):
|
||||
|
||||
obs, _reward, _done, info = self.env.step(action)
|
||||
curr_player_health = info['health']
|
||||
curr_oppont_health = info['enemy_health']
|
||||
|
||||
# Game is over and player loses.
|
||||
if curr_player_health < 0:
|
||||
custom_reward = -curr_oppont_health # Use the remaining health points of opponent as penalty.
|
||||
# If the opponent also has negative health points, it's a even game and the reward is +1.
|
||||
custom_done = True
|
||||
|
||||
# Game is over and player wins.
|
||||
elif curr_oppont_health < 0:
|
||||
custom_reward = curr_player_health * self.reward_coeff # Use the remaining health points of player as reward.
|
||||
# Multiply by reward_coeff to make the reward larger than the penalty to avoid cowardice of agent.
|
||||
custom_done = True
|
||||
|
||||
# While the fighting is still going on.
|
||||
else:
|
||||
custom_reward = self.reward_coeff * (self.prev_oppont_health - curr_oppont_health) - (self.prev_player_health - curr_player_health)
|
||||
self.prev_player_health = curr_player_health
|
||||
self.prev_oppont_health = curr_oppont_health
|
||||
custom_done = False
|
||||
|
||||
# During testing, the session should always keep going.
|
||||
if self.testing:
|
||||
custom_done = False
|
||||
|
||||
# Max reward is 6 * full_hp = 1054 (damage * 3 + winning_reward * 3)
|
||||
return self._preprocess_observation(obs), custom_reward, custom_done, info
|
||||
|
76
004_image_stack_ram_based_reward_custom/test.py
Normal file
76
004_image_stack_ram_based_reward_custom/test.py
Normal file
@ -0,0 +1,76 @@
|
||||
import time
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.RyuVsGuile",
|
||||
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3 | -20.4
|
||||
"ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
verbose=1
|
||||
)
|
||||
model_path = r"trained_models_level_1/ppo_ryu_000000_steps"
|
||||
model.load(model_path)
|
||||
# Average reward for optuna/trial_1_best_model: -82.3
|
||||
# Average reward for optuna/trial_9_best_model: 36.7 | -86.23
|
||||
# Average reward for trained_models/ppo_chunli_5376000_steps: -77.8
|
||||
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
||||
num_episodes = 30
|
||||
episode_reward_sum = 0
|
||||
for _ in range(num_episodes):
|
||||
done = False
|
||||
obs = env.reset()
|
||||
total_reward = 0
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
action, _states = model.predict(obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
if reward != 0:
|
||||
total_reward += reward
|
||||
print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health']))
|
||||
env.render()
|
||||
time.sleep(0.01)
|
||||
print("Total reward: {}".format(total_reward))
|
||||
episode_reward_sum += total_reward
|
||||
|
||||
env.close()
|
||||
print("Average reward for {}: {}".format(model_path, episode_reward_sum/num_episodes))
|
154
004_image_stack_ram_based_reward_custom/train.py
Normal file
154
004_image_stack_ram_based_reward_custom/train.py
Normal file
@ -0,0 +1,154 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import retro
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
|
||||
|
||||
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
|
||||
|
||||
LOG_DIR = 'logs/'
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
class RandomOpponentChangeCallback(BaseCallback):
|
||||
def __init__(self, stages, opponent_interval, verbose=0):
|
||||
super(RandomOpponentChangeCallback, self).__init__(verbose)
|
||||
self.stages = stages
|
||||
self.opponent_interval = opponent_interval
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.opponent_interval == 0:
|
||||
new_state = random.choice(self.stages)
|
||||
print("\nCurrent state:", new_state)
|
||||
self.training_env.env_method("load_state", new_state, indices=None)
|
||||
return True
|
||||
|
||||
# class StageIncreaseCallback(BaseCallback):
|
||||
# def __init__(self, stages, stage_interval, save_dir, verbose=0):
|
||||
# super(StageIncreaseCallback, self).__init__(verbose)
|
||||
# self.stages = stages
|
||||
# self.stage_interval = stage_interval
|
||||
# self.save_dir = save_dir
|
||||
# self.current_stage = 0
|
||||
|
||||
# def _on_step(self) -> bool:
|
||||
# if self.n_calls % self.stage_interval == 0 and self.current_stage < len(self.stages) - 1:
|
||||
# self.current_stage += 1
|
||||
# new_state = self.stages[self.current_stage]
|
||||
# self.training_env.env_method("load_state", new_state, indices=None)
|
||||
# self.model.save(os.path.join(self.save_dir, f"ppo_chunli_stage_{self.current_stage}.zip"))
|
||||
# return True
|
||||
|
||||
def make_env(game, state):
|
||||
def _init():
|
||||
env = retro.make(
|
||||
game=game,
|
||||
state=state,
|
||||
use_restricted_actions=retro.Actions.FILTERED,
|
||||
obs_type=retro.Observations.IMAGE
|
||||
)
|
||||
env = StreetFighterCustomWrapper(env)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
|
||||
state_stages = [
|
||||
"Champion.Level1.RyuVsGuile",
|
||||
"Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
"Champion.Level5.ChunLiVsDhalsim",
|
||||
"Champion.Level6.ChunLiVsRyu",
|
||||
"Champion.Level7.ChunLiVsEHonda",
|
||||
"Champion.Level8.ChunLiVsBlanka",
|
||||
"Champion.Level9.ChunLiVsBalrog",
|
||||
"Champion.Level10.ChunLiVsVega",
|
||||
"Champion.Level11.ChunLiVsSagat",
|
||||
"Champion.Level12.ChunLiVsBison"
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
# state_stages = [
|
||||
# "ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6
|
||||
# "ChampionX.Level2.ChunLiVsChunLi",
|
||||
# "ChampionX.Level3.ChunLiVsZangief",
|
||||
# "ChampionX.Level4.ChunLiVsDhalsim",
|
||||
# "ChampionX.Level5.ChunLiVsRyu",
|
||||
# "ChampionX.Level6.ChunLiVsEHonda",
|
||||
# "ChampionX.Level7.ChunLiVsBlanka",
|
||||
# "ChampionX.Level8.ChunLiVsGuile",
|
||||
# "ChampionX.Level9.ChunLiVsBalrog",
|
||||
# "ChampionX.Level10.ChunLiVsVega",
|
||||
# "ChampionX.Level11.ChunLiVsSagat",
|
||||
# "ChampionX.Level12.ChunLiVsBison"
|
||||
# # Add other stages as necessary
|
||||
# ]
|
||||
# Champion is at difficulty level 4, ChampionX is at difficulty level 8.
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
# Warp env in Monitor wrapper to record training progress
|
||||
env = Monitor(env, LOG_DIR)
|
||||
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
verbose=1,
|
||||
n_steps=1024,
|
||||
batch_size=64,
|
||||
learning_rate=1e-4,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2,
|
||||
gamma=0.95,
|
||||
gae_lambda=0.81322,
|
||||
tensorboard_log="logs/"
|
||||
)
|
||||
|
||||
# Set the save directory
|
||||
save_dir = "trained_models_ryu_level_1_reward_x3"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Load the model from file
|
||||
# model_path = "trained_models/ppo_chunli_1296000_steps.zip"
|
||||
|
||||
# Load model and modify the learning rate and entropy coefficient
|
||||
# custom_objects = {
|
||||
# "learning_rate": 0.0002
|
||||
# }
|
||||
# model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects)
|
||||
|
||||
# Set up callbacks
|
||||
# opponent_interval = 35840 # stage_interval * num_envs = total_steps_per_stage
|
||||
checkpoint_interval = 200000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds)
|
||||
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu")
|
||||
# stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir)
|
||||
|
||||
# model_params = {
|
||||
# 'n_steps': 5,
|
||||
# 'gamma': 0.99,
|
||||
# 'gae_lambda':1,
|
||||
# 'learning_rate': 7e-4,
|
||||
# 'vf_coef': 0.5,
|
||||
# 'ent_coef': 0.0,
|
||||
# 'max_grad_norm':0.5,
|
||||
# 'rms_prop_eps':1e-05
|
||||
# }
|
||||
# model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF))
|
||||
|
||||
model.learn(
|
||||
total_timesteps=int(10000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
|
||||
callback=[checkpoint_callback]#, stage_increase_callback]
|
||||
)
|
||||
env.close()
|
||||
|
||||
# Save the final model
|
||||
model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue
Block a user