mirror of
https://github.com/linyiLYi/street-fighter-ai.git
synced 2025-04-05 07:30:42 +00:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class CNNEncoder(nn.Module):
|
|
def __init__(self, features_dim=512):
|
|
super(CNNEncoder, self).__init__()
|
|
self.conv1 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
|
|
self.relu1 = nn.ReLU()
|
|
self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=2)
|
|
self.relu2 = nn.ReLU()
|
|
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
|
|
self.relu3 = nn.ReLU()
|
|
self.fc = nn.Linear(16384, 512)
|
|
|
|
def forward(self, x):
|
|
x = self.relu1(self.conv1(x))
|
|
x = self.relu2(self.conv2(x))
|
|
x = self.relu3(self.conv3(x))
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class CNNLSTM(nn.Module):
|
|
def __init__(self, features_dim=512):
|
|
super(CNNLSTM, self).__init__()
|
|
self.encoder = CNNEncoder(512)
|
|
self.lstm = nn.LSTM(512, 512)
|
|
|
|
def forward(self, x, hidden):
|
|
x = self.encoder(x)
|
|
x, hidden = self.lstm(x.unsqueeze(0), hidden)
|
|
return x.squeeze(0), hidden
|
|
|
|
def init_hidden(self, batch_size):
|
|
return (torch.zeros(1, batch_size, 512), torch.zeros(1, batch_size, 512)) |