mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-05 07:30:42 +00:00
315 lines
7.7 KiB
Plaintext
315 lines
7.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "bfc79b8c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import retro"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "c24fbcab",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"game = \"StreetFighterIISpecialChampionEdition-Genesis\"\n",
|
|
"state = \"Champion.Level1.ChunLiVsGuile\"\n",
|
|
"env = retro.make(game=game, state=state)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "59839d9c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1], dtype=int8)"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"env.action_space.sample()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "e068cb0a",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(200, 256, 3)"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"env.observation_space.sample().shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "1cb0297f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(200, 256, 3)\n",
|
|
"{'enemy_matches_won': 0, 'score': 0, 'matches_won': 0, 'continuetimer': 0, 'enemy_health': 176, 'health': 176}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"observation = env.reset()\n",
|
|
"print(observation.shape)\n",
|
|
"\n",
|
|
"action = env.action_space.sample()\n",
|
|
"obs, rewards, done, info = env.step(action)\n",
|
|
"print(info)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "0eaa5cc8",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"MultiBinary(12)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from gym.spaces import Box, MultiBinary\n",
|
|
"\n",
|
|
"print(MultiBinary(12))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "49f6cf5c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import cv2\n",
|
|
"\n",
|
|
"import gym\n",
|
|
"import numpy as np\n",
|
|
"from gym.spaces import Box, MultiBinary\n",
|
|
"\n",
|
|
"class StreetFighter(gym.Env):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.observation_space = Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)\n",
|
|
" self.action_space = MultiBinary(12)\n",
|
|
" self.game = retro.make(game=\"StreetFighterIISpecialChampionEdition-Genesis\", use_restricted_actions=retro.Actions.FILTERED)\n",
|
|
" \n",
|
|
" self.full_hp = 176\n",
|
|
" self.player_health = self.full_hp\n",
|
|
" self.oppont_health = self.full_hp\n",
|
|
" \n",
|
|
" self.score = 0\n",
|
|
" \n",
|
|
" def __preprocess(self, observation):\n",
|
|
" gray = cv2.cvtColor(observation, cv2.COLOR_BGR2GRAY)\n",
|
|
" resize = cv2.resize(gray, (84,84), interpolation=cv2.INTER_CUBIC)\n",
|
|
" return resize\n",
|
|
"\n",
|
|
" def step(self, action):\n",
|
|
"\n",
|
|
" obs, reward, done, info = self.game.step(action)\n",
|
|
" custom_obs = self.__preprocess(obs) # It's just frame, not frame_delta\n",
|
|
"\n",
|
|
" # During fighting, either player or opponent has positive health points.\n",
|
|
" if info['health'] > 0 or info['enemy_health'] > 0:\n",
|
|
"\n",
|
|
" # Player Loses\n",
|
|
" if info['health'] < 0 and info['health'] != self.player_health and info['enemy_health'] != 0:\n",
|
|
" reward = (-self.full_hp) * info['enemy_health']\n",
|
|
"\n",
|
|
" # Player Wins\n",
|
|
" elif info['enemy_health'] < 0 and info['enemy_health'] != self.oppont_health and info['health'] != 0:\n",
|
|
" reward = self.full_hp * info['health']\n",
|
|
"\n",
|
|
" # During Fighting\n",
|
|
" else:\n",
|
|
" reward = (self.oppont_health - info['enemy_health']) - (self.player_health - info['health'])\n",
|
|
" \n",
|
|
" self.player_health = info['health']\n",
|
|
" self.oppont_health = info['enemy_health']\n",
|
|
" \n",
|
|
" return custom_obs, reward, done, info\n",
|
|
" \n",
|
|
" def render(self, *args, **kwargs):\n",
|
|
" self.game.render()\n",
|
|
" \n",
|
|
" def reset(self):\n",
|
|
" obs = self.game.reset()\n",
|
|
" custom_obs = self.__preprocess(obs)\n",
|
|
" self.previous_frame = obs\n",
|
|
" \n",
|
|
" self.player_health = self.full_hp\n",
|
|
" self.oppont_health = self.full_hp\n",
|
|
" return custom_obs\n",
|
|
"\n",
|
|
" def close(self):\n",
|
|
" self.game.close()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "6ec30177",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(84, 84)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"env.close()\n",
|
|
"env = StreetFighter()\n",
|
|
"print(env.observation_space.shape)\n",
|
|
"env.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "7d9eab3a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"C:\\ProgramData\\Anaconda3\\envs\\StreetFighterAI\\lib\\site-packages\\pyglet\\image\\codecs\\wic.py:289: UserWarning: [WinError -2147417850] Cannot change thread mode after it is set\n",
|
|
" warnings.warn(str(err))\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"-22 154 176\n",
|
|
"-32 122 176\n",
|
|
"29 122 147\n",
|
|
"7 122 140\n",
|
|
"-31 91 140\n",
|
|
"29 91 111\n",
|
|
"-23 68 111\n",
|
|
"-24 44 111\n",
|
|
"-24 20 111\n",
|
|
"31 20 80\n",
|
|
"10 20 70\n",
|
|
"45 20 25\n",
|
|
"5 20 20\n",
|
|
"-15 5 20\n",
|
|
"19 5 1\n",
|
|
"-176 -1 1\n",
|
|
"46 176 130\n",
|
|
"7 176 123\n",
|
|
"-24 152 123\n",
|
|
"29 152 94\n",
|
|
"-24 128 94\n",
|
|
"7 128 87\n",
|
|
"39 128 48\n",
|
|
"-31 97 48\n",
|
|
"36 97 12\n",
|
|
"-24 73 12\n",
|
|
"-24 49 12\n",
|
|
"8624 49 -1\n",
|
|
"39 176 137\n",
|
|
"-24 152 137\n",
|
|
"-23 129 137\n",
|
|
"-23 106 137\n",
|
|
"-26 80 137\n",
|
|
"-24 56 137\n",
|
|
"-23 33 137\n",
|
|
"-21 12 137\n",
|
|
"-12 0 137\n",
|
|
"-24112 -1 137\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## Checking Rewards functionality\n",
|
|
"import time\n",
|
|
"\n",
|
|
"env = StreetFighter()\n",
|
|
"obs = env.reset()\n",
|
|
"done = False\n",
|
|
"\n",
|
|
"for game in range(5):\n",
|
|
" while not done:\n",
|
|
" if done:\n",
|
|
" obs = env.reset()\n",
|
|
" env.render()\n",
|
|
" obs, reward, done, info = env.step(env.action_space.sample())\n",
|
|
" if reward != 0:\n",
|
|
" print(reward, info['health'], info['enemy_health'])\n",
|
|
" time.sleep(0.01)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b1ae8310",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|