mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 15:10:43 +00:00
updated reward function (added bonus for hit and get-hit)
This commit is contained in:
parent
7690765a4d
commit
41650747b0
Binary file not shown.
Binary file not shown.
@ -8,16 +8,16 @@ class CustomCNN(BaseFeaturesExtractor):
|
||||
def __init__(self, observation_space: gym.Space):
|
||||
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(1, 32, kernel_size=8, stride=4, padding=0),
|
||||
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
|
||||
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
||||
nn.ReLU(),
|
||||
nn.Flatten(),
|
||||
nn.Linear(3136, self.features_dim),
|
||||
nn.Linear(16384, self.features_dim),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
return self.cnn(observations)
|
||||
return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension
|
@ -9,7 +9,6 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
self.win_template = win_template
|
||||
self.lose_template = lose_template
|
||||
self.threshold = threshold
|
||||
|
||||
self.game_screen_gray = None
|
||||
|
||||
self.prev_player_health = 1.0
|
||||
@ -17,7 +16,7 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
|
||||
# Update observation space to single-channel grayscale image
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=(84, 84, 1), dtype=np.uint8
|
||||
low=0.0, high=1.0, shape=(84, 84, 1), dtype=np.float32
|
||||
)
|
||||
|
||||
def _preprocess_observation(self, observation):
|
||||
@ -26,7 +25,7 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
# print("self.game_screen_gray size: ", self.game_screen_gray.shape)
|
||||
# Print the size of the observation
|
||||
# print("Observation size: ", observation.shape)
|
||||
resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
resized_image = cv2.resize(self.game_screen_gray, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
|
||||
return np.expand_dims(resized_image, axis=-1)
|
||||
|
||||
def _check_game_over(self):
|
||||
@ -46,11 +45,26 @@ class StreetFighterCustomWrapper(gym.Wrapper):
|
||||
player_health = np.sum(player_health_area > 129) / player_health_area.size
|
||||
opponent_health = np.sum(oppoent_health_area > 129) / oppoent_health_area.size
|
||||
|
||||
reward = player_health - opponent_health
|
||||
player_health_diff = self.prev_player_health - player_health
|
||||
opponent_health_diff = self.prev_opponent_health - opponent_health
|
||||
|
||||
reward = (opponent_health_diff - player_health_diff) * 100
|
||||
|
||||
# Add bonus for successful attacks or penalize for taking damage
|
||||
if opponent_health_diff > player_health_diff:
|
||||
reward += 10 # Bonus for successful attacks
|
||||
elif opponent_health_diff < player_health_diff:
|
||||
reward -= 10 # Penalty for taking damage
|
||||
|
||||
self.prev_player_health = player_health
|
||||
self.prev_opponent_health = opponent_health
|
||||
|
||||
return reward
|
||||
|
||||
def reset(self):
|
||||
observation = self.env.reset()
|
||||
self.prev_player_health = 1.0
|
||||
self.prev_opponent_health = 1.0
|
||||
return self._preprocess_observation(observation)
|
||||
|
||||
def step(self, action):
|
||||
|
@ -45,12 +45,12 @@ model = PPO(
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
)
|
||||
model.load("ppo_sf2_cnn")
|
||||
model.load("ppo_sf2_cnn_new")
|
||||
|
||||
obs = env.reset()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
while True:
|
||||
timestamp = time.time()
|
||||
action, _ = model.predict(obs)
|
||||
obs, rewards, done, info = env.step(action)
|
||||
@ -59,4 +59,4 @@ while not done:
|
||||
if render_time < 0.0111:
|
||||
time.sleep(0.0111 - render_time) # Add a delay for 90 FPS
|
||||
|
||||
env.close()
|
||||
# env.close()
|
@ -45,11 +45,23 @@ def main():
|
||||
env,
|
||||
device="cuda",
|
||||
policy_kwargs=policy_kwargs,
|
||||
verbose=1
|
||||
verbose=1,
|
||||
n_steps=2048,
|
||||
batch_size=64,
|
||||
n_epochs=10,
|
||||
learning_rate=0.0003,
|
||||
ent_coef=0.01,
|
||||
clip_range=0.2,
|
||||
clip_range_vf=None,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
max_grad_norm=0.5,
|
||||
use_sde=False,
|
||||
sde_sample_freq=-1
|
||||
)
|
||||
model.learn(total_timesteps=int(1000))
|
||||
model.learn(total_timesteps=int(500000))
|
||||
|
||||
model.save("ppo_sf2_cnn")
|
||||
model.save("ppo_sf2_cnn_new")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user