street-fighter-ai/mobilenet_extractor_no_condensing.py
2023-03-30 01:14:39 +08:00

19 lines
768 B
Python

import gym
import torch
import torchvision
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# Custom MobileNetV3 Feature Extractor
class MobileNetV3Extractor(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.Space):
super(MobileNetV3Extractor, self).__init__(observation_space, features_dim=576)
self.mobilenet = torchvision.models.mobilenet_v3_small(pretrained=True)
self.mobilenet = torch.nn.Sequential(*list(self.mobilenet.children())[:-1])
self.adaptive_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, observations: torch.Tensor) -> torch.Tensor:
x = self.mobilenet(observations)
x = self.adaptive_pool(x)
x = torch.flatten(x, 1)
return x