mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 15:10:43 +00:00
77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
import time
|
|
|
|
import cv2
|
|
import torch
|
|
import gym
|
|
import retro
|
|
import numpy as np
|
|
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
|
|
|
|
|
def check_done(screen, win_template, lose_template, threshold=0.65):
|
|
gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
|
|
win_res = cv2.matchTemplate(gray_screen, win_template, cv2.TM_CCOEFF_NORMED)
|
|
lose_res = cv2.matchTemplate(gray_screen, lose_template, cv2.TM_CCOEFF_NORMED)
|
|
|
|
if np.max(win_res) >= threshold:
|
|
print("You win!")
|
|
return True
|
|
|
|
if np.max(lose_res) >= threshold:
|
|
print("You lose!")
|
|
return True
|
|
|
|
def get_health_points(screen):
|
|
# Get the player's HP
|
|
gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
|
|
player_health_area = gray_screen[15:20, 32:120]
|
|
oppoent_health_area = gray_screen[15:20, 136:224]
|
|
|
|
# Get health points using the number of pixels above 129.
|
|
player_health = np.sum(player_health_area > 129) / player_health_area.size
|
|
oppoent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
|
|
|
|
# Helper function to get the max and min pixel values.
|
|
# max_pixel = np.max(player_health_area)
|
|
# min_pixel = np.min(player_health_area)
|
|
# avg = (max_pixel + min_pixel) / 2
|
|
|
|
return player_health, oppoent_health
|
|
|
|
rom_directory = "C:/Users/unitec/Documents/AIProjects/street-fighter-ai"
|
|
retro.data.Integrations.add_custom_path(rom_directory)
|
|
|
|
env = retro.RetroEnv(
|
|
game='StreetFighterIISpecialChampionEdition-Genesis',
|
|
state='Champion.Level3.ChunLiVsChunLi'
|
|
)
|
|
# Champion.Level1.ChunLiVsGuile
|
|
# Champion.Level2.ChunLiVsKen
|
|
# Champion.Level3.ChunLiVsChunLi
|
|
|
|
# env = DummyVecEnv([lambda: env])
|
|
|
|
model = PPO("CnnPolicy", env)
|
|
model.load("trained_models/ppo_sf2_chunli_final")
|
|
|
|
obs = env.reset()
|
|
game_over = False
|
|
|
|
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
|
|
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
|
|
|
|
while not game_over:
|
|
timestamp = time.time()
|
|
action, _states = model.predict(obs)
|
|
obs, rewards, done, info = env.step(action)
|
|
env.render()
|
|
screen = env.unwrapped.get_screen()
|
|
get_health_points(screen)
|
|
game_over = check_done(screen, win_template, lose_template)
|
|
render_time = time.time() - timestamp
|
|
if render_time < 0.0111:
|
|
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
|
|
|
|
env.close() |