mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 23:20:43 +00:00
60 lines
2.4 KiB
Python
60 lines
2.4 KiB
Python
import gym
|
|
import cv2
|
|
import numpy as np
|
|
|
|
# Custom environment wrapper
|
|
class StreetFighterCustomWrapper(gym.Wrapper):
|
|
def __init__(self, env, win_template, lose_template, threshold=0.65):
|
|
super(StreetFighterCustomWrapper, self).__init__(env)
|
|
self.win_template = win_template
|
|
self.lose_template = lose_template
|
|
self.threshold = threshold
|
|
|
|
self.game_screen_gray = None
|
|
|
|
self.prev_player_health = 1.0
|
|
self.prev_opponent_health = 1.0
|
|
|
|
# Update observation space to single-channel grayscale image
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
|
|
)
|
|
|
|
def _preprocess_observation(self, observation):
|
|
self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
|
# Print the size of self.game_screen_gray
|
|
# print("self.game_screen_gray size: ", self.game_screen_gray.shape)
|
|
# Print the size of the observation
|
|
# print("Observation size: ", observation.shape)
|
|
resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA)
|
|
return np.expand_dims(resized_image, axis=-1)
|
|
|
|
def _check_game_over(self):
|
|
win_res = cv2.matchTemplate(self.game_screen_gray, self.win_template, cv2.TM_CCOEFF_NORMED)
|
|
lose_res = cv2.matchTemplate(self.game_screen_gray, self.lose_template, cv2.TM_CCOEFF_NORMED)
|
|
if np.max(win_res) >= self.threshold:
|
|
return True
|
|
if np.max(lose_res) >= self.threshold:
|
|
return True
|
|
return False
|
|
|
|
def _get_reward(self):
|
|
player_health_area = self.game_screen_gray[15:20, 32:120]
|
|
oppoent_health_area = self.game_screen_gray[15:20, 136:224]
|
|
|
|
# Get health points using the number of pixels above 129.
|
|
player_health = np.sum(player_health_area > 129) / player_health_area.size
|
|
opponent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
|
|
|
|
reward = player_health - opponent_health
|
|
return reward
|
|
|
|
def reset(self):
|
|
observation = self.env.reset()
|
|
return self._preprocess_observation(observation)
|
|
|
|
def step(self, action):
|
|
observation, _, _, info = self.env.step(action)
|
|
custom_reward = self._get_reward()
|
|
custom_done = self._check_game_over() or False
|
|
return self._preprocess_observation(observation), custom_reward, custom_done, info |