diff --git a/000_image_stack_ram_based_reward/.ipynb_checkpoints/street_fighter_notebook-checkpoint.ipynb b/000_image_stack_ram_based_reward_ai_generated/.ipynb_checkpoints/street_fighter_notebook-checkpoint.ipynb similarity index 100% rename from 000_image_stack_ram_based_reward/.ipynb_checkpoints/street_fighter_notebook-checkpoint.ipynb rename to 000_image_stack_ram_based_reward_ai_generated/.ipynb_checkpoints/street_fighter_notebook-checkpoint.ipynb diff --git a/000_image_stack_ram_based_reward/__pycache__/custom_cnn.cpython-38.pyc b/000_image_stack_ram_based_reward_ai_generated/__pycache__/custom_cnn.cpython-38.pyc similarity index 100% rename from 000_image_stack_ram_based_reward/__pycache__/custom_cnn.cpython-38.pyc rename to 000_image_stack_ram_based_reward_ai_generated/__pycache__/custom_cnn.cpython-38.pyc diff --git a/000_image_stack_ram_based_reward/__pycache__/rmsprop_optim.cpython-38.pyc b/000_image_stack_ram_based_reward_ai_generated/__pycache__/rmsprop_optim.cpython-38.pyc similarity index 100% rename from 000_image_stack_ram_based_reward/__pycache__/rmsprop_optim.cpython-38.pyc rename to 000_image_stack_ram_based_reward_ai_generated/__pycache__/rmsprop_optim.cpython-38.pyc diff --git a/000_image_stack_ram_based_reward/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc b/000_image_stack_ram_based_reward_ai_generated/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc similarity index 100% rename from 000_image_stack_ram_based_reward/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc rename to 000_image_stack_ram_based_reward_ai_generated/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc diff --git a/000_image_stack_ram_based_reward/check_reward.py b/000_image_stack_ram_based_reward_ai_generated/check_reward.py similarity index 100% rename from 000_image_stack_ram_based_reward/check_reward.py rename to 000_image_stack_ram_based_reward_ai_generated/check_reward.py diff --git a/000_image_stack_ram_based_reward/custom_cnn.py b/000_image_stack_ram_based_reward_ai_generated/custom_cnn.py similarity index 100% rename from 000_image_stack_ram_based_reward/custom_cnn.py rename to 000_image_stack_ram_based_reward_ai_generated/custom_cnn.py diff --git a/000_image_stack_ram_based_reward/evaluate.py b/000_image_stack_ram_based_reward_ai_generated/evaluate.py similarity index 100% rename from 000_image_stack_ram_based_reward/evaluate.py rename to 000_image_stack_ram_based_reward_ai_generated/evaluate.py diff --git a/000_image_stack_ram_based_reward/logs/PPO_29/events.out.tfevents.1680204062.DESKTOP-9E17TO7.19212.0 b/000_image_stack_ram_based_reward_ai_generated/logs/PPO_29/events.out.tfevents.1680204062.DESKTOP-9E17TO7.19212.0 similarity index 100% rename from 000_image_stack_ram_based_reward/logs/PPO_29/events.out.tfevents.1680204062.DESKTOP-9E17TO7.19212.0 rename to 000_image_stack_ram_based_reward_ai_generated/logs/PPO_29/events.out.tfevents.1680204062.DESKTOP-9E17TO7.19212.0 diff --git a/000_image_stack_ram_based_reward/logs/PPO_30/events.out.tfevents.1680229915.DESKTOP-9E17TO7.2720.0 b/000_image_stack_ram_based_reward_ai_generated/logs/PPO_30/events.out.tfevents.1680229915.DESKTOP-9E17TO7.2720.0 similarity index 100% rename from 000_image_stack_ram_based_reward/logs/PPO_30/events.out.tfevents.1680229915.DESKTOP-9E17TO7.2720.0 rename to 000_image_stack_ram_based_reward_ai_generated/logs/PPO_30/events.out.tfevents.1680229915.DESKTOP-9E17TO7.2720.0 diff --git a/000_image_stack_ram_based_reward/logs/monitor.csv b/000_image_stack_ram_based_reward_ai_generated/logs/monitor.csv similarity index 100% rename from 000_image_stack_ram_based_reward/logs/monitor.csv rename to 000_image_stack_ram_based_reward_ai_generated/logs/monitor.csv diff --git a/000_image_stack_ram_based_reward/optuna/tuning_log.txt b/000_image_stack_ram_based_reward_ai_generated/optuna/tuning_log.txt similarity index 100% rename from 000_image_stack_ram_based_reward/optuna/tuning_log.txt rename to 000_image_stack_ram_based_reward_ai_generated/optuna/tuning_log.txt diff --git a/000_image_stack_ram_based_reward/rmsprop_optim.py b/000_image_stack_ram_based_reward_ai_generated/rmsprop_optim.py similarity index 100% rename from 000_image_stack_ram_based_reward/rmsprop_optim.py rename to 000_image_stack_ram_based_reward_ai_generated/rmsprop_optim.py diff --git a/000_image_stack_ram_based_reward/street_fighter_custom_wrapper.py b/000_image_stack_ram_based_reward_ai_generated/street_fighter_custom_wrapper.py similarity index 100% rename from 000_image_stack_ram_based_reward/street_fighter_custom_wrapper.py rename to 000_image_stack_ram_based_reward_ai_generated/street_fighter_custom_wrapper.py diff --git a/000_image_stack_ram_based_reward/street_fighter_notebook.ipynb b/000_image_stack_ram_based_reward_ai_generated/street_fighter_notebook.ipynb similarity index 100% rename from 000_image_stack_ram_based_reward/street_fighter_notebook.ipynb rename to 000_image_stack_ram_based_reward_ai_generated/street_fighter_notebook.ipynb diff --git a/000_image_stack_ram_based_reward/test.py b/000_image_stack_ram_based_reward_ai_generated/test.py similarity index 100% rename from 000_image_stack_ram_based_reward/test.py rename to 000_image_stack_ram_based_reward_ai_generated/test.py diff --git a/000_image_stack_ram_based_reward/train.py b/000_image_stack_ram_based_reward_ai_generated/train.py similarity index 100% rename from 000_image_stack_ram_based_reward/train.py rename to 000_image_stack_ram_based_reward_ai_generated/train.py diff --git a/000_image_stack_ram_based_reward/trained_models/training_logs.txt b/000_image_stack_ram_based_reward_ai_generated/trained_models/training_logs.txt similarity index 100% rename from 000_image_stack_ram_based_reward/trained_models/training_logs.txt rename to 000_image_stack_ram_based_reward_ai_generated/trained_models/training_logs.txt diff --git a/000_image_stack_ram_based_reward/trained_models_level_1/training_logs.txt b/000_image_stack_ram_based_reward_ai_generated/trained_models_level_1/training_logs.txt similarity index 100% rename from 000_image_stack_ram_based_reward/trained_models_level_1/training_logs.txt rename to 000_image_stack_ram_based_reward_ai_generated/trained_models_level_1/training_logs.txt diff --git a/000_image_stack_ram_based_reward/tune.py b/000_image_stack_ram_based_reward_ai_generated/tune.py similarity index 100% rename from 000_image_stack_ram_based_reward/tune.py rename to 000_image_stack_ram_based_reward_ai_generated/tune.py diff --git a/000_image_stack_ram_based_reward/tune_ppo.py b/000_image_stack_ram_based_reward_ai_generated/tune_ppo.py similarity index 100% rename from 000_image_stack_ram_based_reward/tune_ppo.py rename to 000_image_stack_ram_based_reward_ai_generated/tune_ppo.py diff --git a/004_custom_policy/custom_cnn.py b/004_custom_policy/custom_cnn.py deleted file mode 100644 index 9ef0a3b..0000000 --- a/004_custom_policy/custom_cnn.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch.nn as nn - -def conv2d_custom_init(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False): - conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) - nn.init.xavier_uniform_(conv.weight) - return conv - -def custom_conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False): - return nn.Sequential( - conv2d_custom_init(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) - nn.Relu(), - nn.MaxPool2d((2, 2)) - ) - -# Custom feature extractor (CNN) -class CustomCNN(nn.Module): - def __init__(self, num_frames, num_moves, num_attacks): - super(CustomCNN, self).__init__() - self.num_moves = num_moves - self.num_attacks = num_attacks - self.cnn = nn.Sequential( - nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), - nn.ReLU(), - nn.Flatten(), - nn.Linear(16384, self.features_dim), - nn.ReLU() - ) - - def forward(self, observations: torch.Tensor) -> torch.Tensor: - return self.cnn(observations) - \ No newline at end of file diff --git a/004_image_stack_ram_based_reward_custom/__pycache__/custom_cnn.cpython-38.pyc b/004_image_stack_ram_based_reward_custom/__pycache__/custom_cnn.cpython-38.pyc new file mode 100644 index 0000000..8e0f165 Binary files /dev/null and b/004_image_stack_ram_based_reward_custom/__pycache__/custom_cnn.cpython-38.pyc differ diff --git a/004_image_stack_ram_based_reward_custom/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc b/004_image_stack_ram_based_reward_custom/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc new file mode 100644 index 0000000..d473d0a Binary files /dev/null and b/004_image_stack_ram_based_reward_custom/__pycache__/street_fighter_custom_wrapper.cpython-38.pyc differ diff --git a/004_image_stack_ram_based_reward_custom/check_reward.py b/004_image_stack_ram_based_reward_custom/check_reward.py new file mode 100644 index 0000000..0821924 --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/check_reward.py @@ -0,0 +1,47 @@ +import time + +import retro +from stable_baselines3.common.monitor import Monitor + +from street_fighter_custom_wrapper import StreetFighterCustomWrapper + +def make_env(game, state): + def _init(): + env = retro.make( + game=game, + state=state, + use_restricted_actions=retro.Actions.FILTERED, + obs_type=retro.Observations.IMAGE + ) + env = StreetFighterCustomWrapper(env) + return env + return _init + +game = "StreetFighterIISpecialChampionEdition-Genesis" +state = "Champion.Level1.RyuVsGuile" + +env = make_env(game, state)() +env = Monitor(env, 'logs/') + +num_episodes = 30 +episode_reward_sum = 0 +for _ in range(num_episodes): + done = False + obs = env.reset() + total_reward = 0 + while not done: + timestamp = time.time() + obs, reward, done, info = env.step(env.action_space.sample()) + + # Note that if player wins but only has 0 HP left, the winning reward is still 0, so it won't be printed. + if reward != 0: + total_reward += reward + print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health'])) + env.render() + # time.sleep(0.005) + + print("Total reward: {}".format(total_reward)) + episode_reward_sum += total_reward + +env.close() +print("Average reward for random strategy: {}".format(episode_reward_sum/num_episodes)) diff --git a/004_image_stack_ram_based_reward_custom/custom_cnn.py b/004_image_stack_ram_based_reward_custom/custom_cnn.py new file mode 100644 index 0000000..25c50ea --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/custom_cnn.py @@ -0,0 +1,24 @@ +import gym +import torch +import torch.nn as nn +from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + +# Custom feature extractor (CNN) +class CustomCNN(BaseFeaturesExtractor): + def __init__(self, observation_space: gym.Space): + super(CustomCNN, self).__init__(observation_space, features_dim=512) + self.cnn = nn.Sequential( + nn.Conv2d(4, 32, kernel_size=5, stride=2, padding=0), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0), + nn.ReLU(), + nn.Flatten(), + nn.Linear(16384, self.features_dim), + nn.ReLU() + ) + + def forward(self, observations: torch.Tensor) -> torch.Tensor: + return self.cnn(observations) + \ No newline at end of file diff --git a/004_image_stack_ram_based_reward_custom/evaluate.py b/004_image_stack_ram_based_reward_custom/evaluate.py new file mode 100644 index 0000000..c435f08 --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/evaluate.py @@ -0,0 +1,52 @@ +import retro + +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.evaluation import evaluate_policy + +from custom_cnn import CustomCNN +from street_fighter_custom_wrapper import StreetFighterCustomWrapper + +def make_env(game, state): + def _init(): + env = retro.make( + game=game, + state=state, + use_restricted_actions=retro.Actions.FILTERED, + obs_type=retro.Observations.IMAGE + ) + env = StreetFighterCustomWrapper(env) + return env + return _init + +game = "StreetFighterIISpecialChampionEdition-Genesis" +state_stages = [ + "Champion.Level1.ChunLiVsGuile", + "Champion.Level2.ChunLiVsKen", + "Champion.Level3.ChunLiVsChunLi", + "Champion.Level4.ChunLiVsZangief", + "Champion.Level5.ChunLiVsDhalsim", + "Champion.Level6.ChunLiVsRyu", + "Champion.Level7.ChunLiVsEHonda", + "Champion.Level8.ChunLiVsBlanka", + "Champion.Level9.ChunLiVsBalrog", + "Champion.Level10.ChunLiVsVega", + "Champion.Level11.ChunLiVsSagat", + "Champion.Level12.ChunLiVsBison" + # Add other stages as necessary +] + +env = make_env(game, state_stages[0])() + +# Wrap the environment +# env = Monitor(env, 'logs/') + +policy_kwargs = {'features_extractor_class': CustomCNN} +model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs) + +model = PPO.load(r"dummy_model_ppo_chunli") +# model.load(r"trained_models/ppo_chunli_864000_steps") + +mean_reward, std_reward = evaluate_policy(model, env, render=True, n_eval_episodes=10, deterministic=False, return_episode_rewards=True) +print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}") diff --git a/004_image_stack_ram_based_reward_custom/logs/PPO_1/events.out.tfevents.1680427238.DESKTOP-9E17TO7.27420.0 b/004_image_stack_ram_based_reward_custom/logs/PPO_1/events.out.tfevents.1680427238.DESKTOP-9E17TO7.27420.0 new file mode 100644 index 0000000..eaa22b1 Binary files /dev/null and b/004_image_stack_ram_based_reward_custom/logs/PPO_1/events.out.tfevents.1680427238.DESKTOP-9E17TO7.27420.0 differ diff --git a/004_image_stack_ram_based_reward_custom/logs/PPO_2/events.out.tfevents.1680442574.DESKTOP-9E17TO7.8472.0 b/004_image_stack_ram_based_reward_custom/logs/PPO_2/events.out.tfevents.1680442574.DESKTOP-9E17TO7.8472.0 new file mode 100644 index 0000000..cbd9cc2 Binary files /dev/null and b/004_image_stack_ram_based_reward_custom/logs/PPO_2/events.out.tfevents.1680442574.DESKTOP-9E17TO7.8472.0 differ diff --git a/004_image_stack_ram_based_reward_custom/logs/PPO_3/events.out.tfevents.1680450538.DESKTOP-9E17TO7.4520.0 b/004_image_stack_ram_based_reward_custom/logs/PPO_3/events.out.tfevents.1680450538.DESKTOP-9E17TO7.4520.0 new file mode 100644 index 0000000..0d60db2 Binary files /dev/null and b/004_image_stack_ram_based_reward_custom/logs/PPO_3/events.out.tfevents.1680450538.DESKTOP-9E17TO7.4520.0 differ diff --git a/004_image_stack_ram_based_reward_custom/logs/monitor.csv b/004_image_stack_ram_based_reward_custom/logs/monitor.csv new file mode 100644 index 0000000..680323f --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/logs/monitor.csv @@ -0,0 +1,203 @@ +#{"t_start": 1680450538.0135336, "env_id": null} +r,l,t +140,2721,9.666247 +475,2051,16.014586 +659,1841,22.078956 +767,1374,25.788732 +253,2591,35.3779 +337,2369,42.159742 +280,2490,51.242315 +316,2761,59.177625 +498,1532,65.235702 +794,1294,68.963929 +758,1295,72.564837 +343,1862,78.748257 +366,1293,82.443477 +344,3356,95.206057 +655,1716,99.877595 +717,1496,105.722482 +410,2494,113.515226 +906,1496,119.400876 +230,1769,124.036038 +966,1043,127.542873 +321,2284,136.329181 +812,1519,140.423345 +519,1729,146.851307 +722,1412,150.934866 +700,1646,157.093265 +387,1965,163.827162 +343,2604,171.611118 +123,1625,177.426801 +417,3407,187.666056 +187,2616,196.882013 +26,2142,203.498234 +1056,1218,207.239627 +927,1124,210.857355 +47,1994,217.59863 +510,1774,223.863903 +377,1381,227.737895 +481,1390,233.374429 +658,1447,237.396851 +349,3437,248.157832 +718,1829,254.408287 +180,2025,261.068938 +313,1839,267.477898 +902,975,270.676313 +292,2950,280.273983 +182,2268,287.275446 +187,1953,293.603531 +362,3130,303.389702 +638,1779,309.646411 +228,2179,316.57769 +341,2109,323.229764 +85,2241,330.295613 +260,3871,343.38057 +-72,2846,353.025726 +319,2520,360.176952 +121,2334,367.161462 +-143,2696,376.380194 +-10,2199,383.130368 +254,3800,395.603224 +51,3540,406.048309 +340,2647,415.170668 +331,6905,437.351264 +741,1698,441.462111 +331,3519,454.219213 +489,2276,461.541512 +368,3599,472.906958 +-126,3151,482.855691 +-235,3220,494.629744 +-17,2824,502.354869 +-82,3238,514.654498 +-160,3254,525.209512 +-150,2923,535.200898 +-93,7337,558.950392 +-275,2822,566.908611 +245,3758,579.955454 +-117,2586,589.282403 +-230,2788,597.257352 +90,2896,607.108898 +967,1681,613.189581 +-162,3178,623.550789 +-111,2603,632.824294 +-111,2160,639.682543 +-91,3059,649.811759 +552,3472,660.344764 +156,2905,669.981916 +506,1973,676.494263 +-29,3483,687.2097 +-342,2388,694.675364 +-98,7522,720.510979 +268,3198,730.537025 +123,7866,755.786125 +41,3279,766.078602 +454,3852,778.612931 +241,3330,788.852483 +-197,2695,796.350531 +16,2708,805.798072 +-276,2921,815.975549 +230,7838,842.84947 +-343,2955,852.604732 +-124,2710,860.170467 +-318,1477,866.041014 +-37,3970,877.800582 +-26,2989,888.043861 +732,8207,914.516506 +326,3080,924.643157 +-189,2372,933.448068 +275,8785,961.099634 +-189,2881,970.522553 +108,3190,980.352367 +-351,2851,989.817974 +-287,3287,1000.238004 +-262,2847,1010.036809 +-206,7848,1034.80143 +102,7532,1058.580046 +0,8037,1084.80434 +110,4073,1097.11781 +421,3413,1107.327183 +-203,3154,1117.49272 +655,14205,1161.34724 +126,3993,1174.695066 +48,3832,1187.983629 +68,2995,1197.869093 +80,3252,1207.617986 +84,3776,1219.716212 +-192,3176,1229.660275 +-143,2819,1237.98842 +10,2730,1247.57504 +-191,2460,1256.580759 +-28,2546,1263.612297 +192,3534,1273.532951 +268,3797,1285.940749 +-98,3139,1296.872381 +75,3568,1309.84493 +-123,7274,1332.694059 +326,3440,1342.804182 +349,3737,1355.733267 +22,2943,1366.020118 +-202,3018,1375.968116 +888,1928,1382.114744 +-209,1646,1386.164465 +46,1613,1391.748134 +-318,2434,1398.519993 +-275,2288,1405.521706 +397,3578,1418.447945 +317,2150,1425.714679 +-75,2716,1435.669649 +-93,2679,1442.908609 +564,2987,1452.369631 +216,2904,1461.84269 +44,2300,1468.801676 +401,1470,1474.693752 +381,3590,1485.786877 +256,2522,1494.610591 +-141,1773,1498.837729 +335,2651,1507.840033 +860,1561,1511.859426 +357,1743,1517.856122 +846,1433,1523.547 +702,720,1524.743536 +81,3314,1535.593991 +608,1468,1541.5119 +464,2507,1549.079192 +382,1465,1554.894497 +661,2153,1561.58434 +-220,2172,1568.434792 +470,2597,1577.624233 +606,1471,1581.617123 +128,2485,1589.245833 +-151,2076,1596.189308 +-34,1775,1602.548944 +7,2518,1611.704006 +-73,1256,1615.417475 +981,952,1618.628284 +537,1555,1622.779646 +336,2464,1631.451718 +490,2070,1638.048966 +337,3439,1648.578499 +367,2505,1657.904252 +365,2554,1665.472756 +654,1061,1669.007638 +334,3193,1679.491146 +-125,1751,1685.74332 +342,2740,1695.388833 +541,1674,1699.803759 +303,3218,1709.898359 +62,2140,1716.650506 +37,1838,1722.973549 +-9,2999,1732.744556 +-47,2898,1742.188218 +462,3518,1752.401364 +206,2255,1760.763199 +494,2294,1767.674661 +198,2530,1775.444748 +149,2196,1782.305408 +593,2317,1791.355 +349,2208,1798.589246 +-74,1673,1804.620136 +41,2712,1811.981201 +432,2759,1821.619351 +75,2880,1831.46073 +397,1858,1837.717627 +-204,3008,1845.908291 diff --git a/004_image_stack_ram_based_reward_custom/street_fighter_custom_wrapper.py b/004_image_stack_ram_based_reward_custom/street_fighter_custom_wrapper.py new file mode 100644 index 0000000..3b3e622 --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/street_fighter_custom_wrapper.py @@ -0,0 +1,79 @@ +import collections + +import gym +import numpy as np + +# Custom environment wrapper +class StreetFighterCustomWrapper(gym.Wrapper): + def __init__(self, env, testing=False): + super(StreetFighterCustomWrapper, self).__init__(env) + self.env = env + + # Use a deque to store the last 4 frames + self.num_frames = 3 + self.frame_stack = collections.deque(maxlen=self.num_frames) + + self.reward_coeff = 3 + + self.full_hp = 176 + self.prev_player_health = self.full_hp + self.prev_oppont_health = self.full_hp + + # Update observation space to include stacked grayscale images + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(100, 128, 3), dtype=np.uint8) + + self.testing = testing + + def _preprocess_observation(self, observation): + + # Stack the downsampled frames. + self.frame_stack.append(observation[::2, ::2, :]) + + # Stack the R, G, B channel of each frame and return the "image". + stacked_image = np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1) + return stacked_image + + def reset(self): + observation = self.env.reset() + self.prev_player_health = self.full_hp + self.prev_oppont_health = self.full_hp + + # Clear the frame stack and add the first observation [num_frames] times + self.frame_stack.clear() + for _ in range(self.num_frames): + self.frame_stack.append(observation[::2, ::2, :]) + + return np.stack([frame[:, :, i] for i, frame in enumerate(self.frame_stack)], axis=-1) + + def step(self, action): + + obs, _reward, _done, info = self.env.step(action) + curr_player_health = info['health'] + curr_oppont_health = info['enemy_health'] + + # Game is over and player loses. + if curr_player_health < 0: + custom_reward = -curr_oppont_health # Use the remaining health points of opponent as penalty. + # If the opponent also has negative health points, it's a even game and the reward is +1. + custom_done = True + + # Game is over and player wins. + elif curr_oppont_health < 0: + custom_reward = curr_player_health * self.reward_coeff # Use the remaining health points of player as reward. + # Multiply by reward_coeff to make the reward larger than the penalty to avoid cowardice of agent. + custom_done = True + + # While the fighting is still going on. + else: + custom_reward = self.reward_coeff * (self.prev_oppont_health - curr_oppont_health) - (self.prev_player_health - curr_player_health) + self.prev_player_health = curr_player_health + self.prev_oppont_health = curr_oppont_health + custom_done = False + + # During testing, the session should always keep going. + if self.testing: + custom_done = False + + # Max reward is 6 * full_hp = 1054 (damage * 3 + winning_reward * 3) + return self._preprocess_observation(obs), custom_reward, custom_done, info + \ No newline at end of file diff --git a/004_image_stack_ram_based_reward_custom/test.py b/004_image_stack_ram_based_reward_custom/test.py new file mode 100644 index 0000000..7ad01a4 --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/test.py @@ -0,0 +1,76 @@ +import time + +import retro +from stable_baselines3 import PPO + +from street_fighter_custom_wrapper import StreetFighterCustomWrapper + +def make_env(game, state): + def _init(): + env = retro.make( + game=game, + state=state, + use_restricted_actions=retro.Actions.FILTERED, + obs_type=retro.Observations.IMAGE + ) + env = StreetFighterCustomWrapper(env) + return env + return _init + +game = "StreetFighterIISpecialChampionEdition-Genesis" +state_stages = [ + "Champion.Level1.RyuVsGuile", + "Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3 | -20.4 + "ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6 + "Champion.Level2.ChunLiVsKen", + "Champion.Level3.ChunLiVsChunLi", + "Champion.Level4.ChunLiVsZangief", + "Champion.Level5.ChunLiVsDhalsim", + "Champion.Level6.ChunLiVsRyu", + "Champion.Level7.ChunLiVsEHonda", + "Champion.Level8.ChunLiVsBlanka", + "Champion.Level9.ChunLiVsBalrog", + "Champion.Level10.ChunLiVsVega", + "Champion.Level11.ChunLiVsSagat", + "Champion.Level12.ChunLiVsBison" + # Add other stages as necessary +] + +env = make_env(game, state_stages[0])() + +model = PPO( + "CnnPolicy", + env, + verbose=1 +) +model_path = r"trained_models_level_1/ppo_ryu_000000_steps" +model.load(model_path) +# Average reward for optuna/trial_1_best_model: -82.3 +# Average reward for optuna/trial_9_best_model: 36.7 | -86.23 +# Average reward for trained_models/ppo_chunli_5376000_steps: -77.8 + + +obs = env.reset() +done = False + +num_episodes = 30 +episode_reward_sum = 0 +for _ in range(num_episodes): + done = False + obs = env.reset() + total_reward = 0 + while not done: + timestamp = time.time() + action, _states = model.predict(obs) + obs, reward, done, info = env.step(action) + + if reward != 0: + total_reward += reward + print("Reward: {}, playerHP: {}, enemyHP:{}".format(reward, info['health'], info['enemy_health'])) + env.render() + time.sleep(0.01) + print("Total reward: {}".format(total_reward)) + episode_reward_sum += total_reward + +env.close() +print("Average reward for {}: {}".format(model_path, episode_reward_sum/num_episodes)) \ No newline at end of file diff --git a/004_image_stack_ram_based_reward_custom/train.py b/004_image_stack_ram_based_reward_custom/train.py new file mode 100644 index 0000000..510861e --- /dev/null +++ b/004_image_stack_ram_based_reward_custom/train.py @@ -0,0 +1,154 @@ +import os +import random + +import retro +from stable_baselines3 import PPO +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback + +from street_fighter_custom_wrapper import StreetFighterCustomWrapper + +LOG_DIR = 'logs/' +os.makedirs(LOG_DIR, exist_ok=True) + +class RandomOpponentChangeCallback(BaseCallback): + def __init__(self, stages, opponent_interval, verbose=0): + super(RandomOpponentChangeCallback, self).__init__(verbose) + self.stages = stages + self.opponent_interval = opponent_interval + + def _on_step(self) -> bool: + if self.n_calls % self.opponent_interval == 0: + new_state = random.choice(self.stages) + print("\nCurrent state:", new_state) + self.training_env.env_method("load_state", new_state, indices=None) + return True + +# class StageIncreaseCallback(BaseCallback): +# def __init__(self, stages, stage_interval, save_dir, verbose=0): +# super(StageIncreaseCallback, self).__init__(verbose) +# self.stages = stages +# self.stage_interval = stage_interval +# self.save_dir = save_dir +# self.current_stage = 0 + +# def _on_step(self) -> bool: +# if self.n_calls % self.stage_interval == 0 and self.current_stage < len(self.stages) - 1: +# self.current_stage += 1 +# new_state = self.stages[self.current_stage] +# self.training_env.env_method("load_state", new_state, indices=None) +# self.model.save(os.path.join(self.save_dir, f"ppo_chunli_stage_{self.current_stage}.zip")) +# return True + +def make_env(game, state): + def _init(): + env = retro.make( + game=game, + state=state, + use_restricted_actions=retro.Actions.FILTERED, + obs_type=retro.Observations.IMAGE + ) + env = StreetFighterCustomWrapper(env) + return env + return _init + +def main(): + # Set up the environment and model + game = "StreetFighterIISpecialChampionEdition-Genesis" + + state_stages = [ + "Champion.Level1.RyuVsGuile", + "Champion.Level1.ChunLiVsGuile", # Average reward for random strategy: -102.3 + "Champion.Level2.ChunLiVsKen", + "Champion.Level3.ChunLiVsChunLi", + "Champion.Level4.ChunLiVsZangief", + "Champion.Level5.ChunLiVsDhalsim", + "Champion.Level6.ChunLiVsRyu", + "Champion.Level7.ChunLiVsEHonda", + "Champion.Level8.ChunLiVsBlanka", + "Champion.Level9.ChunLiVsBalrog", + "Champion.Level10.ChunLiVsVega", + "Champion.Level11.ChunLiVsSagat", + "Champion.Level12.ChunLiVsBison" + # Add other stages as necessary + ] + + # state_stages = [ + # "ChampionX.Level1.ChunLiVsKen", # Average reward for random strategy: -247.6 + # "ChampionX.Level2.ChunLiVsChunLi", + # "ChampionX.Level3.ChunLiVsZangief", + # "ChampionX.Level4.ChunLiVsDhalsim", + # "ChampionX.Level5.ChunLiVsRyu", + # "ChampionX.Level6.ChunLiVsEHonda", + # "ChampionX.Level7.ChunLiVsBlanka", + # "ChampionX.Level8.ChunLiVsGuile", + # "ChampionX.Level9.ChunLiVsBalrog", + # "ChampionX.Level10.ChunLiVsVega", + # "ChampionX.Level11.ChunLiVsSagat", + # "ChampionX.Level12.ChunLiVsBison" + # # Add other stages as necessary + # ] + # Champion is at difficulty level 4, ChampionX is at difficulty level 8. + + env = make_env(game, state_stages[0])() + + # Warp env in Monitor wrapper to record training progress + env = Monitor(env, LOG_DIR) + + model = PPO( + "CnnPolicy", + env, + device="cuda", + verbose=1, + n_steps=1024, + batch_size=64, + learning_rate=1e-4, + ent_coef=0.01, + clip_range=0.2, + gamma=0.95, + gae_lambda=0.81322, + tensorboard_log="logs/" + ) + + # Set the save directory + save_dir = "trained_models_ryu_level_1_reward_x3" + os.makedirs(save_dir, exist_ok=True) + + # Load the model from file + # model_path = "trained_models/ppo_chunli_1296000_steps.zip" + + # Load model and modify the learning rate and entropy coefficient + # custom_objects = { + # "learning_rate": 0.0002 + # } + # model = PPO.load(model_path, env=env, device="cuda")#, custom_objects=custom_objects) + + # Set up callbacks + # opponent_interval = 35840 # stage_interval * num_envs = total_steps_per_stage + checkpoint_interval = 200000 # checkpoint_interval * num_envs = total_steps_per_checkpoint (Every 80 rounds) + checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu") + # stage_increase_callback = RandomOpponentChangeCallback(state_stages, opponent_interval, save_dir) + + # model_params = { + # 'n_steps': 5, + # 'gamma': 0.99, + # 'gae_lambda':1, + # 'learning_rate': 7e-4, + # 'vf_coef': 0.5, + # 'ent_coef': 0.0, + # 'max_grad_norm':0.5, + # 'rms_prop_eps':1e-05 + # } + # model = A2C('CnnPolicy', env, tensorboard_log='logs/', verbose=1, **model_params, policy_kwargs=dict(optimizer_class=RMSpropTF)) + + model.learn( + total_timesteps=int(10000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds) + callback=[checkpoint_callback]#, stage_increase_callback] + ) + env.close() + + # Save the final model + model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip")) + +if __name__ == "__main__": + main()