street-fighter-ai/test_street_fighter_ai.py

77 lines
2.4 KiB
Python

import time
import cv2
import torch
import gym
import retro
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
def check_done(screen, win_template, lose_template, threshold=0.65):
gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
win_res = cv2.matchTemplate(gray_screen, win_template, cv2.TM_CCOEFF_NORMED)
lose_res = cv2.matchTemplate(gray_screen, lose_template, cv2.TM_CCOEFF_NORMED)
if np.max(win_res) >= threshold:
print("You win!")
return True
if np.max(lose_res) >= threshold:
print("You lose!")
return True
def get_health_points(screen):
# Get the player's HP
gray_screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
player_health_area = gray_screen[15:20, 32:120]
oppoent_health_area = gray_screen[15:20, 136:224]
# Get health points using the number of pixels above 129.
player_health = np.sum(player_health_area > 129) / player_health_area.size
oppoent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
# Helper function to get the max and min pixel values.
# max_pixel = np.max(player_health_area)
# min_pixel = np.min(player_health_area)
# avg = (max_pixel + min_pixel) / 2
return player_health, oppoent_health
rom_directory = "C:/Users/unitec/Documents/AIProjects/street-fighter-ai"
retro.data.Integrations.add_custom_path(rom_directory)
env = retro.RetroEnv(
game='StreetFighterIISpecialChampionEdition-Genesis',
state='Champion.Level3.ChunLiVsChunLi'
)
# Champion.Level1.ChunLiVsGuile
# Champion.Level2.ChunLiVsKen
# Champion.Level3.ChunLiVsChunLi
# env = DummyVecEnv([lambda: env])
model = PPO("CnnPolicy", env)
model.load("trained_models/ppo_sf2_chunli_final")
obs = env.reset()
game_over = False
win_template = cv2.imread('images/pattern_wins_gray.png', cv2.IMREAD_GRAYSCALE)
lose_template = cv2.imread('images/pattern_lose_gray.png', cv2.IMREAD_GRAYSCALE)
while not game_over:
timestamp = time.time()
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()
screen = env.unwrapped.get_screen()
get_health_points(screen)
game_over = check_done(screen, win_template, lose_template)
render_time = time.time() - timestamp
if render_time < 0.0111:
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
env.close()