fastbook/clean/08_collab.ipynb

1686 lines
235 KiB
Plaintext
Raw Normal View History

2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Collaborative filtering deep dive"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A first look at the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.collab import *\n",
"from fastai2.tabular.all import *\n",
"path = untar_data(URLs.ML_100k)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>movie</th>\n",
" <th>rating</th>\n",
" <th>timestamp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>196</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>881250949</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>186</td>\n",
" <td>302</td>\n",
" <td>3</td>\n",
" <td>891717742</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>22</td>\n",
" <td>377</td>\n",
" <td>1</td>\n",
" <td>878887116</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>244</td>\n",
" <td>51</td>\n",
" <td>2</td>\n",
" <td>880606923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>166</td>\n",
" <td>346</td>\n",
" <td>1</td>\n",
" <td>886397596</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user movie rating timestamp\n",
"0 196 242 3 881250949\n",
"1 186 302 3 891717742\n",
"2 22 377 1 878887116\n",
"3 244 51 2 880606923\n",
"4 166 346 1 886397596"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ratings = pd.read_csv(path/'u.data', delimiter='\\t', header=None,\n",
" names=['user','movie','rating','timestamp'])\n",
"ratings.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"last_skywalker = np.array([0.98,0.9,-0.9])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"user1 = np.array([0.9,0.8,-0.6])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.1420000000000003"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(user1*last_skywalker).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"casablanca = np.array([-0.99,-0.3,0.8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-1.611"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(user1*casablanca).sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning the latent factors"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating the DataLoaders"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>movie</th>\n",
" <th>title</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>Toy Story (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>GoldenEye (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>Four Rooms (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>Get Shorty (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Copycat (1995)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" movie title\n",
"0 1 Toy Story (1995)\n",
"1 2 GoldenEye (1995)\n",
"2 3 Four Rooms (1995)\n",
"3 4 Get Shorty (1995)\n",
"4 5 Copycat (1995)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movies = pd.read_csv(path/'u.item', delimiter='|', encoding='latin-1',\n",
" usecols=(0,1), names=('movie','title'), header=None)\n",
"movies.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>movie</th>\n",
" <th>rating</th>\n",
" <th>timestamp</th>\n",
" <th>title</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>196</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>881250949</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>63</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>875747190</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>226</td>\n",
" <td>242</td>\n",
" <td>5</td>\n",
" <td>883888671</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>154</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>879138235</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>306</td>\n",
" <td>242</td>\n",
" <td>5</td>\n",
" <td>876503793</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user movie rating timestamp title\n",
"0 196 242 3 881250949 Kolya (1996)\n",
"1 63 242 3 875747190 Kolya (1996)\n",
"2 226 242 5 883888671 Kolya (1996)\n",
"3 154 242 3 879138235 Kolya (1996)\n",
"4 306 242 5 876503793 Kolya (1996)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ratings = ratings.merge(movies)\n",
"ratings.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>title</th>\n",
" <th>rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>207</td>\n",
" <td>Four Weddings and a Funeral (1994)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>565</td>\n",
" <td>Remains of the Day, The (1993)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>506</td>\n",
" <td>Kids (1995)</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>845</td>\n",
" <td>Chasing Amy (1997)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>798</td>\n",
" <td>Being Human (1993)</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>500</td>\n",
" <td>Down by Law (1986)</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>409</td>\n",
" <td>Much Ado About Nothing (1993)</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>721</td>\n",
" <td>Braveheart (1995)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>316</td>\n",
" <td>Psycho (1960)</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>883</td>\n",
" <td>Judgment Night (1993)</td>\n",
" <td>5</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls = CollabDataLoaders.from_df(ratings, item_name='title', bs=64)\n",
"dls.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_users = len(dls.classes['user'])\n",
"n_movies = len(dls.classes['title'])\n",
"n_factors = 5\n",
"\n",
"user_factors = torch.randn(n_users, n_factors)\n",
"movie_factors = torch.randn(n_movies, n_factors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4586, -0.9915, -0.4052, -0.3621, -0.5908])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot_3 = one_hot(3, n_users).float()\n",
"user_factors.t() @ one_hot_3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4586, -0.9915, -0.4052, -0.3621, -0.5908])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_factors[3]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Collaborative filtering from scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Example:\n",
" def __init__(self, a): self.a = a\n",
" def say(self,x): return f'Hello {self.a}, {x}.'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Hello Sylvain, nice to meet you.'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ex = Example('Sylvain')\n",
"ex.say('nice to meet you')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProduct(Module):\n",
" def __init__(self, n_users, n_movies, n_factors):\n",
" self.user_factors = Embedding(n_users, n_factors)\n",
" self.movie_factors = Embedding(n_movies, n_factors)\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors(x[:,0])\n",
" movies = self.movie_factors(x[:,1])\n",
" return (users * movies).sum(dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 2])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dls.one_batch()\n",
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = DotProduct(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.326261</td>\n",
" <td>1.295701</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.091352</td>\n",
" <td>1.091475</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.961574</td>\n",
" <td>0.977690</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.829995</td>\n",
" <td>0.893122</td>\n",
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.781661</td>\n",
" <td>0.876511</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 5e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProduct(Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
" self.user_factors = Embedding(n_users, n_factors)\n",
" self.movie_factors = Embedding(n_movies, n_factors)\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors(x[:,0])\n",
" movies = self.movie_factors(x[:,1])\n",
" return sigmoid_range((users * movies).sum(dim=1), *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.976380</td>\n",
" <td>1.001455</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.875964</td>\n",
" <td>0.919960</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.685377</td>\n",
" <td>0.870664</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.483701</td>\n",
" <td>0.874071</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.385249</td>\n",
" <td>0.878055</td>\n",
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProduct(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProductBias(Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
" self.user_factors = Embedding(n_users, n_factors)\n",
" self.user_bias = Embedding(n_users, 1)\n",
" self.movie_factors = Embedding(n_movies, n_factors)\n",
" self.movie_bias = Embedding(n_movies, 1)\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors(x[:,0])\n",
" movies = self.movie_factors(x[:,1])\n",
" res = (users * movies).sum(dim=1, keepdim=True)\n",
" res += self.user_bias(x[:,0]) + self.movie_bias(x[:,1])\n",
" return sigmoid_range(res, *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.929161</td>\n",
" <td>0.936303</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.820444</td>\n",
" <td>0.861306</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.621612</td>\n",
" <td>0.865306</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.404648</td>\n",
" <td>0.886448</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.292948</td>\n",
" <td>0.892580</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProductBias(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Weight decay"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAdsAAAFtCAYAAABP6cBcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9d5Qd53mn+Xw3d8453c6NQAAkAnMCs0SlUSQlWcGybI8kh7U9Y+9Zz3hm5bX3jPfIchjLiqRF5ZwoRjGDJAiQSA2gc84531z7R3U1SBEgGuiq+qrq1nMODo4Ooe97u/t2/erNQlEUXFxcXFxcXIzDI9sAFxcXFxcXp+OKrYuLi4uLi8G4Yuvi4uLi4mIwrti6uLi4uLgYjCu2Li4uLi4uBuOKrYuLi4uLi8G4Yuvi4uLi4mIwmxJbIcTTQoiIEGJ5/U+H0Ya5uLi4uLg4hUvxbD+rKEr2+p9WwyxycXFxcXFxGG4Y2cXFxcXFxWAuRWz/TggxLYR4QQhxi1EGubi4uLi4OA2xmdnIQoirgdNADPgQ8C/AHkVRen7r330a+DRAVlbW3ra2Nt0NtjzjJyGUB/m1si0xlJSS4szsGcoyyyjOKDbt3uT8PPHhEYItzYhAwLR7N8vC5BqpVIqC8izZpkhhfmIVgPyyTMmWvBklkSB6tgNfZQW+wkLT7p2NzDK2MkZrQSs+j8+0e6WwOAIr01C+C4SQbY3pHD16dFpRlJLz/bdNie2b/k9CPAL8SlGUf77Qv9m3b59y5MiRSz7b9nznPpg6C3/0mmxLDOe2H9zGNRXX8Lc3/K1pd64cOsTgJ3+X2gcfJOvqA6bdu1ke+fIpZkeXuf9vrpFtihR+8HevEMoO8I7P7ZZtyptYO3mK/ve/n+p//RdybrvNtHs//9Lnebj3YV647wWE0wXoSzdARgF87BeyLZGCEOKooij7zvffLjdnqwAO/9RcJnXXw2wvLI7JtsRw6vPq6V/oN/VOX3kFAIlxa35//SEv8WhSthnSiEeT+INe2Wacl/j6Z8ZXXm7qvf0L/dTn1TtfaNfmYPwU1N0g2xJLclGxFULkCyHuEkKEhBA+IcSHgZuAR403z4aEr1f/HnhBrh0mEM4N07fQh5lrGv0V6oMyPjZu2p2Xgj/oiq0/ZE2xTax/ZvwVFabe27fYRzgvbOqdUhh8CVDOPQNd3sBmPFs/8HlgCpgGPge8W1EUt9f2fJTvgmAu9D8v2xLDqc+rZym+xExkxrQ7PRkZePPyNrwUqxEIeolHkqa+gFiJeDRJwLKe7TgiGMRbUGDanSvxFSZXJ6nPqzftTmn0Pw/eIFSdN4qa9lw0W68oyhSw3wRbnIHHC7XXpIVnW5+rPkD6FvpMLZLyVVRseClWwx/ykkopJBMpfH5rio5RKIpCLGLdMHJifAxfeZmp4dz+xX5AjQI5noEXoHof+EOyLbEkbp+tEYRvgOlOWJqQbYmhaG/rfQt9pt7rLy8nPm5RsV0XmnQMJScTKZSUYtkwcnxsHH+5ySHk9d8Nx3u2kQUYO64++1zOiyu2RqAVCDjcuy3PKifDl2G62PoqrCy2arAoHkk/sdVeMLTvgdWIj4/jN7k4qne+F6/wUpvj7FZABl8GJaUWiLqcF1dsjaBiNwSyHS+2Qgjq8+rpXeg19V5/eQWphQVSq6um3rsZ0tmz1V4wrBhGVpJJEpOT+CrMFdu+hT5qcmrwe/2m3ms6A8+Dxw/VbsbxQrhiawReH9RcDf3OFluAhrwG88VWq0i2oHerhVDTUmyj1hXbxNQUJJOmh5F7F3ppyGsw9U4p9L8AVXshYL1hJlbBFVujCN8AU2fUaSoOpiGvgfGVcVbiK6bdqfVJxsesV5GsVeKmcxg5YMGcrfZZ8Zvo2cZTcQYXB2nId7jYRpdg9DW35eciuGJrFOEb1b/7n5Nrh8Fob+1m5m21PsmEhT3bWDQh2RLziUXUr9mSnu36Z8Vnomc7tDhEQkk437MdfAmU5Llnnst5ccXWKCr3qHnbPmeLbX2+WmVpZijZV1YGWHOwRVrnbLUwsiU9W22ghXmerfY74Xix7XtWzdfWXC3bEkvjiq1ReP1Qe63jPduanBp8wkfvvHli6wkE8BYXW3KwhVuNbE3PNj4+hiczE09Ojml3amLr+Laf/ufUwig3X/uWuGJrJPU3Or7f1u/xU5tbK6EiudySgy3SukAqYt3Wn8TYOL6KClMHWvQu9FKeVU6m38EipPXX1rsh5Ivhiq2RpFHe1vTBFhbttfX5PSDSVGytHEaW1GPr+BDywCG1v9bN114UV2yNZGNOsrPFtj6vnqGlIeLJuGl3+sorSIyNWW4GsRBCXUaQpmFkIdZfOCxGfHzM1B7blJKif7Hf+WKrzUN2+2svivV+K5yE1wd11zm+SKohv4GkkmRgccC0O/3l5aRWV0ktLZl252YJBL1pW43sD3ott0pOicVITs+Y2mM7vjLOWmLN+fnavmeh5oA7D3kTuGJrNOEbYbYHFkdlW2IY2tu7mXlbK6/a84d8aRtGtmRx1OQkKIpbiaw3q7MwftINIW8SV2yNRisccPDKPW2jiantPxZeIp+uO23VXbZWLI4yf2m8Vp3v6IEWA4cAxS2O2iSu2BpN2RUQylfDLQ4l059JZVal69muk7Y5W4uu19MK6cxcGt+70Et+MJ/CUKFpd5pO/3Pgy1DHNLpcFFdsjcbjUTdhOL1IKr/e1IpkX0kJeDzW7LUNpbFna0Wx1QZamOjZ9i30OTuEDGotSu3V4AvKtsQWuGJrBvU3wlw/zA/JtsQwGvIa6F/oJ6WkTLlP+Hz4Skut2Wub1mFk64ltYnwMT14enkzz+l17F3qdXRy1Mg2T7W6+9hJwxdYM0qDftiGvgUgywuiyeYVgVl0i7w96N+YEpxNaNbLVUJfGm+fVzkZmmY/OO9uz1WpQ6m+Sa4eNcMXWDEq3Q0aho1uAZFQkq0vkrRdGDgTTtxo5YEWxNXmgRVoUR/U/B/4sqLxStiW2wRVbM/B41JV7Dq5IlrL9p7yCxPiE5QZbaDlbq9llNGrO1prVyGYOtEiLtp/+56H2GnUGvMumcMXWLOpvgoVBmDV3rKFZ5IfUykuzK5KVaJTk3Jxpd24Gf9ALSnqNbFRSiiVztqm1NZLz86YOtOhb6CPDl0F5lrnjIU1jaQKmzrotP5eIK7ZmoeU2HNwCVJ9Xb+r2H6sukQ9krG/+SSOxjceSoEDAYn2259p+zPVsw7lhPMKhj1et9qT+Zrl22AyHfhosSHELZJdD3zOyLTGMxrxGehd6TQufWnWJfEBbIL+WPkVSsTX1xSKQYS3PVsbS+N6FXhrzG027z3R6n4ZQHlTslm2JrXDF1iyEUL3bvmfBobm8hvwGFmOLzERmTLnPX27NwRaadxdLo8EWWvW15Txbk5fGr8ZXGV8Zd3a+tu8ZtcPCY60XK6vjiq2ZNNwMK1MweUa2JYag9RWaVSTlLSoCv99yIxs17y6d2n+0r9VqOVutWt2sUY3aZ9+xYjvXD/ODbgj5MnDF1kw28rbODCVrD5ie+R5T7hMeD/6yMst5ttp84Pha+ni28Y0wsrU828TYON6iIjyBgCn39Syon/36fIcOtOhdf3a5/bWXjCu2ZpJfCwX1ji2SKsssI9ufbZrYgjUHW5wLI6efZ2u5MLLJPbY98z34PD5qcmpMu9NU+p6F7DIoaZVtie1wxdZsGm5We9SSznsQCyFoyG/YeLs3A19FxcZWF6uQzmHkgMXCyInxMVO3/fTM9xDODeP3OLD/VFFUsa2/Sa1BcbkkXLE1m/qbILoIY8dlW2IITflN5nu2k5MoSeuEbDc827SsRraYZ2vyqMbu+W6a8ptMu89Ups7CyqSbr71MXLE1G+2D2ve0VDOMoiGvgdnILLORWVPu81WUQyJBYtqcCujN4PV58Po9GwKUDljRs00uL5NaXja1EnlkecS5YxrdfO2WcMXWbLKKoWznuQ+uw9De6s3ybjd6bcfMW4CwGQKh9FpGEFtL4At48Hit80iJj6qfCZ9Je2y1SmTHerZ9z0BBGAr
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.linspace(-2,2,100)\n",
"a_s = [1,2,5,10,50] \n",
"ys = [a * x**2 for a in a_s]\n",
"_,ax = plt.subplots(figsize=(8,6))\n",
"for a,y in zip(a_s,ys): ax.plot(x,y, label=f'a={a}')\n",
"ax.set_ylim([0,5])\n",
"ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.972090</td>\n",
" <td>0.962366</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.875591</td>\n",
" <td>0.885106</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.723798</td>\n",
" <td>0.839880</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.586002</td>\n",
" <td>0.823225</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.490980</td>\n",
" <td>0.823060</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProductBias(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating our own Embedding module"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#0) []"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class T(Module):\n",
" def __init__(self): self.a = torch.ones(3)\n",
"\n",
"L(T().parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#1) [Parameter containing:\n",
"tensor([1., 1., 1.], requires_grad=True)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class T(Module):\n",
" def __init__(self): self.a = nn.Parameter(torch.ones(3))\n",
"\n",
"L(T().parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#1) [Parameter containing:\n",
"tensor([[-0.9595],\n",
" [-0.8490],\n",
" [ 0.8159]], requires_grad=True)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class T(Module):\n",
" def __init__(self): self.a = nn.Linear(1, 3, bias=False)\n",
"\n",
"t = T()\n",
"L(t.parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.nn.parameter.Parameter"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(t.a.weight)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def create_params(size):\n",
" return nn.Parameter(torch.zeros(*size).normal_(0, 0.01))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProductBias(Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
" self.user_factors = create_params([n_users, n_factors])\n",
" self.user_bias = create_params([n_users])\n",
" self.movie_factors = create_params([n_movies, n_factors])\n",
" self.movie_bias = create_params([n_movies])\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors[x[:,0]]\n",
" movies = self.movie_factors[x[:,1]]\n",
" res = (users*movies).sum(dim=1)\n",
" res += self.user_bias[x[:,0]] + self.movie_bias[x[:,1]]\n",
" return sigmoid_range(res, *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.962146</td>\n",
" <td>0.936952</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.858084</td>\n",
" <td>0.884951</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.740883</td>\n",
" <td>0.838549</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.592497</td>\n",
" <td>0.823599</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.473570</td>\n",
" <td>0.824263</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProductBias(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interpreting embeddings and biases"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Children of the Corn: The Gathering (1996)',\n",
" 'Lawnmower Man 2: Beyond Cyberspace (1996)',\n",
" 'Beautician and the Beast, The (1997)',\n",
" 'Crow: City of Angels, The (1996)',\n",
" 'Home Alone 3 (1997)']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movie_bias = learn.model.movie_bias.weight.squeeze()\n",
"idxs = movie_bias.argsort()[:5]\n",
"[dls.classes['title'][i] for i in idxs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['L.A. Confidential (1997)',\n",
" 'Titanic (1997)',\n",
" 'Silence of the Lambs, The (1991)',\n",
" 'Shawshank Redemption, The (1994)',\n",
" 'Star Wars (1977)']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = movie_bias.argsort(descending=True)[:5]\n",
"[dls.classes['title'][i] for i in idxs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAyQAAAKuCAYAAABQVtgOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzddXhV9R/A8fe5d9vdXfdYMdgGo3N0SwqCCCIiiCgmFiBpYiAoISqgoIgIoiJId0h3p6NjyYJ13Dq/PwZXxoIBk/mTz+t5eB7u+eY5l7HzOd84iqqqCCGEEEIIIURZ0JR1B4QQQgghhBAPLglIhBBCCCGEEGVGAhIhhBBCCCFEmZGARAghhBBCCFFmJCARQgghhBBClBmbsu7AnfDy8lIrVKhQ1t0QQgghhBD/cQcOHEhUVdW7rPvxIPi/CkgqVKjA/v37y7obQgghhBDiP05RlEtl3YcHhUzZEkIIIYQQQpQZCUiEEEIIIYQQZUYCEiGEEEIIIUSZkYBECCGEEEIIUWYkIBFCCCGEEEKUGQlIhBBCCCGEEGVGAhIhhBBCCCFEmZGARAghhBBCCFFmJCARQgghhBBClBkJSIQQQgghhBBlRgISIYQQQgghRJmRgEQIIYQQQghRZiQgEUIIIYQQQpQZCUiEEEIIIYQQZUYCEiGEEEIIIUSZkYBECCGEEEIIUWYkIBFCCCGEEEKUGQlIhBBCCCGEEGVGAhIhhBBCCCFEmZGARAghhBBCCFFmJCARQgghhBBClBkJSIQQQgghhBBlplQDEkVRPBRFWawoSqaiKJcURXmqiHw6RVG+VRQlXlGUZEVRliuKElCafRFCCCGEEEL8+5X2CMk0wAD4An2BbxRFqV5IvjeBJkAtwB9IAb4u5b4IIYQQQggh/uVsSqsiRVEcgZ5ADVVVM4DtiqIsA54GRt2SvSKwVlXV+OtlfwUml1ZfhLidmE1nOP3TPlDBYjDjGu5N/TGdAPhr1m4q92+AxlZ7T20k7LvMsS+28ND8pwEwZuSypvNMarzZkoo9awNwdv5B0s4mUO/9jvd2QkUw5RjZMWgRzab1xEZvy9n5B7m07DiZUSk0/Kwr5ZpVtOZNPh7Lia+2YcoxorXVUmvEQ7iF+2Axmtn6/G/WfOZcE1kxqXRc8QI2etsi0+xc7Nnx2iLqvN0OR3/Xf+T8hBBCCPH/r9QCEqAyYFZV9fRNx44ArQrJOwv4UlGUG6MjfYHVhVWqKMqLwIsA5cuXL8XuigdVTmImRyf9Sasf+qD3dUZVVdLOJFrTT/+wl7A+9e44ILGYLGhs/h509KjlT1ZMKjnJWdh7OJB8NAa3cB8SD0VbA5KkQ1H4tQq9p3aKc2HhEfxbh2KjtwXAs24A5VqGcGT8xnz5VFVl/zurqP9hJzzrBJB0JIaDH66lzc/90NhqaT3n79mX5347ROL+K9i52AMUmxbyRB0iZ+2h3nsd7ugchRBCCPHgKM2AxAlIveVYKuBcSN7TwGUgGjADx4DXCqtUVdWZwEyAiIgItbQ6Kx5cucmZKFottq55N82KouBa2RuAo5P+BGDby7+jKArNpvYkftdFzv9+GIvRDED111rgHREEwPqesyn/SHUSD1zBMcCVOqPbWdvR6mxwq1qOpINRBLSrnBeIPF6byNl7AVDNFpKOxlBjcF7MfmLqNpIORWMxWbBztafO2+1wKOdCVmwaWwf+SoWetUjcf4XADlXQeTnw18zdKBoF1Wyh5tDWeNULLHCul5Yep+nXPayf3av6FnpNDCnZGDNy8ayTt5TLs7Y/OQkZpEYm4FbFJ1/eK6tOUfnZhoXWc2uab9MKHPlsE6ZMAzaOdoWWEUIIIcSDrTQDkgzA5ZZjLkB6IXm/AewBTyATGEHeCEmjUuyPEIVyCfPGvZovG3rMxrNuAB61/AnqVAU7Vz213mrDxT+O0eLbXtg45N1A+zQqT0D7yiiKQsala+x88w86LBlorS83KZNmU3sW2pZX3QASD0UT0K4ySYejCe1dl6j1p0k7n4Ql14Stow7HgLzpTGH9Iqj+WgsALi07zsnpO4j46GEADKk5OAd7UGVgYwA2PzOfWm+1xrNOAKrZginHWKDt7Ph0zDkmHMrd+mNZkM7dATtXPbHbzuHXIpS47ecxZRnJikvLF5CknIonJzGTcs0rFqijsDSNjRaXEE+Sj8Xg07jCbfshhBBCiAdPaQYkpwEbRVEqqap65vqx2sCJQvLWBt5RVTUZQFGUr4GPFEXxUlU1sZD8oox0n/YmdlpbbG1ssVgsPNusOx2qNym2zHdbF5FlzOHNtn3vqK1X5n1CXFoijnZ6AII9/Rj72Bv8cXADuSYjfRo+XGTZmJQE9lw4xmN1H7IeG/zb5wzr8AyB7nmjAksORTNhbSQxKdn4uzkw/NXmeOUYid16nnPzD9J6bl/rVKObZUan8tcHa8hJyCTWMZ3VwSdomZSJvacjCyocJ0rZS/Kn3/PnsFk42P1dfvmRLczLWUZWegaVftlH21wv7L0c8awbwKLNq1mZtA9TfSNrFiTxfteXSdt9hYuLjmLKNqKaVRYGHOPVT+exqt8XaOy0ZNeyZ/BvnzOl9wi86gVyYup2/NuE4dOkAi4hngX6nZ2Qgc7DocTXv8G4LpycvoPTP+zFvXo5nCt4FJgadnnlSQI7hqOxKTidrag0nacD2VczStwPIYQQQjxYSi0gUVU1U1GUP8gLLJ4H6gCPAk0Lyb4P6K8oymYgCxgExEgw8u80rsebhPoEERl3kRd+GkPDijVwcyhsJt69e6t9f5pXqpfvWI967YrI/bfY1ASWHNqULyCZ0nuE9e9LDkUz+o9jZF+fdhWdks3o3ZcZ16Mm3XvWZlPfuSQejMK/dViBug98sIbqr7fAr2Uob/76Gc1+DcRiyKunfpI/7w7syRNL3slX5kJiNDO2/M6PT3/E7h6/cKaCLZurRtEFyAzR8svGzYxI60DlVjXY4BrJ16vmETFLQ4vve+Po78qq9RtQVx+11mejtyXMpzy2Whv2XzxBxJstSTuXSOKBKPa/u4rQJ+sS3K1Gvj5o7Www55pue+1ucAv3oemXjwFgMZpZ2/V7nCp4WNPNuSaiN5ym2fTHC5QtPs2MVleazz6EEEII8V9S2ncJg4AfgKtAEvCKqqonFEVpAaxWVdXper5hwFfAGcAOOA48Vsp9EaUsvFwFHOz0xKRc5ff96/KNghQ1KrLi6BbWHt+Jg05P1LV4XPVOjOn2Cj7OHoU1Uahb6/5x51LWndiJomjQ2+qY2f99Jqz9kZiUBPp9P5pAd1/G9xxM92lvMqnXMEJ9gvh8zT50tlvR2+WitYB9SjWuUIEJayMZt3oEbZ0qMPvUCTJO5tAs2JP2mQbrlC1jhgEHPxfiUhM5f+UybRPCrX0LTffA3b5gcHY+IYpKvsF4uXvgXs0Xz21xLAm/AECCcza+6Y4YjiTjNSSQpoqel3d+TAObZth7OpKSmcbsvcvof6UmB71i8tXboVoTlh3ZTBXFH5dQL1xCvTBlG0k5FV8gIHEKdic3KROzwYTW7vY/6jnXR30Azszdj2edAJwC3azpsVvO4RjoVuhoTHFpGZeScanU+LbtCyGEEOLBVKoByfUpWN0LOb6NvEXvNz4nkbezlvg/sv/iCXJNBoI8yt1RuSNRkcwd+CnBnv58v20Rk9f9xPiegwvNO2n9T3y75XcAejfoRNfa+TdpW3l0K9vOHGRm/w9w0jmQmpWORtEwvOMAvto4nznPfVJovTmmTeQYw8gxhuJuTAK3jQzdbcbGbMec1lC+USU+fW4AR65EMjJtEnVf/wOtzoZmU3tS480W7Bu9kuPBSVTw88LOteC0rltV8inPqdjzxKRcxaNuAL/t3kWOxUBqdgbhfhWIcUon3cOCvpwzazctJ9uUi0ubQP7sN4/fq5yiZ7km2FtyCtRbM6ASk9fP5eTuHWReSUHRarB11lFnVNsCebU6G7zqBZJ0KBqfRsEAnP35AOd/P4whJZvDY9ejsdPS5ud+2DrquLT0OFHrIlEtKm5VfKjzdv6RqcsrT1K
"text/plain": [
"<Figure size 864x864 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"g = ratings.groupby('title')['rating'].count()\n",
"top_movies = g.sort_values(ascending=False).index.values[:1000]\n",
"top_idxs = tensor([learn.dls.classes['title'].o2i[m] for m in top_movies])\n",
"movie_w = learn.model.movie_factors.weight[top_idxs].cpu().detach()\n",
"movie_pca = movie_w.pca(3)\n",
"fac0,fac1,fac2 = movie_pca.t()\n",
"idxs = np.random.choice(len(top_movies), 50, replace=False)\n",
"idxs = list(range(50))\n",
"X = fac0[idxs]\n",
"Y = fac2[idxs]\n",
"plt.figure(figsize=(12,12))\n",
"plt.scatter(X, Y)\n",
"for i, x, y in zip(top_movies[idxs], X, Y):\n",
" plt.text(x,y,i, color=np.random.rand(3)*0.7, fontsize=11)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using fastai.collab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = collab_learner(dls, n_factors=50, y_range=(0, 5.5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.931751</td>\n",
" <td>0.953806</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.851826</td>\n",
" <td>0.878119</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.715254</td>\n",
" <td>0.834711</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.583173</td>\n",
" <td>0.821470</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.496625</td>\n",
" <td>0.821688</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"EmbeddingDotBias(\n",
" (u_weight): Embedding(944, 50)\n",
" (i_weight): Embedding(1635, 50)\n",
" (u_bias): Embedding(944, 1)\n",
" (i_bias): Embedding(1635, 1)\n",
")"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Titanic (1997)',\n",
" \"Schindler's List (1993)\",\n",
" 'Shawshank Redemption, The (1994)',\n",
" 'L.A. Confidential (1997)',\n",
" 'Silence of the Lambs, The (1991)']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movie_bias = learn.model.i_bias.weight.squeeze()\n",
"idxs = movie_bias.argsort(descending=True)[:5]\n",
"[dls.classes['title'][i] for i in idxs]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embedding distance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Dial M for Murder (1954)'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movie_factors = learn.model.i_weight.weight\n",
"idx = dls.classes['title'].o2i['Silence of the Lambs, The (1991)']\n",
"distances = nn.CosineSimilarity(dim=1)(movie_factors, movie_factors[idx][None])\n",
"idx = distances.argsort(descending=True)[1]\n",
"dls.classes['title'][idx]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Boot strapping a collaborative filtering model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deep learning for collaborative filtering"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(944, 74), (1635, 101)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embs = get_emb_sz(dls)\n",
"embs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class CollabNN(Module):\n",
" def __init__(self, user_sz, item_sz, y_range=(0,5.5), n_act=100):\n",
" self.user_factors = Embedding(*user_sz)\n",
" self.item_factors = Embedding(*item_sz)\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(user_sz[1]+item_sz[1], n_act),\n",
" nn.ReLU(),\n",
" nn.Linear(n_act, 1))\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" embs = self.user_factors(x[:,0]),self.item_factors(x[:,1])\n",
" x = self.layers(torch.cat(embs, dim=1))\n",
" return sigmoid_range(x, *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = CollabNN(*embs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.940104</td>\n",
" <td>0.959786</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.893943</td>\n",
" <td>0.905222</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.865591</td>\n",
" <td>0.875238</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.800177</td>\n",
" <td>0.867468</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.760255</td>\n",
" <td>0.867455</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.002747</td>\n",
" <td>0.972392</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.926903</td>\n",
" <td>0.922348</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.877160</td>\n",
" <td>0.893401</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.838334</td>\n",
" <td>0.865040</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.781666</td>\n",
" <td>0.864936</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = collab_learner(dls, use_nn=True, y_range=(0, 5.5), layers=[100,50])\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@delegates(TabularModel)\n",
"class EmbeddingNN(TabularModel):\n",
" def __init__(self, emb_szs, layers, **kwargs):\n",
" super().__init__(emb_szs, layers=layers, n_cont=0, out_sz=1, **kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: kwargs and delegates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Questionnaire"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research\n",
"\n",
"1. Take a look at all the differences between the `Embedding` version of `DotProductBias` and the `create_params` version, and try to understand why each of those changes is required. If you're not sure, try reverting each change, to see what happens. (NB: even the type of brackets used in `forward` has changed!)\n",
"1. Find three other areas where collaborative filtering is being used, and find out what pros and cons of this approach in those areas.\n",
"1. Complete this notebook using the full MovieLens dataset, and compare your results to online benchmarks. See if you can improve your accuracy. Look on the book website and forum for ideas. Note that there are more columns in the full dataset--see if you can use those too (the next chapter might give you ideas)\n",
"1. Create a model for MovieLens with works with CrossEntropy loss, and compare it to the model in this chapter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}