street-fighter-ai/000_image_stack_ram_based_reward/street_fighter_notebook.ipynb

315 lines
7.7 KiB
Plaintext
Raw Normal View History

2023-03-30 18:10:25 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "bfc79b8c",
"metadata": {},
"outputs": [],
"source": [
"import retro"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c24fbcab",
"metadata": {},
"outputs": [],
"source": [
"game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n",
"state = \"Champion.Level1.ChunLiVsGuile\"\n",
"env = retro.make(game=game, state=state)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "59839d9c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.action_space.sample()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e068cb0a",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"(200, 256, 3)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"env.observation_space.sample().shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "1cb0297f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(200, 256, 3)\n",
"{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n"
]
}
],
"source": [
"observation = env.reset()\n",
"print(observation.shape)\n",
"\n",
"action = env.action_space.sample()\n",
"obs, rewards, done, info = env.step(action)\n",
"print(info)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "0eaa5cc8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MultiBinary(12)\n"
]
}
],
"source": [
"from gym.spaces import Box, MultiBinary\n",
"\n",
"print(MultiBinary(12))"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "49f6cf5c",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"\n",
"import gym\n",
"import numpy as np\n",
"from gym.spaces import Box, MultiBinary\n",
"\n",
"class StreetFighter(gym.Env):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n",
" self.action_space = MultiBinary(12)\n",
" self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n",
" \n",
" self.full_hp = 176\n",
" self.player_health = self.full_hp\n",
" self.oppont_health = self.full_hp\n",
" \n",
" self.score = 0\n",
" \n",
" def __preprocess(self, observation):\n",
" gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n",
" resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n",
" return resize\n",
"\n",
" def step(self, action):\n",
"\n",
" obs, reward, done, info = self.game.step(action)\n",
" custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n",
"\n",
" # During fighting, either player or opponent has positive health points.\n",
" if info['health'] > 0 or info['enemy_health'] > 0:\n",
"\n",
" # Player Loses\n",
" if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n",
" reward = (-self.full_hp) * info['enemy_health']\n",
"\n",
" # Player Wins\n",
" elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n",
" reward = self.full_hp * info['health']\n",
"\n",
" # During Fighting\n",
" else:\n",
" reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n",
" \n",
" self.player_health = info['health']\n",
" self.oppont_health = info['enemy_health']\n",
" \n",
" return custom_obs, reward, done, info\n",
" \n",
" def render(self, *args, **kwargs):\n",
" self.game.render()\n",
" \n",
" def reset(self):\n",
" obs = self.game.reset()\n",
" custom_obs = self.__preprocess(obs)\n",
" self.previous_frame = obs\n",
" \n",
" self.player_health = self.full_hp\n",
" self.oppont_health = self.full_hp\n",
" return custom_obs\n",
"\n",
" def close(self):\n",
" self.game.close()\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "6ec30177",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(84, 84)\n"
]
}
],
"source": [
"env.close()\n",
"env = StreetFighter()\n",
"print(env.observation_space.shape)\n",
"env.close()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "7d9eab3a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\ProgramData\\Anaconda3\\envs\\StreetFighterAI\\lib\\site-packages\\pyglet\\image\\codecs\\wic.py:289: UserWarning: [WinError -2147417850] Cannot change thread mode after it is set\n",
" warnings.warn(str(err))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"-22 154 176\n",
"-32 122 176\n",
"29 122 147\n",
"7 122 140\n",
"-31 91 140\n",
"29 91 111\n",
"-23 68 111\n",
"-24 44 111\n",
"-24 20 111\n",
"31 20 80\n",
"10 20 70\n",
"45 20 25\n",
"5 20 20\n",
"-15 5 20\n",
"19 5 1\n",
"-176 -1 1\n",
"46 176 130\n",
"7 176 123\n",
"-24 152 123\n",
"29 152 94\n",
"-24 128 94\n",
"7 128 87\n",
"39 128 48\n",
"-31 97 48\n",
"36 97 12\n",
"-24 73 12\n",
"-24 49 12\n",
"8624 49 -1\n",
"39 176 137\n",
"-24 152 137\n",
"-23 129 137\n",
"-23 106 137\n",
"-26 80 137\n",
"-24 56 137\n",
"-23 33 137\n",
"-21 12 137\n",
"-12 0 137\n",
"-24112 -1 137\n"
]
}
],
"source": [
"## Checking Rewards functionality\n",
"import time\n",
"\n",
"env = StreetFighter()\n",
"obs = env.reset()\n",
"done = False\n",
"\n",
"for game in range(5):\n",
" while not done:\n",
" if done:\n",
" obs = env.reset()\n",
" env.render()\n",
" obs, reward, done, info = env.step(env.action_space.sample())\n",
" if reward != 0:\n",
" print(reward, info['health'], info['enemy_health'])\n",
" time.sleep(0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1ae8310",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}