street-fighter-ai/003_frame_delta_ram_based/custom_cnn.py

25 lines
892 B
Python
Raw Normal View History

2023-03-29 17:14:39 +00:00
import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# Custom feature extractor (CNN)
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(CustomCNN, self).__init__(observation_space, features_dim=512)
self.cnn = nn.Sequential(
2023-03-30 18:10:25 +00:00
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),
2023-03-29 17:14:39 +00:00
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16384, self.features_dim),
nn.ReLU()
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
2023-03-30 18:10:25 +00:00
observations = observations.unsqueeze(1)
2023-03-29 17:14:39 +00:00
return self.cnn(observations)