Fix shapes

This commit is contained in:
Sylvain Gugger 2020-04-13 07:46:14 -07:00
parent c854e380a6
commit fb7e2b8af4
2 changed files with 4 additions and 4 deletions

View File

@ -1723,7 +1723,7 @@
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n", " \n",
" def forward(self, x):\n", " def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n", " res,h = self.rnn(self.i_h(x), self.h)\n",
@ -2039,7 +2039,7 @@
" self.drop = nn.Dropout(p)\n", " self.drop = nn.Dropout(p)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h_o.weight = self.i_h.weight\n", " self.h_o.weight = self.i_h.weight\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n", " \n",
" def forward(self, x):\n", " def forward(self, x):\n",
" raw,h = self.rnn(self.i_h(x), self.h)\n", " raw,h = self.rnn(self.i_h(x), self.h)\n",

View File

@ -1154,7 +1154,7 @@
" self.i_h = nn.Embedding(vocab_sz, n_hidden)\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden)\n",
" self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n", " self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n", " \n",
" def forward(self, x):\n", " def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n", " res,h = self.rnn(self.i_h(x), self.h)\n",
@ -1362,7 +1362,7 @@
" self.drop = nn.Dropout(p)\n", " self.drop = nn.Dropout(p)\n",
" self.h_o = nn.Linear(n_hidden, vocab_sz)\n", " self.h_o = nn.Linear(n_hidden, vocab_sz)\n",
" self.h_o.weight = self.i_h.weight\n", " self.h_o.weight = self.i_h.weight\n",
" self.h = [torch.zeros(2, bs, n_hidden) for _ in range(n_layers)]\n", " self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(1)]\n",
" \n", " \n",
" def forward(self, x):\n", " def forward(self, x):\n",
" raw,h = self.rnn(self.i_h(x), self.h)\n", " raw,h = self.rnn(self.i_h(x), self.h)\n",