street-fighter-ai/mobilenet_extractor.py

22 lines
873 B
Python
Raw Normal View History

2023-03-29 17:14:39 +00:00
import gym
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from torchvision.models import mobilenet_v3_small
# Custom MobileNetV3 Feature Extractor
class MobileNetV3Extractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(MobileNetV3Extractor, self).__init__(observation_space, features_dim=256)
self.mobilenet = mobilenet_v3_small(pretrained=True)
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(576, self.features_dim)
def forward(self, observations: torch.Tensor) -> torch.Tensor:
# x = observations.permute(0, 2, 3, 1) # Swap the channel dimension
x = self.mobilenet.features(observations)
x = self.adaptive_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x