street-fighter-ai/custom_cnn.py
2023-03-30 01:14:39 +08:00

24 lines
942 B
Python

import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from torchvision.models import mobilenet_v3_small
# Custom feature extractor (CNN)
class CustomCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(CustomCNN, self).__init__(observation_space, features_dim=512)
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
nn.Linear(16384, self.features_dim),
nn.ReLU()
)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension