zakerytclarke commited on
Commit
ec60e4a
·
verified ·
1 Parent(s): c361481

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +174 -0
src/streamlit_app.py CHANGED
@@ -159,3 +159,177 @@ def train_ffnn(tokens, context_size=3, epochs=3):
159
  return model
160
 
161
  def ffnn_predict(model, context, temperature=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  return model
160
 
161
  def ffnn_predict(model, context, temperature=1.0):
162
+ x = torch.tensor([token_to_idx.get(tok, 0) for tok in context[-2:]], device=device).unsqueeze(0)
163
+ with torch.no_grad():
164
+ logits = model(x).squeeze()
165
+ probs = torch.softmax(logits / temperature, dim=0).cpu().numpy()
166
+ return np.random.choice(vocab, p=probs)
167
+
168
+ ###################################
169
+ # Decision Tree
170
+ ###################################
171
+
172
+ def train_dt(tokens, context_size=3):
173
+ X, y = [], []
174
+ for i in range(len(tokens) - context_size):
175
+ context = tokens[i:i+context_size-1]
176
+ target = tokens[i+context_size-1]
177
+ X.append([token_to_idx[tok] for tok in context])
178
+ y.append(token_to_idx[target])
179
+
180
+ with st.spinner("Training Decision Tree..."):
181
+ model = DecisionTreeClassifier()
182
+ model.fit(X, y)
183
+ return model
184
+
185
+ def dt_predict(model, context):
186
+ x = [token_to_idx.get(tok, 0) for tok in context[-2:]]
187
+ pred = model.predict([x])[0]
188
+ return idx_to_token[pred]
189
+
190
+ ###################################
191
+ # Gradient Boosted Tree
192
+ ###################################
193
+
194
+ def train_gbt(tokens, context_size=3):
195
+ X, y = [], []
196
+ for i in range(len(tokens) - context_size):
197
+ context = tokens[i:i+context_size-1]
198
+ target = tokens[i+context_size-1]
199
+ X.append([token_to_idx[tok] for tok in context])
200
+ y.append(token_to_idx[target])
201
+
202
+ with st.spinner("Training Gradient Boosted Tree..."):
203
+ model = GradientBoostingClassifier()
204
+ model.fit(X, y)
205
+ return model
206
+
207
+ def gbt_predict(model, context):
208
+ x = [token_to_idx.get(tok, 0) for tok in context[-2:]]
209
+ pred = model.predict([x])[0]
210
+ return idx_to_token[pred]
211
+
212
+ ###################################
213
+ # RNN
214
+ ###################################
215
+
216
+ class RNNModel(nn.Module):
217
+ def __init__(self, vocab_size, embed_size=64, hidden_size=128):
218
+ super().__init__()
219
+ self.embed = nn.Embedding(vocab_size, embed_size)
220
+ self.rnn = nn.RNN(embed_size, hidden_size, batch_first=True)
221
+ self.fc = nn.Linear(hidden_size, vocab_size)
222
+
223
+ def forward(self, x, h=None):
224
+ x = self.embed(x)
225
+ out, h = self.rnn(x, h)
226
+ out = self.fc(out[:, -1, :])
227
+ return out, h
228
+
229
+ def train_rnn(tokens, context_size=3, epochs=3):
230
+ data = []
231
+ for i in range(len(tokens) - context_size):
232
+ context = tokens[i:i+context_size-1]
233
+ target = tokens[i+context_size-1]
234
+ data.append((
235
+ torch.tensor([token_to_idx[tok] for tok in context], device=device),
236
+ token_to_idx[target]
237
+ ))
238
+
239
+ model = RNNModel(len(vocab)).to(device)
240
+ optimizer = optim.Adam(model.parameters(), lr=0.01)
241
+ criterion = nn.CrossEntropyLoss()
242
+
243
+ progress_bar = st.progress(0)
244
+ total_steps = epochs * len(data)
245
+ step = 0
246
+
247
+ for epoch in range(epochs):
248
+ total_loss = 0
249
+ h = None
250
+ for x, y in data:
251
+ x = x.unsqueeze(0)
252
+ y = torch.tensor([y], device=device)
253
+ out, h = model(x, h)
254
+ loss = criterion(out, y)
255
+ optimizer.zero_grad()
256
+ loss.backward()
257
+ optimizer.step()
258
+ total_loss += loss.item()
259
+
260
+ step += 1
261
+ progress_bar.progress(step / total_steps)
262
+
263
+ st.write(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
264
+
265
+ progress_bar.empty()
266
+ return model
267
+
268
+ def rnn_predict(model, context, temperature=1.0):
269
+ x = torch.tensor([token_to_idx.get(tok, 0) for tok in context[-2:]], device=device).unsqueeze(0)
270
+ with torch.no_grad():
271
+ logits, _ = model(x)
272
+ probs = torch.softmax(logits.squeeze() / temperature, dim=0).cpu().numpy()
273
+ return np.random.choice(vocab, p=probs)
274
+
275
+ ###################################
276
+ # Train and evaluate
277
+ ###################################
278
+
279
+ if train_button:
280
+ st.write(f"Training **{model_type}** model...")
281
+
282
+ if model_type == "N-gram":
283
+ with st.spinner("Training N-gram model..."):
284
+ model = NGramModel(tokens, n=3)
285
+ elif model_type == "Feed Forward NN":
286
+ model = train_ffnn(tokens)
287
+ elif model_type == "Decision Tree":
288
+ model = train_dt(tokens)
289
+ elif model_type == "Gradient Boosted Tree":
290
+ model = train_gbt(tokens)
291
+ elif model_type == "RNN":
292
+ model = train_rnn(tokens)
293
+
294
+ st.session_state["model"] = model
295
+ st.session_state["model_type"] = model_type
296
+ st.success(f"{model_type} model trained.")
297
+
298
+ ###################################
299
+ # Chat interface
300
+ ###################################
301
+
302
+ st.header("💬 Chat with the model")
303
+
304
+ if "model" in st.session_state:
305
+ user_input = st.text_input("Type a prompt:")
306
+
307
+ if user_input:
308
+ context = tokenize(user_input, tokenizer_type)
309
+ generated = context.copy()
310
+
311
+ for _ in range(20):
312
+ if st.session_state["model_type"] == "N-gram":
313
+ next_tok = st.session_state["model"].predict(generated, temperature)
314
+ elif st.session_state["model_type"] == "Feed Forward NN":
315
+ next_tok = ffnn_predict(st.session_state["model"], generated, temperature)
316
+ elif st.session_state["model_type"] == "Decision Tree":
317
+ next_tok = dt_predict(st.session_state["model"], generated)
318
+ elif st.session_state["model_type"] == "Gradient Boosted Tree":
319
+ next_tok = gbt_predict(st.session_state["model"], generated)
320
+ elif st.session_state["model_type"] == "RNN":
321
+ next_tok = rnn_predict(st.session_state["model"], generated, temperature)
322
+
323
+ generated.append(next_tok)
324
+ if next_tok == "<END>":
325
+ break
326
+
327
+ if tokenizer_type == "character":
328
+ output = "".join(generated)
329
+ else:
330
+ output = " ".join(generated)
331
+
332
+ st.write("**Model Output:**")
333
+ st.write(output)
334
+ else:
335
+ st.info("Train a model to begin chatting.")