mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 23:20:43 +00:00
22 lines
873 B
Python
22 lines
873 B
Python
import gym
|
|
import torch
|
|
import torch.nn as nn
|
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
|
from torchvision.models import mobilenet_v3_small
|
|
|
|
# Custom MobileNetV3 Feature Extractor
|
|
class MobileNetV3Extractor(BaseFeaturesExtractor):
|
|
def __init__(self, observation_space: gym.Space):
|
|
super(MobileNetV3Extractor, self).__init__(observation_space, features_dim=256)
|
|
self.mobilenet = mobilenet_v3_small(pretrained=True)
|
|
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(576, self.features_dim)
|
|
|
|
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
|
# x = observations.permute(0, 2, 3, 1) # Swap the channel dimension
|
|
x = self.mobilenet.features(observations)
|
|
x = self.adaptive_pool(x)
|
|
x = torch.flatten(x, 1)
|
|
x = self.fc(x)
|
|
return x
|