mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-04 23:20:43 +00:00
19 lines
768 B
Python
19 lines
768 B
Python
|
import gym
|
||
|
import torch
|
||
|
import torchvision
|
||
|
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
||
|
|
||
|
# Custom MobileNetV3 Feature Extractor
|
||
|
class MobileNetV3Extractor(BaseFeaturesExtractor):
|
||
|
def __init__(self, observation_space: gym.Space):
|
||
|
super(MobileNetV3Extractor, self).__init__(observation_space, features_dim=576)
|
||
|
self.mobilenet = torchvision.models.mobilenet_v3_small(pretrained=True)
|
||
|
self.mobilenet = torch.nn.Sequential(*list(self.mobilenet.children())[:-1])
|
||
|
self.adaptive_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
||
|
|
||
|
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
||
|
x = self.mobilenet(observations)
|
||
|
x = self.adaptive_pool(x)
|
||
|
x = torch.flatten(x, 1)
|
||
|
return x
|