diff --git a/06_multicat.ipynb b/06_multicat.ipynb index 96eadc4..6c3beb5 100644 --- a/06_multicat.ipynb +++ b/06_multicat.ipynb @@ -976,7 +976,7 @@ "source": [ "def binary_cross_entropy(inputs, targets):\n", " inputs = inputs.sigmoid()\n", - " return -torch.where(targets==1, inputs, 1-inputs).log().mean()" + " return -torch.where(targets==1, 1-inputs, inputs).log().mean()" ] }, {