ryanjg commited on
Commit
43b05b0
·
verified ·
1 Parent(s): 34fb0ec

setting map location to cpu

Browse files
Files changed (1) hide show
  1. training/k_sparse_autoencoder.py +246 -246
training/k_sparse_autoencoder.py CHANGED
@@ -1,247 +1,247 @@
1
- import os
2
- import json
3
- import torch
4
- from torch import nn
5
-
6
- class SparseAutoencoder(nn.Module):
7
-
8
- def __init__(
9
- self,
10
- n_dirs_local: int,
11
- d_model: int,
12
- k: int,
13
- auxk: int, #| None,
14
- dead_steps_threshold: int,
15
- auxk_coef: float
16
- ):
17
- super().__init__()
18
- self.n_dirs_local = n_dirs_local
19
- self.d_model = d_model
20
- self.k = k
21
- self.auxk = auxk
22
- self.dead_steps_threshold = dead_steps_threshold
23
- self.auxk_coef = auxk_coef
24
- self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
25
- self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
26
-
27
- self.pre_bias = nn.Parameter(torch.zeros(d_model))
28
- self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
29
-
30
- self.stats_last_nostats_last_nonzeronzero: torch.Tensor
31
- self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
32
-
33
- def auxk_mask_fn(x):
34
- dead_mask = self.stats_last_nonzero > dead_steps_threshold
35
- x.data *= dead_mask # inplace to save memory
36
- return x
37
-
38
- self.auxk_mask_fn = auxk_mask_fn
39
- ## initialization
40
-
41
- # "tied" init
42
- self.decoder.weight.data = self.encoder.weight.data.T.clone()
43
-
44
- # store decoder in column major layout for kernel
45
- self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
46
- self.mse_scale = 1
47
- unit_norm_decoder_(self)
48
-
49
- def save_to_disk(self, path: str):
50
- PATH_TO_CFG = 'config.json'
51
- PATH_TO_WEIGHTS = 'state_dict.pth'
52
-
53
- cfg = {
54
- "n_dirs_local": self.n_dirs_local,
55
- "d_model": self.d_model,
56
- "k": self.k,
57
- "auxk": self.auxk,
58
- "dead_steps_threshold": self.dead_steps_threshold,
59
- "auxk_coef": self.auxk_coef
60
- }
61
-
62
- os.makedirs(path, exist_ok=True)
63
-
64
- with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
65
- json.dump(cfg, f)
66
-
67
- torch.save({
68
- "state_dict": self.state_dict(),
69
- }, os.path.join(path, PATH_TO_WEIGHTS))
70
-
71
- @classmethod
72
- def load_from_disk(cls, path: str):
73
- PATH_TO_CFG = 'config.json'
74
- PATH_TO_WEIGHTS = 'state_dict.pth'
75
-
76
- with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
77
- cfg = json.load(f)
78
-
79
- ae = cls(
80
- n_dirs_local=cfg["n_dirs_local"],
81
- d_model=cfg["d_model"],
82
- k=cfg["k"],
83
- auxk=cfg["auxk"],
84
- dead_steps_threshold=cfg["dead_steps_threshold"],
85
- auxk_coef = cfg["auxk_coef"] if "auxk_coef" in cfg else 1/32
86
- )
87
-
88
- state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
89
- ae.load_state_dict(state_dict)
90
-
91
- return ae
92
-
93
- @property
94
- def n_dirs(self):
95
- return self.n_dirs_local
96
-
97
- def encode(self, x):
98
- x = x - self.pre_bias
99
- latents_pre_act = self.encoder(x) + self.latent_bias
100
-
101
- vals, inds = torch.topk(
102
- latents_pre_act,
103
- k=self.k,
104
- dim=-1
105
- )
106
-
107
- latents = torch.zeros_like(latents_pre_act)
108
- latents.scatter_(-1, inds, torch.relu(vals))
109
-
110
- return latents
111
-
112
- def encode_with_k(self, x, k):
113
- x = x - self.pre_bias
114
- latents_pre_act = self.encoder(x) + self.latent_bias
115
-
116
- vals, inds = torch.topk(
117
- latents_pre_act,
118
- k=k,
119
- dim=-1
120
- )
121
-
122
- latents = torch.zeros_like(latents_pre_act)
123
- latents.scatter_(-1, inds, torch.relu(vals))
124
-
125
- return latents
126
-
127
- def encode_without_topk(self, x):
128
- x = x - self.pre_bias
129
- latents_pre_act = torch.relu(self.encoder(x) + self.latent_bias)
130
- return latents_pre_act
131
-
132
-
133
- def forward(self, x):
134
- x = x - self.pre_bias
135
- latents_pre_act = self.encoder(x) + self.latent_bias
136
- l0 = (latents_pre_act > 0).float().sum(-1).mean()
137
- vals, inds = torch.topk(
138
- latents_pre_act,
139
- k=self.k,
140
- dim=-1
141
- )
142
- with torch.no_grad(): # Disable gradients for statistics
143
- ## set num nonzero stat ##
144
- tmp = torch.zeros_like(self.stats_last_nonzero)
145
- tmp.scatter_add_(
146
- 0,
147
- inds.reshape(-1),
148
- (vals > 1e-3).to(tmp.dtype).reshape(-1),
149
- )
150
- self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
151
- self.stats_last_nonzero += 1
152
-
153
- del tmp
154
- ## auxk
155
- if self.auxk is not None: # for auxk
156
- auxk_vals, auxk_inds = torch.topk(
157
- self.auxk_mask_fn(latents_pre_act),
158
- k=self.auxk,
159
- dim=-1
160
- )
161
- else:
162
- auxk_inds = None
163
- auxk_vals = None
164
-
165
- ## end auxk
166
-
167
- vals = torch.relu(vals)
168
- if auxk_vals is not None:
169
- auxk_vals = torch.relu(auxk_vals)
170
-
171
- rows, cols = latents_pre_act.size()
172
- row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
173
- vals = vals.reshape(-1)
174
- inds = inds.reshape(-1)
175
-
176
- indices = torch.stack([row_indices.to(inds.device), inds])
177
-
178
- sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
179
-
180
- recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
181
-
182
- mse_loss = self.mse_scale * self.mse(recons, x)
183
-
184
- ## Calculate AuxK loss if applicable
185
- if auxk_vals is not None:
186
- auxk_recons = self.decode_sparse(auxk_inds, auxk_vals)
187
- auxk_loss =self.auxk_coef * self.normalized_mse(auxk_recons, x - recons.detach() + self.pre_bias.detach()).nan_to_num(0)
188
- else:
189
- auxk_loss = 0.0
190
-
191
- total_loss = mse_loss + auxk_loss
192
-
193
- return recons, total_loss, {
194
- "inds": inds,
195
- "vals": vals,
196
- "auxk_inds": auxk_inds,
197
- "auxk_vals": auxk_vals,
198
- "l0": l0,
199
- "train_recons": mse_loss,
200
- "train_maxk_recons": auxk_loss
201
- }
202
-
203
-
204
- def decode_sparse(self, inds, vals):
205
- rows, cols = inds.shape[0], self.n_dirs
206
-
207
- row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
208
- vals = vals.reshape(-1)
209
- inds = inds.reshape(-1)
210
-
211
- indices = torch.stack([row_indices.to(inds.device), inds])
212
-
213
- sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
214
-
215
- recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
216
- return recons
217
-
218
- @property
219
- def device(self):
220
- return next(self.parameters()).device
221
-
222
- def mse(self, recons, x):
223
- # return ((recons - x) ** 2).sum(dim=-1).mean()
224
- return ((recons - x) ** 2).mean()
225
-
226
- def normalized_mse(self, recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
227
- # only used for auxk
228
- xs_mu = xs.mean(dim=0)
229
-
230
- loss = self.mse(recon, xs) / self.mse(
231
- xs_mu[None, :].broadcast_to(xs.shape), xs
232
- )
233
-
234
- return loss
235
-
236
- def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
237
-
238
- autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
239
-
240
-
241
- def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
242
-
243
- assert autoencoder.decoder.weight.grad is not None
244
-
245
- autoencoder.decoder.weight.grad +=\
246
- torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
247
  autoencoder.decoder.weight.data * -1
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from torch import nn
5
+
6
+ class SparseAutoencoder(nn.Module):
7
+
8
+ def __init__(
9
+ self,
10
+ n_dirs_local: int,
11
+ d_model: int,
12
+ k: int,
13
+ auxk: int, #| None,
14
+ dead_steps_threshold: int,
15
+ auxk_coef: float
16
+ ):
17
+ super().__init__()
18
+ self.n_dirs_local = n_dirs_local
19
+ self.d_model = d_model
20
+ self.k = k
21
+ self.auxk = auxk
22
+ self.dead_steps_threshold = dead_steps_threshold
23
+ self.auxk_coef = auxk_coef
24
+ self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
25
+ self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
26
+
27
+ self.pre_bias = nn.Parameter(torch.zeros(d_model))
28
+ self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
29
+
30
+ self.stats_last_nostats_last_nonzeronzero: torch.Tensor
31
+ self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
32
+
33
+ def auxk_mask_fn(x):
34
+ dead_mask = self.stats_last_nonzero > dead_steps_threshold
35
+ x.data *= dead_mask # inplace to save memory
36
+ return x
37
+
38
+ self.auxk_mask_fn = auxk_mask_fn
39
+ ## initialization
40
+
41
+ # "tied" init
42
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
43
+
44
+ # store decoder in column major layout for kernel
45
+ self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
46
+ self.mse_scale = 1
47
+ unit_norm_decoder_(self)
48
+
49
+ def save_to_disk(self, path: str):
50
+ PATH_TO_CFG = 'config.json'
51
+ PATH_TO_WEIGHTS = 'state_dict.pth'
52
+
53
+ cfg = {
54
+ "n_dirs_local": self.n_dirs_local,
55
+ "d_model": self.d_model,
56
+ "k": self.k,
57
+ "auxk": self.auxk,
58
+ "dead_steps_threshold": self.dead_steps_threshold,
59
+ "auxk_coef": self.auxk_coef
60
+ }
61
+
62
+ os.makedirs(path, exist_ok=True)
63
+
64
+ with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
65
+ json.dump(cfg, f)
66
+
67
+ torch.save({
68
+ "state_dict": self.state_dict(),
69
+ }, os.path.join(path, PATH_TO_WEIGHTS))
70
+
71
+ @classmethod
72
+ def load_from_disk(cls, path: str):
73
+ PATH_TO_CFG = 'config.json'
74
+ PATH_TO_WEIGHTS = 'state_dict.pth'
75
+
76
+ with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
77
+ cfg = json.load(f)
78
+
79
+ ae = cls(
80
+ n_dirs_local=cfg["n_dirs_local"],
81
+ d_model=cfg["d_model"],
82
+ k=cfg["k"],
83
+ auxk=cfg["auxk"],
84
+ dead_steps_threshold=cfg["dead_steps_threshold"],
85
+ auxk_coef = cfg["auxk_coef"] if "auxk_coef" in cfg else 1/32
86
+ )
87
+
88
+ state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS), map_location=torch.device('cpu'))["state_dict"]
89
+ ae.load_state_dict(state_dict)
90
+
91
+ return ae
92
+
93
+ @property
94
+ def n_dirs(self):
95
+ return self.n_dirs_local
96
+
97
+ def encode(self, x):
98
+ x = x - self.pre_bias
99
+ latents_pre_act = self.encoder(x) + self.latent_bias
100
+
101
+ vals, inds = torch.topk(
102
+ latents_pre_act,
103
+ k=self.k,
104
+ dim=-1
105
+ )
106
+
107
+ latents = torch.zeros_like(latents_pre_act)
108
+ latents.scatter_(-1, inds, torch.relu(vals))
109
+
110
+ return latents
111
+
112
+ def encode_with_k(self, x, k):
113
+ x = x - self.pre_bias
114
+ latents_pre_act = self.encoder(x) + self.latent_bias
115
+
116
+ vals, inds = torch.topk(
117
+ latents_pre_act,
118
+ k=k,
119
+ dim=-1
120
+ )
121
+
122
+ latents = torch.zeros_like(latents_pre_act)
123
+ latents.scatter_(-1, inds, torch.relu(vals))
124
+
125
+ return latents
126
+
127
+ def encode_without_topk(self, x):
128
+ x = x - self.pre_bias
129
+ latents_pre_act = torch.relu(self.encoder(x) + self.latent_bias)
130
+ return latents_pre_act
131
+
132
+
133
+ def forward(self, x):
134
+ x = x - self.pre_bias
135
+ latents_pre_act = self.encoder(x) + self.latent_bias
136
+ l0 = (latents_pre_act > 0).float().sum(-1).mean()
137
+ vals, inds = torch.topk(
138
+ latents_pre_act,
139
+ k=self.k,
140
+ dim=-1
141
+ )
142
+ with torch.no_grad(): # Disable gradients for statistics
143
+ ## set num nonzero stat ##
144
+ tmp = torch.zeros_like(self.stats_last_nonzero)
145
+ tmp.scatter_add_(
146
+ 0,
147
+ inds.reshape(-1),
148
+ (vals > 1e-3).to(tmp.dtype).reshape(-1),
149
+ )
150
+ self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
151
+ self.stats_last_nonzero += 1
152
+
153
+ del tmp
154
+ ## auxk
155
+ if self.auxk is not None: # for auxk
156
+ auxk_vals, auxk_inds = torch.topk(
157
+ self.auxk_mask_fn(latents_pre_act),
158
+ k=self.auxk,
159
+ dim=-1
160
+ )
161
+ else:
162
+ auxk_inds = None
163
+ auxk_vals = None
164
+
165
+ ## end auxk
166
+
167
+ vals = torch.relu(vals)
168
+ if auxk_vals is not None:
169
+ auxk_vals = torch.relu(auxk_vals)
170
+
171
+ rows, cols = latents_pre_act.size()
172
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
173
+ vals = vals.reshape(-1)
174
+ inds = inds.reshape(-1)
175
+
176
+ indices = torch.stack([row_indices.to(inds.device), inds])
177
+
178
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
179
+
180
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
181
+
182
+ mse_loss = self.mse_scale * self.mse(recons, x)
183
+
184
+ ## Calculate AuxK loss if applicable
185
+ if auxk_vals is not None:
186
+ auxk_recons = self.decode_sparse(auxk_inds, auxk_vals)
187
+ auxk_loss =self.auxk_coef * self.normalized_mse(auxk_recons, x - recons.detach() + self.pre_bias.detach()).nan_to_num(0)
188
+ else:
189
+ auxk_loss = 0.0
190
+
191
+ total_loss = mse_loss + auxk_loss
192
+
193
+ return recons, total_loss, {
194
+ "inds": inds,
195
+ "vals": vals,
196
+ "auxk_inds": auxk_inds,
197
+ "auxk_vals": auxk_vals,
198
+ "l0": l0,
199
+ "train_recons": mse_loss,
200
+ "train_maxk_recons": auxk_loss
201
+ }
202
+
203
+
204
+ def decode_sparse(self, inds, vals):
205
+ rows, cols = inds.shape[0], self.n_dirs
206
+
207
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
208
+ vals = vals.reshape(-1)
209
+ inds = inds.reshape(-1)
210
+
211
+ indices = torch.stack([row_indices.to(inds.device), inds])
212
+
213
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
214
+
215
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
216
+ return recons
217
+
218
+ @property
219
+ def device(self):
220
+ return next(self.parameters()).device
221
+
222
+ def mse(self, recons, x):
223
+ # return ((recons - x) ** 2).sum(dim=-1).mean()
224
+ return ((recons - x) ** 2).mean()
225
+
226
+ def normalized_mse(self, recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
227
+ # only used for auxk
228
+ xs_mu = xs.mean(dim=0)
229
+
230
+ loss = self.mse(recon, xs) / self.mse(
231
+ xs_mu[None, :].broadcast_to(xs.shape), xs
232
+ )
233
+
234
+ return loss
235
+
236
+ def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
237
+
238
+ autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
239
+
240
+
241
+ def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
242
+
243
+ assert autoencoder.decoder.weight.grad is not None
244
+
245
+ autoencoder.decoder.weight.grad +=\
246
+ torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
247
  autoencoder.decoder.weight.data * -1