2023-03-28 07:39:30 +00:00
|
|
|
import gym
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
2023-03-29 17:14:39 +00:00
|
|
|
import torch
|
|
|
|
from torchvision.transforms import Normalize
|
|
|
|
from gym.spaces import MultiBinary
|
2023-03-28 07:39:30 +00:00
|
|
|
|
|
|
|
# Custom environment wrapper
|
|
|
|
class StreetFighterCustomWrapper(gym.Wrapper):
|
2023-03-28 18:27:06 +00:00
|
|
|
def __init__(self, env, win_template, lose_template, testing=False, threshold=0.65):
|
2023-03-28 07:39:30 +00:00
|
|
|
super(StreetFighterCustomWrapper, self).__init__(env)
|
2023-03-29 17:14:39 +00:00
|
|
|
self.action_space = MultiBinary(12)
|
|
|
|
|
|
|
|
# self.win_template = win_template
|
|
|
|
# self.lose_template = lose_template
|
2023-03-28 07:39:30 +00:00
|
|
|
self.threshold = threshold
|
|
|
|
self.game_screen_gray = None
|
|
|
|
|
|
|
|
self.prev_player_health = 1.0
|
|
|
|
self.prev_opponent_health = 1.0
|
|
|
|
|
|
|
|
# Update observation space to single-channel grayscale image
|
2023-03-29 17:14:39 +00:00
|
|
|
# self.observation_space = gym.spaces.Box(
|
|
|
|
# low=0.0, high=1.0, shape=(84, 84, 1), dtype=np.float32
|
|
|
|
# )
|
|
|
|
|
|
|
|
# observation_space for mobilenet
|
2023-03-28 07:39:30 +00:00
|
|
|
self.observation_space = gym.spaces.Box(
|
2023-03-29 17:14:39 +00:00
|
|
|
low=0.0, high=1.0, shape=(3, 96, 96), dtype=np.float32
|
2023-03-28 07:39:30 +00:00
|
|
|
)
|
2023-03-28 18:27:06 +00:00
|
|
|
|
|
|
|
self.testing = testing
|
2023-03-29 17:14:39 +00:00
|
|
|
|
|
|
|
# Normalize the image for MobileNetV3Small.
|
|
|
|
self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
2023-03-28 07:39:30 +00:00
|
|
|
|
|
|
|
def _preprocess_observation(self, observation):
|
2023-03-29 17:14:39 +00:00
|
|
|
# self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
|
|
|
# resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
|
|
|
|
# return np.expand_dims(resized_image, axis=-1)
|
|
|
|
|
|
|
|
# # Using MobileNetV3Small.
|
2023-03-28 07:39:30 +00:00
|
|
|
self.game_screen_gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)
|
2023-03-29 17:14:39 +00:00
|
|
|
resized_image = cv2.resize(observation, (96, 96), interpolation=cv2.INTER_AREA).astype(np.float32) / 255.0
|
|
|
|
|
|
|
|
# Convert the NumPy array to a PyTorch tensor
|
|
|
|
resized_image = torch.from_numpy(resized_image).permute(2, 0, 1)
|
|
|
|
|
|
|
|
# Apply normalization
|
|
|
|
resized_image = self.normalize(resized_image)
|
|
|
|
|
|
|
|
# # Add a batch dimension to match the model input shape
|
|
|
|
# # resized_image = resized_image.unsqueeze(0)
|
|
|
|
return resized_image
|
2023-03-28 18:27:06 +00:00
|
|
|
|
|
|
|
def _get_win_or_lose_bonus(self):
|
|
|
|
if self.prev_player_health > self.prev_opponent_health:
|
|
|
|
# print('You win!')
|
2023-03-29 17:14:39 +00:00
|
|
|
return 300
|
2023-03-28 18:27:06 +00:00
|
|
|
else:
|
|
|
|
# print('You lose!')
|
2023-03-29 17:14:39 +00:00
|
|
|
return -300
|
2023-03-28 07:39:30 +00:00
|
|
|
|
|
|
|
def _get_reward(self):
|
|
|
|
player_health_area = self.game_screen_gray[15:20, 32:120]
|
|
|
|
oppoent_health_area = self.game_screen_gray[15:20, 136:224]
|
|
|
|
|
|
|
|
# Get health points using the number of pixels above 129.
|
|
|
|
player_health = np.sum(player_health_area > 129) / player_health_area.size
|
|
|
|
opponent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
|
|
|
|
|
2023-03-28 09:39:01 +00:00
|
|
|
player_health_diff = self.prev_player_health - player_health
|
|
|
|
opponent_health_diff = self.prev_opponent_health - opponent_health
|
|
|
|
|
2023-03-29 17:14:39 +00:00
|
|
|
reward = (opponent_health_diff - player_health_diff) * 200 # max would be 200
|
|
|
|
|
|
|
|
# Penalty for each step without any change in health
|
|
|
|
if opponent_health_diff <= 0.0000001:
|
|
|
|
reward -= 12.0 / 60.0 # -12 points per second if no damage to opponent
|
2023-03-28 09:39:01 +00:00
|
|
|
|
|
|
|
self.prev_player_health = player_health
|
|
|
|
self.prev_opponent_health = opponent_health
|
|
|
|
|
2023-03-28 18:27:06 +00:00
|
|
|
# Print the health values of the player and the opponent
|
|
|
|
# print("Player health: %f Opponent health:%f" % (player_health, opponent_health))
|
2023-03-28 07:39:30 +00:00
|
|
|
return reward
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
observation = self.env.reset()
|
2023-03-28 09:39:01 +00:00
|
|
|
self.prev_player_health = 1.0
|
|
|
|
self.prev_opponent_health = 1.0
|
2023-03-28 07:39:30 +00:00
|
|
|
return self._preprocess_observation(observation)
|
|
|
|
|
|
|
|
def step(self, action):
|
2023-03-28 18:27:06 +00:00
|
|
|
# observation, _, _, info = self.env.step(action)
|
2023-03-29 17:14:39 +00:00
|
|
|
observation, _reward, _done, info = self.env.step(self.env.action_space.sample())
|
2023-03-28 07:39:30 +00:00
|
|
|
custom_reward = self._get_reward()
|
2023-03-29 17:14:39 +00:00
|
|
|
custom_reward -= 1.0 / 60.0 # penalty for each step (-1 points per second)
|
2023-03-28 18:27:06 +00:00
|
|
|
|
|
|
|
custom_done = False
|
|
|
|
if self.prev_player_health <= 0.00001 or self.prev_opponent_health <= 0.00001:
|
|
|
|
custom_reward += self._get_win_or_lose_bonus()
|
|
|
|
if not self.testing:
|
|
|
|
custom_done = True
|
|
|
|
else:
|
|
|
|
self.prev_player_health = 1.0
|
|
|
|
self.prev_opponent_health = 1.0
|
|
|
|
|
2023-03-29 17:14:39 +00:00
|
|
|
return self._preprocess_observation(observation), custom_reward, custom_done, info
|
|
|
|
|