new training and testing scripts based on cv
BIN
__pycache__/custom_cnn.cpython-38.pyc
Normal file
BIN
__pycache__/custom_sf2_cv_env.cpython-38.pyc
Normal file
3
add_custom_rom.py
Normal file
@ -0,0 +1,3 @@
|
||||
import retro
|
||||
rom_directory = "C:/Users/unitec/Documents/AIProjects/street-fighter-ai"
|
||||
retro.data.Integrations.add_custom_path(rom_directory)
|
17
convert_image_to_grayscale.py
Normal file
@ -0,0 +1,17 @@
|
||||
import cv2
|
||||
|
||||
# Convert image to grayscale
|
||||
def convert_image_to_grayscale(input_img_path, output_img_path):
|
||||
# Read the input image
|
||||
img = cv2.imread(input_img_path)
|
||||
|
||||
# Convert the image to grayscale
|
||||
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# Save the grayscale image
|
||||
cv2.imwrite(output_img_path, gray_img)
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_img_path = r"images/sf2screen_fight.png"
|
||||
output_img_path = r"images/sf2screen_fight_gray.png"
|
||||
convert_image_to_grayscale(input_img_path, output_img_path)
|
23
custom_cnn.py
Normal file
@ -0,0 +1,23 @@
|
||||
import gym
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(3136, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
60
custom_sf2_cv_env.py
Normal file
@ -0,0 +1,60 @@
|
||||
import gym
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# Custom environment wrapper
|
||||
class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
def __init__(self, env, win_template, lose_template, threshold=0.65):
|
||||
super(StreetFighterCustomWrapper, self).__init__(env)
|
||||
self.win_template = win_template
|
||||
self.lose_template = lose_template
|
||||
self.threshold = threshold
|
||||
|
||||
self.game_screen_gray = None
|
||||
|
||||
self.prev_player_health = 1.0
|
||||
self.prev_opponent_health = 1.0
|
||||
|
||||
# Update observation space to single-channel grayscale image
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
|
||||
)
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
# Print the size of self.game_screen_gray
|
||||
# print("self.game_screen_gray size: ", self.game_screen_gray.shape)
|
||||
# Print the size of the observation
|
||||
# print("Observation size: ", observation.shape)
|
||||
resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
return np.expand_dims(resized_image, axis=-1)
|
||||
|
||||
def _check_game_over(self):
|
||||
win_res = cv2.matchTemplate(self.game_screen_gray, self.win_template, cv2.TM_CCOEFF_NORMED)
|
||||
lose_res = cv2.matchTemplate(self.game_screen_gray, self.lose_template, cv2.TM_CCOEFF_NORMED)
|
||||
if np.max(win_res) >= self.threshold:
|
||||
return True
|
||||
if np.max(lose_res) >= self.threshold:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_reward(self):
|
||||
player_health_area = self.game_screen_gray[15:20, 32:120]
|
||||
oppoent_health_area = self.game_screen_gray[15:20, 136:224]
|
||||
|
||||
# Get health points using the number of pixels above 129.
|
||||
player_health = np.sum(player_health_area > 129) / player_health_area.size
|
||||
opponent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
|
||||
|
||||
reward = player_health - opponent_health
|
||||
return reward
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
return self._preprocess_observation(observation)
|
||||
|
||||
def step(self, action):
|
||||
observation, _, _, info = self.env.step(action)
|
||||
custom_reward = self._get_reward()
|
||||
custom_done = self._check_game_over() or False
|
||||
return self._preprocess_observation(observation), custom_reward, custom_done, info
|
BIN
images/pattern_lose.png
Normal file
After Width: | Height: | Size: 363 B |
BIN
images/pattern_lose_gray.png
Normal file
After Width: | Height: | Size: 235 B |
BIN
images/pattern_wins.png
Normal file
After Width: | Height: | Size: 331 B |
BIN
images/pattern_wins_gray.png
Normal file
After Width: | Height: | Size: 292 B |
BIN
images/sf2screen_fight.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
images/sf2screen_fight_gray.png
Normal file
After Width: | Height: | Size: 20 KiB |
BIN
images/sf2screen_fight_labeled.png
Normal file
After Width: | Height: | Size: 33 KiB |
BIN
images/sf2screen_win.png
Normal file
After Width: | Height: | Size: 23 KiB |
BIN
images/sf2screen_you_lose.png
Normal file
After Width: | Height: | Size: 25 KiB |
62
test_cv_sf2_ai.py
Normal file
@ -0,0 +1,62 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import gym
|
||||
import retro
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from custom_cnn import CustomCNN
|
||||
from custom_sf2_cv_env import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
||||
env = StreetFighterCustomWrapper(env, win_template, lose_template)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
"Champion.Level1.ChunLiVsGuile",
|
||||
"Champion.Level2.ChunLiVsKen",
|
||||
"Champion.Level3.ChunLiVsChunLi",
|
||||
"Champion.Level4.ChunLiVsZangief",
|
||||
# Add other stages as necessary
|
||||
]
|
||||
|
||||
env = make_env(game, state_stages[0])()
|
||||
|
||||
# Wrap the environment
|
||||
env = DummyVecEnv([lambda: env])
|
||||
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
model = PPO(
|
||||
"CnnPolicy",
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
)
|
||||
model.load("ppo_sf2_cnn")
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
timestamp = time.time()
|
||||
action, _ = model.predict(obs)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
env.render()
|
||||
render_time = time.time() - timestamp
|
||||
if render_time < 0.0111:
|
||||
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
|
||||
|
||||
env.close()
|
@ -1,11 +1,45 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import gym
|
||||
import retro
|
||||
import numpy as np
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
|
||||
def check_done(screen, win_template, lose_template, threshold=0.65):
|
||||
gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
|
||||
win_res = cv2.matchTemplate(gray_screen, win_template, cv2.TM_CCOEFF_NORMED)
|
||||
lose_res = cv2.matchTemplate(gray_screen, lose_template, cv2.TM_CCOEFF_NORMED)
|
||||
|
||||
if np.max(win_res) >= threshold:
|
||||
print("You win!")
|
||||
return True
|
||||
|
||||
if np.max(lose_res) >= threshold:
|
||||
print("You lose!")
|
||||
return True
|
||||
|
||||
def get_health_points(screen):
|
||||
# Get the player's HP
|
||||
gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
|
||||
player_health_area = gray_screen[15:20, 32:120]
|
||||
oppoent_health_area = gray_screen[15:20, 136:224]
|
||||
|
||||
# Get health points using the number of pixels above 129.
|
||||
player_health = np.sum(player_health_area > 129) / player_health_area.size
|
||||
oppoent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
|
||||
|
||||
# Helper function to get the max and min pixel values.
|
||||
# max_pixel = np.max(player_health_area)
|
||||
# min_pixel = np.min(player_health_area)
|
||||
# avg = (max_pixel + min_pixel) / 2
|
||||
|
||||
return player_health, oppoent_health
|
||||
|
||||
rom_directory = "C:/Users/unitec/Documents/AIProjects/street-fighter-ai"
|
||||
retro.data.Integrations.add_custom_path(rom_directory)
|
||||
|
||||
@ -13,26 +47,31 @@ env = retro.RetroEnv(
|
||||
game='StreetFighterIISpecialChampionEdition-Genesis',
|
||||
state='Champion.Level3.ChunLiVsChunLi'
|
||||
)
|
||||
# Champion.Level1.ChunLiVsGuile
|
||||
# Champion.Level2.ChunLiVsKen
|
||||
# Champion.Level3.ChunLiVsChunLi
|
||||
|
||||
|
||||
env = DummyVecEnv([lambda: env])
|
||||
# env = DummyVecEnv([lambda: env])
|
||||
|
||||
model = PPO("CnnPolicy", env)
|
||||
model.load("trained_models/ppo_sf2_chunli_final")
|
||||
|
||||
obs = env.reset()
|
||||
while True:
|
||||
game_over = False
|
||||
|
||||
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
while not game_over:
|
||||
timestamp = time.time()
|
||||
action, _states = model.predict(obs)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
env.render()
|
||||
screen = env.unwrapped.get_screen()
|
||||
get_health_points(screen)
|
||||
game_over = check_done(screen, win_template, lose_template)
|
||||
render_time = time.time() - timestamp
|
||||
if render_time < 0.0111:
|
||||
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
|
||||
if done:
|
||||
break
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
@ -5,57 +5,23 @@ import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
|
||||
from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first
|
||||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Custom feature extractor (CNN)
|
||||
class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(3136, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
|
||||
# Custom environment wrapper for preprocessing
|
||||
class CustomAtariWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
# self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
observation = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
||||
return np.expand_dims(observation, axis=-1)
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
return self._preprocess_observation(observation)
|
||||
|
||||
def step(self, action):
|
||||
observation, reward, done, info = self.env.step(action)
|
||||
return self._preprocess_observation(observation), reward, done, info
|
||||
from custom_cnn import CustomCNN
|
||||
from custom_sf2_cv_env import StreetFighterCustomWrapper
|
||||
|
||||
def make_env(game, state, seed=0):
|
||||
def _init():
|
||||
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
||||
env = retro.RetroEnv(game=game, state=state, obs_type=retro.Observations.IMAGE)
|
||||
env = CustomAtariWrapper(env)
|
||||
env = StreetFighterCustomWrapper(env, win_template, lose_template)
|
||||
env.seed(seed)
|
||||
return env
|
||||
return _init
|
||||
|
||||
def main():
|
||||
|
||||
# Set up the environment and model
|
||||
game = "StreetFighterIISpecialChampionEdition-Genesis"
|
||||
state_stages = [
|
||||
@ -67,12 +33,11 @@ def main():
|
||||
]
|
||||
|
||||
num_envs = 8
|
||||
seed = 42
|
||||
|
||||
env = SubprocVecEnv([make_env(game, state_stages[0], seed=i) for i in range(num_envs)])
|
||||
|
||||
policy_kwargs = {
|
||||
'features_extractor_class': CustomCNN
|
||||
'features_extractor_class': CustomCNN
|
||||
}
|
||||
|
||||
model = PPO(
|
||||
@ -88,5 +53,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# missing reward function
|