File size: 5,587 Bytes
f50f696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "976fbfea",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.insert(0,'..')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4b164f6b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"from tqdm import tqdm\n",
"\n",
"\n",
"from train import train\n",
"import priors\n",
"import encoders\n",
"import positional_encodings\n",
"import utils\n",
"import bar_distribution\n",
"\n",
"\n",
"from samlib.utils import chunker"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29d423b4",
"metadata": {},
"outputs": [],
"source": [
"mykwargs = \\\n",
"{\n",
" 'bptt': 5*5+1,\n",
"'nlayers': 6,\n",
" 'dropout': 0.0, 'steps_per_epoch': 100,\n",
" 'batch_size': 100}\n",
"mnist_jobs_5shot_pi_prior_search = [\n",
" pretrain_and_eval( {'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n",
" 'translations': False, 'jonas_style': True}, priors.stroke.DataLoader, Losses.ce, enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, nhid=nhid, y_encoder_generator=encoders.get_Canonical(5), lr=lr, epochs=epochs, single_eval_pos_gen=mykwargs['bptt']-1,\n",
" extra_prior_kwargs_dict={'num_features': 28*28, 'fuse_x_y': False, 'num_outputs':5, 'only_train_for_last_idx': True,\n",
" 'min_max_strokes': (1,max_strokes), 'min_max_len': (min_len, max_len), 'min_max_width': (min_width, max_width), 'max_offset': max_offset, 'max_target_offset': max_target_offset},\n",
" **mykwargs)\n",
" for max_strokes, min_len, max_len, min_width, max_width, max_offset, max_target_offset in random_hypers\n",
" for enc in [encoders.Linear] for emsize in [1024] for nhead in [4] for nhid in [emsize*2] for warmup_epochs in [5] for lr in [.00001] for epochs in [128,1024] for _ in range(1)]\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "deb93d1e",
"metadata": {},
"outputs": [],
"source": [
"\n",
"@torch.inference_mode()\n",
"def get_acc(finetuned_model, eval_pos, device='cpu', steps=100, train_mode=False, **mykwargs):\n",
" finetuned_model.to(device)\n",
" finetuned_model.eval()\n",
"\n",
" t_dl = priors.omniglot.DataLoader(steps, batch_size=1000, seq_len=mykwargs['bptt'], train=train_mode,\n",
" **mykwargs['extra_prior_kwargs_dict'])\n",
"\n",
" ps = []\n",
" ys = []\n",
" for x, y in tqdm(t_dl):\n",
" p = finetuned_model(tuple(e.to(device) for e in x), single_eval_pos=eval_pos)\n",
" ps.append(p)\n",
" ys.append(y)\n",
"\n",
" ps = torch.cat(ps, 1)\n",
" ys = torch.cat(ys, 1)\n",
"\n",
" def acc(ps, ys):\n",
" return (ps.argmax(-1) == ys.to(ps.device)).float().mean()\n",
"\n",
" a = acc(ps[eval_pos], ys[eval_pos]).cpu()\n",
" print(a.item())\n",
" return a\n",
"\n",
"\n",
"def train_and_eval(*args, **kwargs):\n",
" r = train(*args, **kwargs)\n",
" model = r[-1]\n",
" acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n",
" model.to('cpu')\n",
" return [acc]\n",
"\n",
"def pretrain_and_eval(extra_prior_kwargs_dict_eval,*args, **kwargs):\n",
" r = train(*args, **kwargs)\n",
" model = r[-1]\n",
" kwargs['extra_prior_kwargs_dict'] = extra_prior_kwargs_dict_eval\n",
" acc = get_acc(model, -1, device='cuda:0', **kwargs).cpu()\n",
" model.to('cpu')\n",
" return r, acc"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "706ecbb7",
"metadata": {},
"outputs": [],
"source": [
"\n",
"emsize = 1024\n",
"# mnist_jobs_5shot_pi[20].result()[-1].state_dict()\n",
"mykwargs = \\\n",
" {'bptt': 5 * 5 + 1,\n",
" 'nlayers': 6,\n",
" 'nhead': 4, 'emsize': emsize,\n",
" 'encoder_generator': encoders.Linear, 'nhid': emsize * 2}\n",
"results = train_and_eval(priors.omniglot.DataLoader, Losses.ce, y_encoder_generator=encoders.get_Canonical(5),\n",
" load_weights_from_this_state_dict=mnist_jobs_5shot_pi_prior_search[67][0][-1].state_dict(), epochs=32, lr=.00001, dropout=dropout,\n",
" single_eval_pos_gen=mykwargs['bptt'] - 1,\n",
" extra_prior_kwargs_dict={'num_features': 28 * 28, 'fuse_x_y': False, 'num_outputs': 5,\n",
" 'translations': True, 'jonas_style': True},\n",
" batch_size=100, steps_per_epoch=200, **mykwargs)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "611554b2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|