{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bfc79b8c", "metadata": {}, "outputs": [], "source": [ "import retro" ] }, { "cell_type": "code", "execution_count": 2, "id": "c24fbcab", "metadata": {}, "outputs": [], "source": [ "game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n", "state = \"Champion.Level1.ChunLiVsGuile\"\n", "env = retro.make(game=game, state=state)" ] }, { "cell_type": "code", "execution_count": 6, "id": "59839d9c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.action_space.sample()" ] }, { "cell_type": "code", "execution_count": 9, "id": "e068cb0a", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(200, 256, 3)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "env.observation_space.sample().shape" ] }, { "cell_type": "code", "execution_count": 16, "id": "1cb0297f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(200, 256, 3)\n", "{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n" ] } ], "source": [ "observation = env.reset()\n", "print(observation.shape)\n", "\n", "action = env.action_space.sample()\n", "obs, rewards, done, info = env.step(action)\n", "print(info)" ] }, { "cell_type": "code", "execution_count": 19, "id": "0eaa5cc8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MultiBinary(12)\n" ] } ], "source": [ "from gym.spaces import Box, MultiBinary\n", "\n", "print(MultiBinary(12))" ] }, { "cell_type": "code", "execution_count": 37, "id": "49f6cf5c", "metadata": {}, "outputs": [], "source": [ "import cv2\n", "\n", "import gym\n", "import numpy as np\n", "from gym.spaces import Box, MultiBinary\n", "\n", "class StreetFighter(gym.Env):\n", " def __init__(self):\n", " super().__init__()\n", " self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n", " self.action_space = MultiBinary(12)\n", " self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n", " \n", " self.full_hp = 176\n", " self.player_health = self.full_hp\n", " self.oppont_health = self.full_hp\n", " \n", " self.score = 0\n", " \n", " def __preprocess(self, observation):\n", " gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n", " resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n", " return resize\n", "\n", " def step(self, action):\n", "\n", " obs, reward, done, info = self.game.step(action)\n", " custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n", "\n", " # During fighting, either player or opponent has positive health points.\n", " if info['health'] > 0 or info['enemy_health'] > 0:\n", "\n", " # Player Loses\n", " if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n", " reward = (-self.full_hp) * info['enemy_health']\n", "\n", " # Player Wins\n", " elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n", " reward = self.full_hp * info['health']\n", "\n", " # During Fighting\n", " else:\n", " reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n", " \n", " self.player_health = info['health']\n", " self.oppont_health = info['enemy_health']\n", " \n", " return custom_obs, reward, done, info\n", " \n", " def render(self, *args, **kwargs):\n", " self.game.render()\n", " \n", " def reset(self):\n", " obs = self.game.reset()\n", " custom_obs = self.__preprocess(obs)\n", " self.previous_frame = obs\n", " \n", " self.player_health = self.full_hp\n", " self.oppont_health = self.full_hp\n", " return custom_obs\n", "\n", " def close(self):\n", " self.game.close()\n" ] }, { "cell_type": "code", "execution_count": 38, "id": "6ec30177", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(84, 84)\n" ] } ], "source": [ "env.close()\n", "env = StreetFighter()\n", "print(env.observation_space.shape)\n", "env.close()" ] }, { "cell_type": "code", "execution_count": 39, "id": "7d9eab3a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\ProgramData\\Anaconda3\\envs\\StreetFighterAI\\lib\\site-packages\\pyglet\\image\\codecs\\wic.py:289: UserWarning: [WinError -2147417850] Cannot change thread mode after it is set\n", " warnings.warn(str(err))\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-22 154 176\n", "-32 122 176\n", "29 122 147\n", "7 122 140\n", "-31 91 140\n", "29 91 111\n", "-23 68 111\n", "-24 44 111\n", "-24 20 111\n", "31 20 80\n", "10 20 70\n", "45 20 25\n", "5 20 20\n", "-15 5 20\n", "19 5 1\n", "-176 -1 1\n", "46 176 130\n", "7 176 123\n", "-24 152 123\n", "29 152 94\n", "-24 128 94\n", "7 128 87\n", "39 128 48\n", "-31 97 48\n", "36 97 12\n", "-24 73 12\n", "-24 49 12\n", "8624 49 -1\n", "39 176 137\n", "-24 152 137\n", "-23 129 137\n", "-23 106 137\n", "-26 80 137\n", "-24 56 137\n", "-23 33 137\n", "-21 12 137\n", "-12 0 137\n", "-24112 -1 137\n" ] } ], "source": [ "## Checking Rewards functionality\n", "import time\n", "\n", "env = StreetFighter()\n", "obs = env.reset()\n", "done = False\n", "\n", "for game in range(5):\n", " while not done:\n", " if done:\n", " obs = env.reset()\n", " env.render()\n", " obs, reward, done, info = env.step(env.action_space.sample())\n", " if reward != 0:\n", " print(reward, info['health'], info['enemy_health'])\n", " time.sleep(0.01)" ] }, { "cell_type": "code", "execution_count": null, "id": "b1ae8310", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }