diff --git a/.gitignore b/.gitignore index 52ac8d0..4c18755 100644 --- a/.gitignore +++ b/.gitignore @@ -4,16 +4,21 @@ archives/ images/ -main/logs/monitoring/ -recordings/ +/main/logs/monitoring/ -007* +# Mac system files +/007* *.DS_Store .DS_Store -main/recordings/ -main/record.py -main/update_model.py +# Recorded videos +/main/recordings/ + +# Main scripts +/main/*.py +!/main/test.py +!/main/train.py +!/main/street_fighter_custom_wrapper.py # Game Data /data/ diff --git a/README.md b/README.md index 713a1f4..1a643b9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# SFighterAI 街头霸王游戏智能代理 +# SFighterAI 本项目基于深度强化学习训练了一个用于通关《街头霸王·二:冠军特别版》(Street Fighter II Special Champion Edition)关底 BOSS 的智能 AI 代理。该智能代理完全基于游戏画面(RGB 像素值)进行决策,在该项目给定存档中最后一关的第一轮对局可以取得 100% 胜率(实际上出现了“过拟合”现象,详见[运行测试]部分的讨论)。 @@ -71,3 +71,25 @@ python test.py cd [项目上级文件夹]/street-fighter-ai/main python train.py ``` + +### 查看曲线 + +项目中包含了训练过程的 Tensorboard 曲线图,可以使用 Tensorboard 查看其中的详细数据。推荐使用 VSCode 集成的 Tensorboard 插件直接查看(我爱你 VSCode!)。以下列出传统查看方法: + +```bash +cd [项目上级文件夹]/street-fighter-ai/main +tensorboard --logdir=logs/ +``` + +在浏览器中打开 Tensorboard 服务默认地址 `http://localhost:6006/`,即可查看训练过程的交互式曲线图。 + +## 鸣谢 +本项目使用了 [OpenAI Gym Retro](https://retro.readthedocs.io/en/latest/getting_started.html)、[Stable-Baselines3](https://stable-baselines3.readthedocs.io/en/master/) 等开源代码库。感谢各位程序工作者对开源社区的贡献! + +特别列出以下两篇对本项目启发作用很大的论文: + +[1] [DIAMBRA Arena A New Reinforcement Learning Platform for Research and Experimentation](https://arxiv.org/abs/2210.10595) +这篇论文中关于格斗游戏深度强化学习模型超参数设置的经验总结非常有价值,对本项目的训练过程有很大的帮助。 + +[2] [Mitigating Cowardice for Reinforcement Learning](https://ieee-cog.org/2022/assets/papers/paper_111.pdf) +这篇论文中提出的“惩罚衰减”机制有效地解决了本次训练中智能代理在游戏中的“怯懈”(始终回避对手,不敢尝试攻击)问题,帮助非常大。 diff --git a/main/train.py b/main/train.py index 3bf5d1a..93d9692 100644 --- a/main/train.py +++ b/main/train.py @@ -35,10 +35,6 @@ def make_env(game, state, seed=0): obs_type=retro.Observations.IMAGE ) env = StreetFighterCustomWrapper(env) - # Create log directory - # env_log_dir = os.path.join(LOG_DIR, str(seed+200)) # +100 to avoid conflict with other log dirs when fine-tuning - # os.makedirs(env_log_dir, exist_ok=True) - # env = Monitor(env, env_log_dir) env = Monitor(env) env.seed(seed) return env