2020-03-06 18:19:03 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Collaborative filtering deep dive"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A first look at the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.collab import *\n",
"from fastai2.tabular.all import *\n",
"path = untar_data(URLs.ML_100k)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>movie</th>\n",
" <th>rating</th>\n",
" <th>timestamp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>196</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>881250949</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>186</td>\n",
" <td>302</td>\n",
" <td>3</td>\n",
" <td>891717742</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>22</td>\n",
" <td>377</td>\n",
" <td>1</td>\n",
" <td>878887116</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>244</td>\n",
" <td>51</td>\n",
" <td>2</td>\n",
" <td>880606923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>166</td>\n",
" <td>346</td>\n",
" <td>1</td>\n",
" <td>886397596</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user movie rating timestamp\n",
"0 196 242 3 881250949\n",
"1 186 302 3 891717742\n",
"2 22 377 1 878887116\n",
"3 244 51 2 880606923\n",
"4 166 346 1 886397596"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ratings = pd.read_csv(path/'u.data', delimiter='\\t', header=None,\n",
" names=['user','movie','rating','timestamp'])\n",
"ratings.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"last_skywalker = np.array([0.98,0.9,-0.9])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"user1 = np.array([0.9,0.8,-0.6])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.1420000000000003"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(user1*last_skywalker).sum()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"casablanca = np.array([-0.99,-0.3,0.8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-1.611"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(user1*casablanca).sum()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning the latent factors"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating the DataLoaders"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>movie</th>\n",
" <th>title</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>Toy Story (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>GoldenEye (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>Four Rooms (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>Get Shorty (1995)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Copycat (1995)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" movie title\n",
"0 1 Toy Story (1995)\n",
"1 2 GoldenEye (1995)\n",
"2 3 Four Rooms (1995)\n",
"3 4 Get Shorty (1995)\n",
"4 5 Copycat (1995)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movies = pd.read_csv(path/'u.item', delimiter='|', encoding='latin-1',\n",
" usecols=(0,1), names=('movie','title'), header=None)\n",
"movies.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>movie</th>\n",
" <th>rating</th>\n",
" <th>timestamp</th>\n",
" <th>title</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>196</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>881250949</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>63</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>875747190</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>226</td>\n",
" <td>242</td>\n",
" <td>5</td>\n",
" <td>883888671</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>154</td>\n",
" <td>242</td>\n",
" <td>3</td>\n",
" <td>879138235</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>306</td>\n",
" <td>242</td>\n",
" <td>5</td>\n",
" <td>876503793</td>\n",
" <td>Kolya (1996)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" user movie rating timestamp title\n",
"0 196 242 3 881250949 Kolya (1996)\n",
"1 63 242 3 875747190 Kolya (1996)\n",
"2 226 242 5 883888671 Kolya (1996)\n",
"3 154 242 3 879138235 Kolya (1996)\n",
"4 306 242 5 876503793 Kolya (1996)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ratings = ratings.merge(movies)\n",
"ratings.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>user</th>\n",
" <th>title</th>\n",
" <th>rating</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
2020-04-28 17:12:59 +00:00
" <td>542</td>\n",
" <td>My Left Foot (1989)</td>\n",
" <td>4</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
2020-04-28 17:12:59 +00:00
" <td>422</td>\n",
" <td>Event Horizon (1997)</td>\n",
" <td>3</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
2020-04-28 17:12:59 +00:00
" <td>311</td>\n",
" <td>African Queen, The (1951)</td>\n",
" <td>4</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
2020-04-28 17:12:59 +00:00
" <td>595</td>\n",
" <td>Face/Off (1997)</td>\n",
" <td>4</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
2020-04-28 17:12:59 +00:00
" <td>617</td>\n",
" <td>Evil Dead II (1987)</td>\n",
" <td>1</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
2020-04-28 17:12:59 +00:00
" <td>158</td>\n",
" <td>Jurassic Park (1993)</td>\n",
" <td>5</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
2020-04-28 17:12:59 +00:00
" <td>836</td>\n",
" <td>Chasing Amy (1997)</td>\n",
2020-03-06 18:19:03 +00:00
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
2020-04-28 17:12:59 +00:00
" <td>474</td>\n",
" <td>Emma (1996)</td>\n",
" <td>3</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
2020-04-28 17:12:59 +00:00
" <td>466</td>\n",
" <td>Jackie Chan's First Strike (1996)</td>\n",
" <td>3</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
2020-04-28 17:12:59 +00:00
" <td>554</td>\n",
" <td>Scream (1996)</td>\n",
" <td>3</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"dls = CollabDataLoaders.from_df(ratings, item_name='title', bs=64)\n",
"dls.show_batch()"
]
},
2020-04-28 17:12:59 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'user': (#944) ['#na#',1,2,3,4,5,6,7,8,9...],\n",
" 'title': (#1635) ['#na#',\"'Til There Was You (1997)\",'1-900 (1994)','101 Dalmatians (1996)','12 Angry Men (1957)','187 (1997)','2 Days in the Valley (1996)','20,000 Leagues Under the Sea (1954)','2001: A Space Odyssey (1968)','3 Ninjas: High Noon At Mega Mountain (1998)'...]}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.classes"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_users = len(dls.classes['user'])\n",
"n_movies = len(dls.classes['title'])\n",
"n_factors = 5\n",
"\n",
"user_factors = torch.randn(n_users, n_factors)\n",
"movie_factors = torch.randn(n_movies, n_factors)"
]
},
2020-04-28 17:12:59 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"one_hot_3 = one_hot(3, n_users).float()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([944, 5])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_factors.shape"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4586, -0.9915, -0.4052, -0.3621, -0.5908])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_factors.t() @ one_hot_3"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4586, -0.9915, -0.4052, -0.3621, -0.5908])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"user_factors[3]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Collaborative filtering from scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Example:\n",
" def __init__(self, a): self.a = a\n",
" def say(self,x): return f'Hello {self.a}, {x}.'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Hello Sylvain, nice to meet you.'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ex = Example('Sylvain')\n",
"ex.say('nice to meet you')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProduct(Module):\n",
" def __init__(self, n_users, n_movies, n_factors):\n",
" self.user_factors = Embedding(n_users, n_factors)\n",
" self.movie_factors = Embedding(n_movies, n_factors)\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors(x[:,0])\n",
" movies = self.movie_factors(x[:,1])\n",
" return (users * movies).sum(dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([64, 2])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dls.one_batch()\n",
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = DotProduct(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.993168</td>\n",
" <td>0.990168</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.884821</td>\n",
" <td>0.911269</td>\n",
" <td>00:12</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.671865</td>\n",
" <td>0.875679</td>\n",
" <td>00:12</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.471727</td>\n",
" <td>0.878200</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:11</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.361314</td>\n",
" <td>0.884209</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 5e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProduct(Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
" self.user_factors = Embedding(n_users, n_factors)\n",
" self.movie_factors = Embedding(n_movies, n_factors)\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors(x[:,0])\n",
" movies = self.movie_factors(x[:,1])\n",
" return sigmoid_range((users * movies).sum(dim=1), *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.973745</td>\n",
" <td>0.993206</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.869132</td>\n",
" <td>0.914323</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.676553</td>\n",
" <td>0.870192</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.485377</td>\n",
" <td>0.873865</td>\n",
2020-03-06 18:19:03 +00:00
" <td>00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
2020-04-28 17:12:59 +00:00
" <td>0.377866</td>\n",
" <td>0.877610</td>\n",
" <td>00:11</td>\n",
2020-03-06 18:19:03 +00:00
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProduct(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProductBias(Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
" self.user_factors = Embedding(n_users, n_factors)\n",
" self.user_bias = Embedding(n_users, 1)\n",
" self.movie_factors = Embedding(n_movies, n_factors)\n",
" self.movie_bias = Embedding(n_movies, 1)\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors(x[:,0])\n",
" movies = self.movie_factors(x[:,1])\n",
" res = (users * movies).sum(dim=1, keepdim=True)\n",
" res += self.user_bias(x[:,0]) + self.movie_bias(x[:,1])\n",
" return sigmoid_range(res, *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.929161</td>\n",
" <td>0.936303</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.820444</td>\n",
" <td>0.861306</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.621612</td>\n",
" <td>0.865306</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.404648</td>\n",
" <td>0.886448</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.292948</td>\n",
" <td>0.892580</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProductBias(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Weight decay"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAdsAAAFtCAYAAABP6cBcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9d5Qd53mn+Xw3d8453c6NQAAkAnMCs0SlUSQlWcGybI8kh7U9Y+9Zz3hm5bX3jPfIchjLiqRF5ZwoRjGDJAiQSA2gc84531z7R3U1SBEgGuiq+qrq1nMODo4Ooe97u/t2/erNQlEUXFxcXFxcXIzDI9sAFxcXFxcXp+OKrYuLi4uLi8G4Yuvi4uLi4mIwrti6uLi4uLgYjCu2Li4uLi4uBuOKrYuLi4uLi8G4Yuvi4uLi4mIwmxJbIcTTQoiIEGJ5/U+H0Ya5uLi4uLg4hUvxbD+rKEr2+p9WwyxycXFxcXFxGG4Y2cXFxcXFxWAuRWz/TggxLYR4QQhxi1EGubi4uLi4OA2xmdnIQoirgdNADPgQ8C/AHkVRen7r330a+DRAVlbW3ra2Nt0NtjzjJyGUB/m1si0xlJSS4szsGcoyyyjOKDbt3uT8PPHhEYItzYhAwLR7N8vC5BqpVIqC8izZpkhhfmIVgPyyTMmWvBklkSB6tgNfZQW+wkLT7p2NzDK2MkZrQSs+j8+0e6WwOAIr01C+C4SQbY3pHD16dFpRlJLz/bdNie2b/k9CPAL8SlGUf77Qv9m3b59y5MiRSz7b9nznPpg6C3/0mmxLDOe2H9zGNRXX8Lc3/K1pd64cOsTgJ3+X2gcfJOvqA6bdu1ke+fIpZkeXuf9vrpFtihR+8HevEMoO8I7P7ZZtyptYO3mK/ve/n+p//RdybrvNtHs//9Lnebj3YV647wWE0wXoSzdARgF87BeyLZGCEOKooij7zvffLjdnqwAO/9RcJnXXw2wvLI7JtsRw6vPq6V/oN/VOX3kFAIlxa35//SEv8WhSthnSiEeT+INe2Wacl/j6Z8ZXXm7qvf0L/dTn1TtfaNfmYPwU1N0g2xJLclGxFULkCyHuEkKEhBA+IcSHgZuAR403z4aEr1f/HnhBrh0mEM4N07fQh5lrGv0V6oMyPjZu2p2Xgj/oiq0/ZE2xTax/ZvwVFabe27fYRzgvbOqdUhh8CVDOPQNd3sBmPFs/8HlgCpgGPge8W1EUt9f2fJTvgmAu9D8v2xLDqc+rZym+xExkxrQ7PRkZePPyNrwUqxEIeolHkqa+gFiJeDRJwLKe7TgiGMRbUGDanSvxFSZXJ6nPqzftTmn0Pw/eIFSdN4qa9lw0W68oyhSw3wRbnIHHC7XXpIVnW5+rPkD6FvpMLZLyVVRseClWwx/ykkopJBMpfH5rio5RKIpCLGLdMHJifAxfeZmp4dz+xX5AjQI5noEXoHof+EOyLbEkbp+tEYRvgOlOWJqQbYmhaG/rfQt9pt7rLy8nPm5RsV0XmnQMJScTKZSUYtkwcnxsHH+5ySHk9d8Nx3u2kQUYO64++1zOiyu2RqAVCDjcuy3PKifDl2G62PoqrCy2arAoHkk/sdVeMLTvgdWIj4/jN7k4qne+F6/wUpvj7FZABl8GJaUWiLqcF1dsjaBiNwSyHS+2Qgjq8+rpXeg19V5/eQWphQVSq6um3rsZ0tmz1V4wrBhGVpJJEpOT+CrMFdu+hT5qcmrwe/2m3ms6A8+Dxw/VbsbxQrhiawReH9RcDf3OFluAhrwG88VWq0i2oHerhVDTUmyj1hXbxNQUJJOmh5F7F3ppyGsw9U4p9L8AVXshYL1hJlbBFVujCN8AU2fUaSoOpiGvgfGVcVbiK6bdqfVJxsesV5GsVeKmcxg5YMGcrfZZ8Zvo2cZTcQYXB2nId7jYRpdg9DW35eciuGJrFOEb1b/7n5Nrh8Fob+1m5m21PsmEhT3bWDQh2RLziUXUr9mSnu36Z8Vnomc7tDhEQkk437MdfAmU5Llnnst5ccXWKCr3qHnbPmeLbX2+WmVpZijZV1YGWHOwRVrnbLUwsiU9W22ghXmerfY74Xix7XtWzdfWXC3bEkvjiq1ReP1Qe63jPduanBp8wkfvvHli6wkE8BYXW3KwhVuNbE3PNj4+hiczE09Ojml3amLr+Laf/ufUwig3X/uWuGJrJPU3Or7f1u/xU5tbK6EiudySgy3SukAqYt3Wn8TYOL6KClMHWvQu9FKeVU6m38EipPXX1rsh5Ivhiq2RpFHe1vTBFhbttfX5PSDSVGytHEaW1GPr+BDywCG1v9bN114UV2yNZGNOsrPFtj6vnqGlIeLJuGl3+sorSIyNWW4GsRBCXUaQpmFkIdZfOCxGfHzM1B7blJKif7Hf+WKrzUN2+2svivV+K5yE1wd11zm+SKohv4GkkmRgccC0O/3l5aRWV0ktLZl252YJBL1pW43sD3ott0pOicVITs+Y2mM7vjLOWmLN+fnavmeh5oA7D3kTuGJrNOEbYbYHFkdlW2IY2tu7mXlbK6/a84d8aRtGtmRx1OQkKIpbiaw3q7MwftINIW8SV2yNRisccPDKPW2jiantPxZeIp+uO23VXbZWLI4yf2m8Vp3v6IEWA4cAxS2O2iSu2BpN2RUQylfDLQ4l059JZVal69muk7Y5W4uu19MK6cxcGt+70Et+MJ/CUKFpd5pO/3Pgy1DHNLpcFFdsjcbjUTdhOL1IKr/e1IpkX0kJeDzW7LUNpbFna0Wx1QZamOjZ9i30OTuEDGotSu3V4AvKtsQWuGJrBvU3wlw/zA/JtsQwGvIa6F/oJ6WkTLlP+Hz4Skut2Wub1mFk64ltYnwMT14enkzz+l17F3qdXRy1Mg2T7W6+9hJwxdYM0qDftiGvgUgywuiyeYVgVl0i7w96N+YEpxNaNbLVUJfGm+fVzkZmmY/OO9uz1WpQ6m+Sa4eNcMXWDEq3Q0aho1uAZFQkq0vkrRdGDgTTtxo5YEWxNXmgRVoUR/U/B/4sqLxStiW2wRVbM/B41JV7Dq5IlrL9p7yCxPiE5QZbaDlbq9llNGrO1prVyGYOtEiLtp/+56H2GnUGvMumcMXWLOpvgoVBmDV3rKFZ5IfUykuzK5KVaJTk3Jxpd24Gf9ALSnqNbFRSiiVztqm1NZLz86YOtOhb6CPDl0F5lrnjIU1jaQKmzrotP5eIK7ZmoeU2HNwCVJ9Xb+r2H6sukQ9krG/+SSOxjceSoEDAYn2259p+zPVsw7lhPMKhj1et9qT+Zrl22AyHfhosSHELZJdD3zOyLTGMxrxGehd6TQufWnWJfEBbIL+WPkVSsTX1xSKQYS3PVsbS+N6FXhrzG027z3R6n4ZQHlTslm2JrXDF1iyEUL3bvmfBobm8hvwGFmOLzERmTLnPX27NwRaadxdLo8EWWvW15Txbk5fGr8ZXGV8Zd3a+tu8ZtcPCY60XK6vjiq2ZNNwMK1MweUa2JYag9RWaVSTlLSoCv99yIxs17y6d2n+0r9VqOVutWt2sUY3aZ9+xYjvXD/ODbgj5MnDF1kw28rbODCVrD5ie+R5T7hMeD/6yMst5ttp84Pha+ni28Y0wsrU828TYON6iIjyBgCn39Syon/36fIcOtOhdf3a5/bWXjCu2ZpJfCwX1ji2SKsssI9ufbZrYgjUHW5wLI6efZ2u5MLLJPbY98z34PD5qcmpMu9NU+p6F7DIoaZVtie1wxdZsGm5We9SSznsQCyFoyG/YeLs3A19FxcZWF6uQzmHkgMXCyInxMVO3/fTM9xDODeP3OLD/VFFUsa2/Sa1BcbkkXLE1m/qbILoIY8dlW2IITflN5nu2k5MoSeuEbDc827SsRraYZ2vyqMbu+W6a8ptMu89Ups7CyqSbr71MXLE1G+2D2ve0VDOMoiGvgdnILLORWVPu81WUQyJBYtqcCujN4PV58Po9GwKUDljRs00uL5NaXja1EnlkecS5YxrdfO2WcMXWbLKKoWznuQ+uw9De6s3ybjd6bcfMW4CwGQKh9FpGEFtL4At48Hit80iJj6qfCZ9Je2y1SmTHerZ9z0BBGAr
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.linspace(-2,2,100)\n",
"a_s = [1,2,5,10,50] \n",
"ys = [a * x**2 for a in a_s]\n",
"_,ax = plt.subplots(figsize=(8,6))\n",
"for a,y in zip(a_s,ys): ax.plot(x,y, label=f'a={a}')\n",
"ax.set_ylim([0,5])\n",
"ax.legend();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.972090</td>\n",
" <td>0.962366</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.875591</td>\n",
" <td>0.885106</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.723798</td>\n",
" <td>0.839880</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.586002</td>\n",
" <td>0.823225</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.490980</td>\n",
" <td>0.823060</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProductBias(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating our own Embedding module"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#0) []"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class T(Module):\n",
" def __init__(self): self.a = torch.ones(3)\n",
"\n",
"L(T().parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#1) [Parameter containing:\n",
"tensor([1., 1., 1.], requires_grad=True)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class T(Module):\n",
" def __init__(self): self.a = nn.Parameter(torch.ones(3))\n",
"\n",
"L(T().parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#1) [Parameter containing:\n",
"tensor([[-0.9595],\n",
" [-0.8490],\n",
" [ 0.8159]], requires_grad=True)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class T(Module):\n",
" def __init__(self): self.a = nn.Linear(1, 3, bias=False)\n",
"\n",
"t = T()\n",
"L(t.parameters())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.nn.parameter.Parameter"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(t.a.weight)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def create_params(size):\n",
" return nn.Parameter(torch.zeros(*size).normal_(0, 0.01))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DotProductBias(Module):\n",
" def __init__(self, n_users, n_movies, n_factors, y_range=(0,5.5)):\n",
" self.user_factors = create_params([n_users, n_factors])\n",
" self.user_bias = create_params([n_users])\n",
" self.movie_factors = create_params([n_movies, n_factors])\n",
" self.movie_bias = create_params([n_movies])\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" users = self.user_factors[x[:,0]]\n",
" movies = self.movie_factors[x[:,1]]\n",
" res = (users*movies).sum(dim=1)\n",
" res += self.user_bias[x[:,0]] + self.movie_bias[x[:,1]]\n",
" return sigmoid_range(res, *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.962146</td>\n",
" <td>0.936952</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.858084</td>\n",
" <td>0.884951</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.740883</td>\n",
" <td>0.838549</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.592497</td>\n",
" <td>0.823599</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.473570</td>\n",
" <td>0.824263</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = DotProductBias(n_users, n_movies, 50)\n",
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Interpreting embeddings and biases"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Children of the Corn: The Gathering (1996)',\n",
" 'Lawnmower Man 2: Beyond Cyberspace (1996)',\n",
" 'Beautician and the Beast, The (1997)',\n",
" 'Crow: City of Angels, The (1996)',\n",
" 'Home Alone 3 (1997)']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2020-03-31 20:57:32 +00:00
"movie_bias = learn.model.movie_bias.squeeze()\n",
2020-03-06 18:19:03 +00:00
"idxs = movie_bias.argsort()[:5]\n",
"[dls.classes['title'][i] for i in idxs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['L.A. Confidential (1997)',\n",
" 'Titanic (1997)',\n",
" 'Silence of the Lambs, The (1991)',\n",
" 'Shawshank Redemption, The (1994)',\n",
" 'Star Wars (1977)']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idxs = movie_bias.argsort(descending=True)[:5]\n",
"[dls.classes['title'][i] for i in idxs]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAyQAAAKuCAYAAABQVtgOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzddXhV9R/A8fe5d9vdXfdYMdgGo3N0SwqCCCIiiCgmFiBpYiAoISqgoIgIoiJId0h3p6NjyYJ13Dq/PwZXxoIBk/mTz+t5eB7u+eY5l7HzOd84iqqqCCGEEEIIIURZ0JR1B4QQQgghhBAPLglIhBBCCCGEEGVGAhIhhBBCCCFEmZGARAghhBBCCFFmJCARQgghhBBClBmbsu7AnfDy8lIrVKhQ1t0QQgghhBD/cQcOHEhUVdW7rPvxIPi/CkgqVKjA/v37y7obQgghhBDiP05RlEtl3YcHhUzZEkIIIYQQQpQZCUiEEEIIIYQQZUYCEiGEEEIIIUSZkYBECCGEEEIIUWYkIBFCCCGEEEKUGQlIhBBCCCGEEGVGAhIhhBBCCCFEmZGARAghhBBCCFFmJCARQgghhBBClBkJSIQQQgghhBBlRgISIYQQQgghRJmRgEQIIYQQQghRZiQgEUIIIYQQQpQZCUiEEEIIIYQQZUYCEiGEEEIIIUSZkYBECCGEEEIIUWYkIBFCCCGEEEKUGQlIhBBCCCGEEGVGAhIhhBBCCCFEmZGARAghhBBCCFFmJCARQgghhBBClBkJSIQQQgghhBBlplQDEkVRPBRFWawoSqaiKJcURXmqiHw6RVG+VRQlXlGUZEVRliuKElCafRFCCCGEEEL8+5X2CMk0wAD4An2BbxRFqV5IvjeBJkAtwB9IAb4u5b4IIYQQQggh/uVsSqsiRVEcgZ5ADVVVM4DtiqIsA54GRt2SvSKwVlXV+OtlfwUml1ZfhLidmE1nOP3TPlDBYjDjGu5N/TGdAPhr1m4q92+AxlZ7T20k7LvMsS+28ND8pwEwZuSypvNMarzZkoo9awNwdv5B0s4mUO/9jvd2QkUw5RjZMWgRzab1xEZvy9n5B7m07DiZUSk0/Kwr5ZpVtOZNPh7Lia+2YcoxorXVUmvEQ7iF+2Axmtn6/G/WfOZcE1kxqXRc8QI2etsi0+xc7Nnx2iLqvN0OR3/Xf+T8hBBCCPH/r9QCEqAyYFZV9fRNx44ArQrJOwv4UlGUG6MjfYHVhVWqKMqLwIsA5cuXL8XuigdVTmImRyf9Sasf+qD3dUZVVdLOJFrTT/+wl7A+9e44ILGYLGhs/h509KjlT1ZMKjnJWdh7OJB8NAa3cB8SD0VbA5KkQ1H4tQq9p3aKc2HhEfxbh2KjtwXAs24A5VqGcGT8xnz5VFVl/zurqP9hJzzrBJB0JIaDH66lzc/90NhqaT3n79mX5347ROL+K9i52AMUmxbyRB0iZ+2h3nsd7ugchRBCCPHgKM2AxAlIveVYKuBcSN7TwGUgGjADx4DXCqtUVdWZwEyAiIgItbQ6Kx5cucmZKFottq55N82KouBa2RuAo5P+BGDby7+jKArNpvYkftdFzv9+GIvRDED111rgHREEwPqesyn/SHUSD1zBMcCVOqPbWdvR6mxwq1qOpINRBLSrnBeIPF6byNl7AVDNFpKOxlBjcF7MfmLqNpIORWMxWbBztafO2+1wKOdCVmwaWwf+SoWetUjcf4XADlXQeTnw18zdKBoF1Wyh5tDWeNULLHCul5Yep+nXPayf3av6FnpNDCnZGDNy8ayTt5TLs7Y/OQkZpEYm4FbFJ1/eK6tOUfnZhoXWc2uab9MKHPlsE6ZMAzaOdoWWEUIIIcSDrTQDkgzA5ZZjLkB6IXm/AewBTyATGEHeCEmjUuyPEIVyCfPGvZovG3rMxrNuAB61/AnqVAU7Vz213mrDxT+O0eLbXtg45N1A+zQqT0D7yiiKQsala+x88w86LBlorS83KZNmU3sW2pZX3QASD0UT0K4ySYejCe1dl6j1p0k7n4Ql14Stow7HgLzpTGH9Iqj+WgsALi07zsnpO4j46GEADKk5OAd7UGVgYwA2PzOfWm+1xrNOAKrZginHWKDt7Ph0zDkmHMrd+mNZkM7dATtXPbHbzuHXIpS47ecxZRnJikvLF5CknIonJzGTcs0rFqijsDSNjRaXEE+Sj8Xg07jCbfshhBBCiAdPaQYkpwEbRVEqqap65vqx2sCJQvLWBt5RVTUZQFGUr4GPFEXxUlU1sZD8oox0n/YmdlpbbG1ssVgsPNusOx2qNym2zHdbF5FlzOHNtn3vqK1X5n1CXFoijnZ6AII9/Rj72Bv8cXADuSYjfRo+XGTZmJQE9lw4xmN1H7IeG/zb5wzr8AyB7nmjAksORTNhbSQxKdn4uzkw/NXmeOUYid16nnPzD9J6bl/rVKObZUan8tcHa8hJyCTWMZ3VwSdomZSJvacjCyocJ0rZS/Kn3/PnsFk42P1dfvmRLczLWUZWegaVftlH21wv7L0c8awbwKLNq1mZtA9TfSNrFiTxfteXSdt9hYuLjmLKNqKaVRYGHOPVT+exqt8XaOy0ZNeyZ/BvnzOl9wi86gVyYup2/NuE4dOkAi4hngX6nZ2Qgc7DocTXv8G4LpycvoPTP+zFvXo5nCt4FJgadnnlSQI7hqOxKTidrag0nacD2VczStwPIYQQQjxYSi0gUVU1U1GUP8gLLJ4H6gCPAk0Lyb4P6K8oymYgCxgExEgw8u80rsebhPoEERl3kRd+GkPDijVwcyhsJt69e6t9f5pXqpfvWI967YrI/bfY1ASWHNqULyCZ0nuE9e9LDkUz+o9jZF+fdhWdks3o3ZcZ16Mm3XvWZlPfuSQejMK/dViBug98sIbqr7fAr2Uob/76Gc1+DcRiyKunfpI/7w7syRNL3slX5kJiNDO2/M6PT3/E7h6/cKaCLZurRtEFyAzR8svGzYxI60DlVjXY4BrJ16vmETFLQ4vve+Po78qq9RtQVx+11mejtyXMpzy2Whv2XzxBxJstSTuXSOKBKPa/u4rQJ+sS3K1Gvj5o7Www55pue+1ucAv3oemXjwFgMZpZ2/V7nCp4WNPNuSaiN5ym2fTHC5QtPs2MVleazz6EEEII8V9S2ncJg4AfgKtAEvCKqqonFEVpAaxWVdXper5hwFfAGcAOOA48Vsp9EaUsvFwFHOz0xKRc5ff96/KNghQ1KrLi6BbWHt+Jg05P1LV4XPVOjOn2Cj7OHoU1Uahb6/5x51LWndiJomjQ2+qY2f99Jqz9kZiUBPp9P5pAd1/G9xxM92lvMqnXMEJ9gvh8zT50tlvR2+WitYB9SjWuUIEJayMZt3oEbZ0qMPvUCTJO5tAs2JP2mQbrlC1jhgEHPxfiUhM5f+UybRPCrX0LTffA3b5gcHY+IYpKvsF4uXvgXs0Xz21xLAm/AECCcza+6Y4YjiTjNSSQpoqel3d+TAObZth7OpKSmcbsvcvof6UmB71i8tXboVoTlh3ZTBXFH5dQL1xCvTBlG0k5FV8gIHEKdic3KROzwYTW7vY/6jnXR30Azszdj2edAJwC3azpsVvO4RjoVuhoTHFpGZeScanU+LbtCyGEEOLBVKoByfUpWN0LOb6NvEXvNz4nkbezlvg/sv/iCXJNBoI8yt1RuSNRkcwd+CnBnv58v20Rk9f9xPiegwvNO2n9T3y75XcAejfoRNfa+TdpW3l0K9vOHGRm/w9w0jmQmpWORtEwvOMAvto4nznPfVJovTmmTeQYw8gxhuJuTAK3jQzdbcbGbMec1lC+USU+fW4AR65EMjJtEnVf/wOtzoZmU3tS480W7Bu9kuPBSVTw88LOteC0rltV8inPqdjzxKRcxaNuAL/t3kWOxUBqdgbhfhWIcUon3cOCvpwzazctJ9uUi0ubQP7sN4/fq5yiZ7km2FtyCtRbM6ASk9fP5eTuHWReSUHRarB11lFnVNsCebU6G7zqBZJ0KBqfRsEAnP35AOd/P4whJZvDY9ejsdPS5ud+2DrquLT0OFHrIlEtKm5VfKjzdv6RqcsrT1K
"text/plain": [
"<Figure size 864x864 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"g = ratings.groupby('title')['rating'].count()\n",
"top_movies = g.sort_values(ascending=False).index.values[:1000]\n",
"top_idxs = tensor([learn.dls.classes['title'].o2i[m] for m in top_movies])\n",
2020-03-31 20:57:32 +00:00
"movie_w = learn.model.movie_factors[top_idxs].cpu().detach()\n",
2020-03-06 18:19:03 +00:00
"movie_pca = movie_w.pca(3)\n",
"fac0,fac1,fac2 = movie_pca.t()\n",
"idxs = np.random.choice(len(top_movies), 50, replace=False)\n",
"idxs = list(range(50))\n",
"X = fac0[idxs]\n",
"Y = fac2[idxs]\n",
"plt.figure(figsize=(12,12))\n",
"plt.scatter(X, Y)\n",
"for i, x, y in zip(top_movies[idxs], X, Y):\n",
" plt.text(x,y,i, color=np.random.rand(3)*0.7, fontsize=11)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using fastai.collab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = collab_learner(dls, n_factors=50, y_range=(0, 5.5))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.931751</td>\n",
" <td>0.953806</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.851826</td>\n",
" <td>0.878119</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.715254</td>\n",
" <td>0.834711</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.583173</td>\n",
" <td>0.821470</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.496625</td>\n",
" <td>0.821688</td>\n",
" <td>00:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"EmbeddingDotBias(\n",
" (u_weight): Embedding(944, 50)\n",
" (i_weight): Embedding(1635, 50)\n",
" (u_bias): Embedding(944, 1)\n",
" (i_bias): Embedding(1635, 1)\n",
")"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Titanic (1997)',\n",
" \"Schindler's List (1993)\",\n",
" 'Shawshank Redemption, The (1994)',\n",
" 'L.A. Confidential (1997)',\n",
" 'Silence of the Lambs, The (1991)']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movie_bias = learn.model.i_bias.weight.squeeze()\n",
"idxs = movie_bias.argsort(descending=True)[:5]\n",
"[dls.classes['title'][i] for i in idxs]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embedding distance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Dial M for Murder (1954)'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"movie_factors = learn.model.i_weight.weight\n",
"idx = dls.classes['title'].o2i['Silence of the Lambs, The (1991)']\n",
"distances = nn.CosineSimilarity(dim=1)(movie_factors, movie_factors[idx][None])\n",
"idx = distances.argsort(descending=True)[1]\n",
"dls.classes['title'][idx]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Boot strapping a collaborative filtering model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deep learning for collaborative filtering"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(944, 74), (1635, 101)]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"embs = get_emb_sz(dls)\n",
"embs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class CollabNN(Module):\n",
" def __init__(self, user_sz, item_sz, y_range=(0,5.5), n_act=100):\n",
" self.user_factors = Embedding(*user_sz)\n",
" self.item_factors = Embedding(*item_sz)\n",
" self.layers = nn.Sequential(\n",
" nn.Linear(user_sz[1]+item_sz[1], n_act),\n",
" nn.ReLU(),\n",
" nn.Linear(n_act, 1))\n",
" self.y_range = y_range\n",
" \n",
" def forward(self, x):\n",
" embs = self.user_factors(x[:,0]),self.item_factors(x[:,1])\n",
" x = self.layers(torch.cat(embs, dim=1))\n",
" return sigmoid_range(x, *self.y_range)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = CollabNN(*embs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.940104</td>\n",
" <td>0.959786</td>\n",
" <td>00:15</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.893943</td>\n",
" <td>0.905222</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.865591</td>\n",
" <td>0.875238</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.800177</td>\n",
" <td>0.867468</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.760255</td>\n",
" <td>0.867455</td>\n",
" <td>00:14</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, model, loss_func=MSELossFlat())\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>1.002747</td>\n",
" <td>0.972392</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.926903</td>\n",
" <td>0.922348</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.877160</td>\n",
" <td>0.893401</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.838334</td>\n",
" <td>0.865040</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.781666</td>\n",
" <td>0.864936</td>\n",
" <td>00:16</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = collab_learner(dls, use_nn=True, y_range=(0, 5.5), layers=[100,50])\n",
"learn.fit_one_cycle(5, 5e-3, wd=0.1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@delegates(TabularModel)\n",
"class EmbeddingNN(TabularModel):\n",
" def __init__(self, emb_szs, layers, **kwargs):\n",
" super().__init__(emb_szs, layers=layers, n_cont=0, out_sz=1, **kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sidebar: kwargs and delegates"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### End sidebar"
]
},
{
"cell_type": "markdown",
2020-04-23 13:41:55 +00:00
"metadata": {},
"source": [
"## Conclusion"
]
},
{
"cell_type": "markdown",
2020-03-06 18:19:03 +00:00
"metadata": {},
"source": [
"## Questionnaire"
]
},
2020-03-18 00:34:07 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. What problem does collaborative filtering solve?\n",
"1. How does it solve it?\n",
"1. Why might a collaborative filtering predictive model fail to be a very useful recommendation system?\n",
"1. What does a crosstab representation of collaborative filtering data look like?\n",
"1. Write the code to create a crosstab representation of the MovieLens data (you might need to do some web searching!)\n",
"1. What is a latent factor? Why is it \"latent\"?\n",
"1. What is a dot product? Calculate a dot product manually using pure python with lists.\n",
"1. What does `pandas.DataFrame.merge` do?\n",
"1. What is an embedding matrix?\n",
"1. What is the relationship between an embedding and a matrix of one-hot encoded vectors?\n",
"1. Why do we need `Embedding` if we could use one-hot encoded vectors for the same thing?\n",
"1. What does an embedding contain before we start training (assuming we're not using a prertained model)?\n",
"1. Create a class (without peeking, if possible!) and use it.\n",
"1. What does `x[:,0]` return?\n",
"1. Rewrite the `DotProduct` class (without peeking, if possible!) and train a model with it\n",
"1. What is a good loss function to use for MovieLens? Why? \n",
"1. What would happen if we used `CrossEntropy` loss with MovieLens? How would we need to change the model?\n",
"1. What is the use of bias in a dot product model?\n",
"1. What is another name for weight decay?\n",
"1. Write the equation for weight decay (without peeking!)\n",
"1. Write the equation for the gradient of weight decay. Why does it help reduce weights?\n",
"1. Why does reducing weights lead to better generalization?\n",
"1. What does `argsort` do in PyTorch?\n",
"1. Does sorting the movie biases give the same result as averaging overall movie ratings by movie? Why / why not?\n",
"1. How do you print the names and details of the layers in a model?\n",
"1. What is the \"bootstrapping problem\" in collaborative filtering?\n",
"1. How could you deal with the bootstrapping problem for new users? For new movies?\n",
"1. How can feedback loops impact collaborative filtering systems?\n",
"1. When using a neural network in collaborative filtering, why can we have different number of factors for movie and user?\n",
"1. Why is there a `nn.Sequential` in the `CollabNN` model?\n",
"1. What kind of model should be use if we want to add metadata about users and items, or information such as date and time, to a collaborative filter model?"
]
},
2020-03-06 18:19:03 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further research\n",
"\n",
"1. Take a look at all the differences between the `Embedding` version of `DotProductBias` and the `create_params` version, and try to understand why each of those changes is required. If you're not sure, try reverting each change, to see what happens. (NB: even the type of brackets used in `forward` has changed!)\n",
"1. Find three other areas where collaborative filtering is being used, and find out what pros and cons of this approach in those areas.\n",
"1. Complete this notebook using the full MovieLens dataset, and compare your results to online benchmarks. See if you can improve your accuracy. Look on the book website and forum for ideas. Note that there are more columns in the full dataset--see if you can use those too (the next chapter might give you ideas)\n",
"1. Create a model for MovieLens with works with CrossEntropy loss, and compare it to the model in this chapter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}