mirror of
https://github.com/fastai/fastbook.git
synced 2025-04-04 01:40:44 +00:00
945 lines
24 KiB
Plaintext
945 lines
24 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab\n",
|
|
"import fastbook\n",
|
|
"fastbook.setup_book()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#hide\n",
|
|
"from fastbook import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# A fastai Learner from Scratch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"path = untar_data(URLs.IMAGENETTE_160)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"t = get_image_files(path)\n",
|
|
"t[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from glob import glob\n",
|
|
"files = L(glob(f'{path}/**/*.JPEG', recursive=True)).map(Path)\n",
|
|
"files[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"im = Image.open(files[0])\n",
|
|
"im"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"im_t = tensor(im)\n",
|
|
"im_t.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lbls = files.map(Self.parent.name()).unique(); lbls"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"v2i = lbls.val2idx(); v2i"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Dataset:\n",
|
|
" def __init__(self, fns): self.fns=fns\n",
|
|
" def __len__(self): return len(self.fns)\n",
|
|
" def __getitem__(self, i):\n",
|
|
" im = Image.open(self.fns[i]).resize((64,64)).convert('RGB')\n",
|
|
" y = v2i[self.fns[i].parent.name]\n",
|
|
" return tensor(im).float()/255, tensor(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_filt = L(o.parent.parent.name=='train' for o in files)\n",
|
|
"train,valid = files[train_filt],files[~train_filt]\n",
|
|
"len(train),len(valid)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_ds,valid_ds = Dataset(train),Dataset(valid)\n",
|
|
"x,y = train_ds[0]\n",
|
|
"x.shape,y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"show_image(x, title=lbls[y]);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def collate(idxs, ds): \n",
|
|
" xb,yb = zip(*[ds[i] for i in idxs])\n",
|
|
" return torch.stack(xb),torch.stack(yb)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x,y = collate([1,2], train_ds)\n",
|
|
"x.shape,y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DataLoader:\n",
|
|
" def __init__(self, ds, bs=128, shuffle=False, n_workers=1):\n",
|
|
" self.ds,self.bs,self.shuffle,self.n_workers = ds,bs,shuffle,n_workers\n",
|
|
"\n",
|
|
" def __len__(self): return (len(self.ds)-1)//self.bs+1\n",
|
|
"\n",
|
|
" def __iter__(self):\n",
|
|
" idxs = L.range(self.ds)\n",
|
|
" if self.shuffle: idxs = idxs.shuffle()\n",
|
|
" chunks = [idxs[n:n+self.bs] for n in range(0, len(self.ds), self.bs)]\n",
|
|
" with ProcessPoolExecutor(self.n_workers) as ex:\n",
|
|
" yield from ex.map(collate, chunks, ds=self.ds)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"n_workers = min(16, defaults.cpus)\n",
|
|
"train_dl = DataLoader(train_ds, bs=128, shuffle=True, n_workers=n_workers)\n",
|
|
"valid_dl = DataLoader(valid_ds, bs=256, shuffle=False, n_workers=n_workers)\n",
|
|
"xb,yb = first(train_dl)\n",
|
|
"xb.shape,yb.shape,len(train_dl)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"stats = [xb.mean((0,1,2)),xb.std((0,1,2))]\n",
|
|
"stats"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Normalize:\n",
|
|
" def __init__(self, stats): self.stats=stats\n",
|
|
" def __call__(self, x):\n",
|
|
" if x.device != self.stats[0].device:\n",
|
|
" self.stats = to_device(self.stats, x.device)\n",
|
|
" return (x-self.stats[0])/self.stats[1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"norm = Normalize(stats)\n",
|
|
"def tfm_x(x): return norm(x).permute((0,3,1,2))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"t = tfm_x(x)\n",
|
|
"t.mean((0,2,3)),t.std((0,2,3))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Module and Parameter"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Parameter(Tensor):\n",
|
|
" def __new__(self, x): return Tensor._make_subclass(Parameter, x, True)\n",
|
|
" def __init__(self, *args, **kwargs): self.requires_grad_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"Parameter(tensor(3.))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Module:\n",
|
|
" def __init__(self):\n",
|
|
" self.hook,self.params,self.children,self._training = None,[],[],False\n",
|
|
" \n",
|
|
" def register_parameters(self, *ps): self.params += ps\n",
|
|
" def register_modules (self, *ms): self.children += ms\n",
|
|
" \n",
|
|
" @property\n",
|
|
" def training(self): return self._training\n",
|
|
" @training.setter\n",
|
|
" def training(self,v):\n",
|
|
" self._training = v\n",
|
|
" for m in self.children: m.training=v\n",
|
|
" \n",
|
|
" def parameters(self):\n",
|
|
" return self.params + sum([m.parameters() for m in self.children], [])\n",
|
|
"\n",
|
|
" def __setattr__(self,k,v):\n",
|
|
" super().__setattr__(k,v)\n",
|
|
" if isinstance(v,Parameter): self.register_parameters(v)\n",
|
|
" if isinstance(v,Module): self.register_modules(v)\n",
|
|
" \n",
|
|
" def __call__(self, *args, **kwargs):\n",
|
|
" res = self.forward(*args, **kwargs)\n",
|
|
" if self.hook is not None: self.hook(res, args)\n",
|
|
" return res\n",
|
|
" \n",
|
|
" def cuda(self):\n",
|
|
" for p in self.parameters(): p.data = p.data.cuda()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class ConvLayer(Module):\n",
|
|
" def __init__(self, ni, nf, stride=1, bias=True, act=True):\n",
|
|
" super().__init__()\n",
|
|
" self.w = Parameter(torch.zeros(nf,ni,3,3))\n",
|
|
" self.b = Parameter(torch.zeros(nf)) if bias else None\n",
|
|
" self.act,self.stride = act,stride\n",
|
|
" init = nn.init.kaiming_normal_ if act else nn.init.xavier_normal_\n",
|
|
" init(self.w)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" x = F.conv2d(x, self.w, self.b, stride=self.stride, padding=1)\n",
|
|
" if self.act: x = F.relu(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"l = ConvLayer(3, 4)\n",
|
|
"len(l.parameters())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"xbt = tfm_x(xb)\n",
|
|
"r = l(xbt)\n",
|
|
"r.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Linear(Module):\n",
|
|
" def __init__(self, ni, nf):\n",
|
|
" super().__init__()\n",
|
|
" self.w = Parameter(torch.zeros(nf,ni))\n",
|
|
" self.b = Parameter(torch.zeros(nf))\n",
|
|
" nn.init.xavier_normal_(self.w)\n",
|
|
" \n",
|
|
" def forward(self, x): return x@self.w.t() + self.b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"l = Linear(4,2)\n",
|
|
"r = l(torch.ones(3,4))\n",
|
|
"r.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class T(Module):\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__()\n",
|
|
" self.c,self.l = ConvLayer(3,4),Linear(4,2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"t = T()\n",
|
|
"len(t.parameters())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"t.cuda()\n",
|
|
"t.l.w.device"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Simple CNN"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Sequential(Module):\n",
|
|
" def __init__(self, *layers):\n",
|
|
" super().__init__()\n",
|
|
" self.layers = layers\n",
|
|
" self.register_modules(*layers)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" for l in self.layers: x = l(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class AdaptivePool(Module):\n",
|
|
" def forward(self, x): return x.mean((2,3))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def simple_cnn():\n",
|
|
" return Sequential(\n",
|
|
" ConvLayer(3 ,16 ,stride=2), #32\n",
|
|
" ConvLayer(16,32 ,stride=2), #16\n",
|
|
" ConvLayer(32,64 ,stride=2), # 8\n",
|
|
" ConvLayer(64,128,stride=2), # 4\n",
|
|
" AdaptivePool(),\n",
|
|
" Linear(128, 10)\n",
|
|
" )"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"m = simple_cnn()\n",
|
|
"len(m.parameters())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def print_stats(outp, inp): print (outp.mean().item(),outp.std().item())\n",
|
|
"for i in range(4): m.layers[i].hook = print_stats\n",
|
|
"\n",
|
|
"r = m(xbt)\n",
|
|
"r.shape"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def nll(input, target): return -input[range(target.shape[0]), target].mean()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def log_softmax(x): return (x.exp()/(x.exp().sum(-1,keepdim=True))).log()\n",
|
|
"\n",
|
|
"sm = log_softmax(r); sm[0][0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"loss = nll(sm, yb)\n",
|
|
"loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def log_softmax(x): return x - x.exp().sum(-1,keepdim=True).log()\n",
|
|
"sm = log_softmax(r); sm[0][0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = torch.rand(5)\n",
|
|
"a = x.max()\n",
|
|
"x.exp().sum().log() == a + (x-a).exp().sum().log()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def logsumexp(x):\n",
|
|
" m = x.max(-1)[0]\n",
|
|
" return m + (x-m[:,None]).exp().sum(-1).log()\n",
|
|
"\n",
|
|
"logsumexp(r)[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def log_softmax(x): return x - x.logsumexp(-1,keepdim=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"sm = log_softmax(r); sm[0][0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def cross_entropy(preds, yb): return nll(log_softmax(preds), yb).mean()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Learner"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class SGD:\n",
|
|
" def __init__(self, params, lr, wd=0.): store_attr()\n",
|
|
" def step(self):\n",
|
|
" for p in self.params:\n",
|
|
" p.data -= (p.grad.data + p.data*self.wd) * self.lr\n",
|
|
" p.grad.data.zero_()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DataLoaders:\n",
|
|
" def __init__(self, *dls): self.train,self.valid = dls\n",
|
|
"\n",
|
|
"dls = DataLoaders(train_dl,valid_dl)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Learner:\n",
|
|
" def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):\n",
|
|
" store_attr()\n",
|
|
" for cb in cbs: cb.learner = self\n",
|
|
"\n",
|
|
" def one_batch(self):\n",
|
|
" self('before_batch')\n",
|
|
" xb,yb = self.batch\n",
|
|
" self.preds = self.model(xb)\n",
|
|
" self.loss = self.loss_func(self.preds, yb)\n",
|
|
" if self.model.training:\n",
|
|
" self.loss.backward()\n",
|
|
" self.opt.step()\n",
|
|
" self('after_batch')\n",
|
|
"\n",
|
|
" def one_epoch(self, train):\n",
|
|
" self.model.training = train\n",
|
|
" self('before_epoch')\n",
|
|
" dl = self.dls.train if train else self.dls.valid\n",
|
|
" for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):\n",
|
|
" self.one_batch()\n",
|
|
" self('after_epoch')\n",
|
|
" \n",
|
|
" def fit(self, n_epochs):\n",
|
|
" self('before_fit')\n",
|
|
" self.opt = self.opt_func(self.model.parameters(), self.lr)\n",
|
|
" self.n_epochs = n_epochs\n",
|
|
" try:\n",
|
|
" for self.epoch in range(n_epochs):\n",
|
|
" self.one_epoch(True)\n",
|
|
" self.one_epoch(False)\n",
|
|
" except CancelFitException: pass\n",
|
|
" self('after_fit')\n",
|
|
" \n",
|
|
" def __call__(self,name):\n",
|
|
" for cb in self.cbs: getattr(cb,name,noop)()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Callbacks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Callback(GetAttr): _default='learner'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class SetupLearnerCB(Callback):\n",
|
|
" def before_batch(self):\n",
|
|
" xb,yb = to_device(self.batch)\n",
|
|
" self.learner.batch = tfm_x(xb),yb\n",
|
|
"\n",
|
|
" def before_fit(self): self.model.cuda()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class TrackResults(Callback):\n",
|
|
" def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]\n",
|
|
" \n",
|
|
" def after_epoch(self):\n",
|
|
" n = sum(self.ns)\n",
|
|
" print(self.epoch, self.model.training,\n",
|
|
" sum(self.losses).item()/n, sum(self.accs).item()/n)\n",
|
|
" \n",
|
|
" def after_batch(self):\n",
|
|
" xb,yb = self.batch\n",
|
|
" acc = (self.preds.argmax(dim=1)==yb).float().sum()\n",
|
|
" self.accs.append(acc)\n",
|
|
" n = len(xb)\n",
|
|
" self.losses.append(self.loss*n)\n",
|
|
" self.ns.append(n)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cbs = [SetupLearnerCB(),TrackResults()]\n",
|
|
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs)\n",
|
|
"learn.fit(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Scheduling the Learning Rate"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class LRFinder(Callback):\n",
|
|
" def before_fit(self):\n",
|
|
" self.losses,self.lrs = [],[]\n",
|
|
" self.learner.lr = 1e-6\n",
|
|
" \n",
|
|
" def before_batch(self):\n",
|
|
" if not self.model.training: return\n",
|
|
" self.opt.lr *= 1.2\n",
|
|
"\n",
|
|
" def after_batch(self):\n",
|
|
" if not self.model.training: return\n",
|
|
" if self.opt.lr>10 or torch.isnan(self.loss): raise CancelFitException\n",
|
|
" self.losses.append(self.loss.item())\n",
|
|
" self.lrs.append(self.opt.lr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lrfind = LRFinder()\n",
|
|
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[lrfind])\n",
|
|
"learn.fit(2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.plot(lrfind.lrs[:-2],lrfind.losses[:-2])\n",
|
|
"plt.xscale('log')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class OneCycle(Callback):\n",
|
|
" def __init__(self, base_lr): self.base_lr = base_lr\n",
|
|
" def before_fit(self): self.lrs = []\n",
|
|
"\n",
|
|
" def before_batch(self):\n",
|
|
" if not self.model.training: return\n",
|
|
" n = len(self.dls.train)\n",
|
|
" bn = self.epoch*n + self.num\n",
|
|
" mn = self.n_epochs*n\n",
|
|
" pct = bn/mn\n",
|
|
" pct_start,div_start = 0.25,10\n",
|
|
" if pct<pct_start:\n",
|
|
" pct /= pct_start\n",
|
|
" lr = (1-pct)*self.base_lr/div_start + pct*self.base_lr\n",
|
|
" else:\n",
|
|
" pct = (pct-pct_start)/(1-pct_start)\n",
|
|
" lr = (1-pct)*self.base_lr\n",
|
|
" self.opt.lr = lr\n",
|
|
" self.lrs.append(lr)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"onecyc = OneCycle(0.1)\n",
|
|
"learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[onecyc])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"learn.fit(8)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.plot(onecyc.lrs);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Conclusion"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Questionnaire"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"> tip: Experiments: For the questions here that ask you to explain what some function or class is, you should also complete your own code experiments."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. What is `glob`?\n",
|
|
"1. How do you open an image with the Python imaging library?\n",
|
|
"1. What does `L.map` do?\n",
|
|
"1. What does `Self` do?\n",
|
|
"1. What is `L.val2idx`?\n",
|
|
"1. What methods do you need to implement to create your own `Dataset`?\n",
|
|
"1. Why do we call `convert` when we open an image from Imagenette?\n",
|
|
"1. What does `~` do? How is it useful for splitting training and validation sets?\n",
|
|
"1. Does `~` work with the `L` or `Tensor` classes? What about NumPy arrays, Python lists, or pandas DataFrames?\n",
|
|
"1. What is `ProcessPoolExecutor`?\n",
|
|
"1. How does `L.range(self.ds)` work?\n",
|
|
"1. What is `__iter__`?\n",
|
|
"1. What is `first`?\n",
|
|
"1. What is `permute`? Why is it needed?\n",
|
|
"1. What is a recursive function? How does it help us define the `parameters` method?\n",
|
|
"1. Write a recursive function that returns the first 20 items of the Fibonacci sequence.\n",
|
|
"1. What is `super`?\n",
|
|
"1. Why do subclasses of `Module` need to override `forward` instead of defining `__call__`?\n",
|
|
"1. In `ConvLayer`, why does `init` depend on `act`?\n",
|
|
"1. Why does `Sequential` need to call `register_modules`?\n",
|
|
"1. Write a hook that prints the shape of every layer's activations.\n",
|
|
"1. What is \"LogSumExp\"?\n",
|
|
"1. Why is `log_softmax` useful?\n",
|
|
"1. What is `GetAttr`? How is it helpful for callbacks?\n",
|
|
"1. Reimplement one of the callbacks in this chapter without inheriting from `Callback` or `GetAttr`.\n",
|
|
"1. What does `Learner.__call__` do?\n",
|
|
"1. What is `getattr`? (Note the case difference to `GetAttr`!)\n",
|
|
"1. Why is there a `try` block in `fit`?\n",
|
|
"1. Why do we check for `model.training` in `one_batch`?\n",
|
|
"1. What is `store_attr`?\n",
|
|
"1. What is the purpose of `TrackResults.before_epoch`?\n",
|
|
"1. What does `model.cuda` do? How does it work?\n",
|
|
"1. Why do we need to check `model.training` in `LRFinder` and `OneCycle`?\n",
|
|
"1. Use cosine annealing in `OneCycle`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Further Research"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"1. Write `resnet18` from scratch (refer to <<chapter_resnet>> as needed), and train it with the `Learner` in this chapter.\n",
|
|
"1. Implement a batchnorm layer from scratch and use it in your `resnet18`.\n",
|
|
"1. Write a Mixup callback for use in this chapter.\n",
|
|
"1. Add momentum to SGD.\n",
|
|
"1. Pick a few features that you're interested in from fastai (or any other library) and implement them in this chapter.\n",
|
|
"1. Pick a research paper that's not yet implemented in fastai or PyTorch and implement it in this chapter.\n",
|
|
" - Port it over to fastai.\n",
|
|
" - Submit a pull request to fastai, or create your own extension module and release it. \n",
|
|
" - Hint: you may find it helpful to use [`nbdev`](https://nbdev.fast.ai/) to create and deploy your package."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"split_at_heading": true
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|