mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 15:10:43 +00:00
24 lines
942 B
Python
24 lines
942 B
Python
import gym
|
|
import torch
|
|
import torch.nn as nn
|
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
|
from torchvision.models import mobilenet_v3_small
|
|
# Custom feature extractor (CNN)
|
|
class CustomCNN(BaseFeaturesExtractor):
|
|
def __init__(self, observation_space: gym.Space):
|
|
super(CustomCNN, self).__init__(observation_space, features_dim=512)
|
|
self.cnn = nn.Sequential(
|
|
nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=0),
|
|
nn.ReLU(),
|
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
|
|
nn.ReLU(),
|
|
nn.Flatten(),
|
|
nn.Linear(16384, self.features_dim),
|
|
nn.ReLU()
|
|
)
|
|
|
|
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
|
return self.cnn(observations.permute(0, 3, 1, 2)) # Swap the channel dimension
|
|
|