zakerytclarke commited on
Commit
dc3a847
·
verified ·
1 Parent(s): cd0c9a0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +327 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,329 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ # streamlit_app.py
2
+
 
3
  import streamlit as st
4
+ from datasets import load_dataset
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from collections import defaultdict, Counter
10
+ from sklearn.tree import DecisionTreeClassifier
11
+ from sklearn.ensemble import GradientBoostingClassifier
12
+ import random
13
+
14
+ st.title("🧠 Language Model Explorer")
15
+
16
+ ###################################
17
+ # Sidebar configuration
18
+ ###################################
19
+
20
+ dataset_name = st.sidebar.selectbox(
21
+ "Choose Dataset",
22
+ ["squad", "tiny_shakespeare"]
23
+ )
24
+
25
+ tokenizer_type = st.sidebar.selectbox(
26
+ "Choose Tokenizer",
27
+ ["character", "word"]
28
+ )
29
+
30
+ model_type = st.sidebar.selectbox(
31
+ "Choose Model",
32
+ ["N-gram", "Feed Forward NN", "Decision Tree", "Gradient Boosted Tree", "RNN"]
33
+ )
34
+
35
+ temperature = st.sidebar.slider("Sampling Temperature", 0.1, 2.0, 1.0)
36
+ train_button = st.sidebar.button("Train Model")
37
+
38
+ device = torch.device("cpu") # force CPU usage
39
+
40
+ ###################################
41
+ # Load dataset
42
+ ###################################
43
+
44
+ @st.cache_data
45
+ def load_text(dataset_name):
46
+ if dataset_name == "squad":
47
+ data = load_dataset("squad", split="train[:1%]")
48
+ texts = [x['context'] for x in data]
49
+ elif dataset_name == "tiny_shakespeare":
50
+ data = load_dataset("tiny_shakespeare")
51
+ texts = [data['train'][0]['text']]
52
+ else:
53
+ texts = ["hello world"]
54
+ return " ".join(texts)
55
+
56
+ text_data = load_text(dataset_name)
57
+
58
+ ###################################
59
+ # Tokenization
60
+ ###################################
61
+
62
+ def tokenize(text, tokenizer_type):
63
+ if tokenizer_type == "character":
64
+ tokens = list(text)
65
+ elif tokenizer_type == "word":
66
+ tokens = text.split()
67
+ return tokens
68
+
69
+ tokens = tokenize(text_data, tokenizer_type)
70
+ vocab = list(set(tokens))
71
+ token_to_idx = {tok: i for i, tok in enumerate(vocab)}
72
+ idx_to_token = {i: tok for tok, i in token_to_idx.items()}
73
+
74
+ ###################################
75
+ # Models
76
+ ###################################
77
+
78
+ class NGramModel:
79
+ def __init__(self, tokens, n=3):
80
+ self.n = n
81
+ self.model = defaultdict(Counter)
82
+ for i in range(len(tokens) - n):
83
+ context = tuple(tokens[i:i+n-1])
84
+ next_token = tokens[i+n-1]
85
+ self.model[context][next_token] += 1
86
+
87
+ def predict(self, context, temperature=1.0):
88
+ context = tuple(context[-(self.n-1):])
89
+ counts = self.model.get(context, None)
90
+ if counts is None:
91
+ return random.choice(list(token_to_idx.keys()))
92
+ items = list(counts.items())
93
+ tokens_, freqs = zip(*items)
94
+ probs = np.array(freqs, dtype=float)
95
+ probs = probs ** (1.0 / temperature)
96
+ probs /= probs.sum()
97
+ return np.random.choice(tokens_, p=probs)
98
+
99
+ ###################################
100
+ # Feed Forward NN
101
+ ###################################
102
+
103
+ class FFNN(nn.Module):
104
+ def __init__(self, vocab_size, context_size, hidden_size=128):
105
+ super().__init__()
106
+ self.embed = nn.Embedding(vocab_size, hidden_size)
107
+ self.fc1 = nn.Linear(hidden_size * context_size, hidden_size)
108
+ self.fc2 = nn.Linear(hidden_size, vocab_size)
109
+
110
+ def forward(self, x):
111
+ x = self.embed(x)
112
+ x = x.view(x.size(0), -1)
113
+ x = torch.relu(self.fc1(x))
114
+ x = self.fc2(x)
115
+ return x
116
+
117
+ def train_ffnn(tokens, context_size=3, epochs=3):
118
+ data = []
119
+ for i in range(len(tokens) - context_size):
120
+ context = tokens[i:i+context_size-1]
121
+ target = tokens[i+context_size-1]
122
+ data.append((
123
+ torch.tensor([token_to_idx[tok] for tok in context], device=device),
124
+ token_to_idx[target]
125
+ ))
126
+
127
+ model = FFNN(len(vocab), context_size-1).to(device)
128
+ optimizer = optim.Adam(model.parameters(), lr=0.01)
129
+ criterion = nn.CrossEntropyLoss()
130
+
131
+ progress_bar = st.progress(0)
132
+ total_steps = epochs * len(data)
133
+ step = 0
134
+
135
+ for epoch in range(epochs):
136
+ total_loss = 0
137
+ for x, y in data:
138
+ x = x.unsqueeze(0)
139
+ y = torch.tensor([y], device=device)
140
+ out = model(x)
141
+ loss = criterion(out, y)
142
+ optimizer.zero_grad()
143
+ loss.backward()
144
+ optimizer.step()
145
+ total_loss += loss.item()
146
+
147
+ step += 1
148
+ progress_bar.progress(step / total_steps)
149
+
150
+ st.write(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
151
+
152
+ progress_bar.empty()
153
+ return model
154
+
155
+ def ffnn_predict(model, context, temperature=1.0):
156
+ x = torch.tensor([token_to_idx.get(tok, 0) for tok in context[-2:]], device=device).unsqueeze(0)
157
+ with torch.no_grad():
158
+ logits = model(x).squeeze()
159
+ probs = torch.softmax(logits / temperature, dim=0).cpu().numpy()
160
+ return np.random.choice(vocab, p=probs)
161
+
162
+ ###################################
163
+ # Decision Tree
164
+ ###################################
165
+
166
+ def train_dt(tokens, context_size=3):
167
+ X, y = [], []
168
+ for i in range(len(tokens) - context_size):
169
+ context = tokens[i:i+context_size-1]
170
+ target = tokens[i+context_size-1]
171
+ X.append([token_to_idx[tok] for tok in context])
172
+ y.append(token_to_idx[target])
173
+
174
+ with st.spinner("Training Decision Tree..."):
175
+ model = DecisionTreeClassifier()
176
+ model.fit(X, y)
177
+ return model
178
+
179
+ def dt_predict(model, context):
180
+ x = [token_to_idx.get(tok, 0) for tok in context[-2:]]
181
+ pred = model.predict([x])[0]
182
+ return idx_to_token[pred]
183
+
184
+ ###################################
185
+ # Gradient Boosted Tree
186
+ ###################################
187
+
188
+ def train_gbt(tokens, context_size=3):
189
+ X, y = [], []
190
+ for i in range(len(tokens) - context_size):
191
+ context = tokens[i:i+context_size-1]
192
+ target = tokens[i+context_size-1]
193
+ X.append([token_to_idx[tok] for tok in context])
194
+ y.append(token_to_idx[target])
195
+
196
+ with st.spinner("Training Gradient Boosted Tree..."):
197
+ model = GradientBoostingClassifier()
198
+ model.fit(X, y)
199
+ return model
200
+
201
+ def gbt_predict(model, context):
202
+ x = [token_to_idx.get(tok, 0) for tok in context[-2:]]
203
+ pred = model.predict([x])[0]
204
+ return idx_to_token[pred]
205
+
206
+ ###################################
207
+ # RNN
208
+ ###################################
209
+
210
+ class RNNModel(nn.Module):
211
+ def __init__(self, vocab_size, embed_size=64, hidden_size=128):
212
+ super().__init__()
213
+ self.embed = nn.Embedding(vocab_size, embed_size)
214
+ self.rnn = nn.RNN(embed_size, hidden_size, batch_first=True)
215
+ self.fc = nn.Linear(hidden_size, vocab_size)
216
+
217
+ def forward(self, x, h=None):
218
+ x = self.embed(x)
219
+ out, h = self.rnn(x, h)
220
+ out = self.fc(out[:, -1, :])
221
+ return out, h
222
+
223
+ def train_rnn(tokens, context_size=3, epochs=3):
224
+ data = []
225
+ for i in range(len(tokens) - context_size):
226
+ context = tokens[i:i+context_size-1]
227
+ target = tokens[i+context_size-1]
228
+ data.append((
229
+ torch.tensor([token_to_idx[tok] for tok in context], device=device),
230
+ token_to_idx[target]
231
+ ))
232
+
233
+ model = RNNModel(len(vocab)).to(device)
234
+ optimizer = optim.Adam(model.parameters(), lr=0.01)
235
+ criterion = nn.CrossEntropyLoss()
236
+
237
+ progress_bar = st.progress(0)
238
+ total_steps = epochs * len(data)
239
+ step = 0
240
+
241
+ for epoch in range(epochs):
242
+ total_loss = 0
243
+ h = None
244
+ for x, y in data:
245
+ x = x.unsqueeze(0)
246
+ y = torch.tensor([y], device=device)
247
+ out, h = model(x, h)
248
+ loss = criterion(out, y)
249
+ optimizer.zero_grad()
250
+ loss.backward()
251
+ optimizer.step()
252
+ total_loss += loss.item()
253
+
254
+ step += 1
255
+ progress_bar.progress(step / total_steps)
256
+
257
+ st.write(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
258
+
259
+ progress_bar.empty()
260
+ return model
261
+
262
+ def rnn_predict(model, context, temperature=1.0):
263
+ x = torch.tensor([token_to_idx.get(tok, 0) for tok in context[-2:]], device=device).unsqueeze(0)
264
+ with torch.no_grad():
265
+ logits, _ = model(x)
266
+ probs = torch.softmax(logits.squeeze() / temperature, dim=0).cpu().numpy()
267
+ return np.random.choice(vocab, p=probs)
268
+
269
+ ###################################
270
+ # Train and evaluate
271
+ ###################################
272
+
273
+ if train_button:
274
+ st.write(f"Training **{model_type}** model...")
275
+
276
+ if model_type == "N-gram":
277
+ with st.spinner("Training N-gram model..."):
278
+ model = NGramModel(tokens, n=3)
279
+ elif model_type == "Feed Forward NN":
280
+ model = train_ffnn(tokens)
281
+ elif model_type == "Decision Tree":
282
+ model = train_dt(tokens)
283
+ elif model_type == "Gradient Boosted Tree":
284
+ model = train_gbt(tokens)
285
+ elif model_type == "RNN":
286
+ model = train_rnn(tokens)
287
+
288
+ st.session_state["model"] = model
289
+ st.session_state["model_type"] = model_type
290
+ st.success(f"{model_type} model trained.")
291
+
292
+ ###################################
293
+ # Chat interface
294
+ ###################################
295
+
296
+ st.header("💬 Chat with the model")
297
+
298
+ if "model" in st.session_state:
299
+ user_input = st.text_input("Type a prompt:")
300
+
301
+ if user_input:
302
+ context = tokenize(user_input, tokenizer_type)
303
+ generated = context.copy()
304
+
305
+ for _ in range(20):
306
+ if st.session_state["model_type"] == "N-gram":
307
+ next_tok = st.session_state["model"].predict(generated, temperature)
308
+ elif st.session_state["model_type"] == "Feed Forward NN":
309
+ next_tok = ffnn_predict(st.session_state["model"], generated, temperature)
310
+ elif st.session_state["model_type"] == "Decision Tree":
311
+ next_tok = dt_predict(st.session_state["model"], generated)
312
+ elif st.session_state["model_type"] == "Gradient Boosted Tree":
313
+ next_tok = gbt_predict(st.session_state["model"], generated)
314
+ elif st.session_state["model_type"] == "RNN":
315
+ next_tok = rnn_predict(st.session_state["model"], generated, temperature)
316
+
317
+ generated.append(next_tok)
318
+ if next_tok == "<END>":
319
+ break
320
+
321
+ if tokenizer_type == "character":
322
+ output = "".join(generated)
323
+ else:
324
+ output = " ".join(generated)
325
 
326
+ st.write("**Model Output:**")
327
+ st.write(output)
328
+ else:
329
+ st.info("Train a model to begin chatting.")