kalekarnn commited on
Commit
03df9fc
·
verified ·
1 Parent(s): 0506cd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -56
app.py CHANGED
@@ -1,9 +1,83 @@
1
  import streamlit as st
2
  import torch
3
- import tiktoken
4
- from dataclasses import dataclass
5
  import torch.nn as nn
6
  from torch.nn import functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  @dataclass
9
  class GPTConfig:
@@ -119,65 +193,108 @@ class GPT(nn.Module):
119
  return model
120
 
121
 
122
- # Load the trained model
 
 
 
 
 
 
 
 
 
123
  @st.cache_resource
124
  def load_model():
 
125
  config = GPTConfig()
126
  model = GPT(config)
127
- try:
128
- # Load the model with map_location to handle CPU-only environments
129
- model.load_state_dict(torch.load('trained_model_quantized.pt', map_location=torch.device('cpu')), strict=False)
130
- model.eval() # Set the model to evaluation mode
131
- st.success("Model loaded successfully!")
132
- except Exception as e:
133
- st.error(f"Error loading model: {e}")
134
- return model
135
-
136
- # Load the tokenizer
137
- def load_tokenizer():
138
- return tiktoken.get_encoding('gpt2')
139
-
140
- # Generate text function
141
- def generate_text(model, tokenizer, input_text, length, num_sequences):
142
- # Encode the input text
143
- input_ids = tokenizer.encode(input_text)
144
- input_tensor = torch.tensor(input_ids).unsqueeze(0) # Add batch dimension (shape: [1, T])
145
-
146
- generated_sequences = []
147
- for _ in range(num_sequences):
148
- # Generate additional tokens
149
- with torch.no_grad():
150
- for _ in range(length):
151
- logits = model(input_tensor)[0] # Get logits
152
- next_token_logits = logits[:, -1, :] # Get the last token's logits
153
- next_token_probs = torch.softmax(next_token_logits, dim=-1)
154
- next_token = torch.multinomial(next_token_probs, num_samples=1) # Sample from the distribution
155
-
156
- # Ensure the next_token has the correct shape for concatenation
157
- next_token = next_token.view(1, -1) # Reshape to [1, 1] if necessary
158
- input_tensor = torch.cat((input_tensor, next_token), dim=1) # Append the new token
159
-
160
- # Decode the generated tokens
161
- generated_sequences.append(tokenizer.decode(input_tensor[0].tolist()))
162
-
163
- return generated_sequences
164
-
165
- # Streamlit app layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  st.title("GPT Text Generator")
167
- st.write("Enter your text and specify the length of additional text to generate.")
168
 
169
- input_text = st.text_area("Input Text", "Once upon a time", max_chars=512) # Limit to 512 characters
170
- length = st.slider("Predict Additional Text of Length", 1, 50, 10)
171
- num_sequences = st.slider("Number of Sequences to Generate", 1, 5, 1)
 
 
 
 
172
 
173
  if st.button("Generate"):
174
- model = load_model() # Load the model for inference
175
- tokenizer = load_tokenizer() # Load the tokenizer
176
- st.write("Generating text...")
177
- generated_texts = generate_text(model, tokenizer, input_text, length, num_sequences)
178
- st.write("Text generation complete.")
179
-
180
- st.write("Generated Texts:")
181
- for i, text in enumerate(generated_texts):
182
- st.subheader(f"Sequence {i + 1}")
183
- st.write(text)
 
 
 
 
1
  import streamlit as st
2
  import torch
 
 
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
+ import tiktoken
6
+ import sys
7
+ import os
8
+ import logging
9
+ import warnings
10
+ from dataclasses import dataclass
11
+ import math
12
+
13
+ class MLP(nn.Module):
14
+
15
+ def __init__(self, config):
16
+ super().__init__()
17
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
18
+ self.gelu = nn.GELU(approximate='tanh')
19
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
20
+ self.c_proj.NANOGPT_SCALE_INIT = 1
21
+
22
+ def forward(self, x):
23
+ x = self.c_fc(x)
24
+ x = self.gelu(x)
25
+ x = self.c_proj(x)
26
+ return x
27
+
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
38
+ self.c_proj.NANGPT_SCALE_INIT = 1
39
+ # regularization
40
+ self.n_head = config.n_head
41
+ self.n_embd = config.n_embd
42
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
43
+
44
+ def forward(self, x):
45
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
46
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
47
+ # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
48
+ # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
49
+ qkv = self.c_attn(x)
50
+ q, k, v = qkv.split(self.n_embd, dim=2)
51
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
52
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
53
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
54
+
55
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
56
+ att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
57
+ att = F.softmax(att, dim=-1)
58
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
59
+
60
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
61
+ # output projection
62
+ y = self.c_proj(y)
63
+ return y
64
+
65
+
66
+
67
+ class Block(nn.Module):
68
+
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.ln_1 = nn.LayerNorm(config.n_embd)
72
+ self.attn = CausalSelfAttention(config)
73
+ self.ln_2 = nn.LayerNorm(config.n_embd)
74
+ self.mlp = MLP(config)
75
+
76
+ def forward(self, x):
77
+ x = x + self.attn(self.ln_1(x))
78
+ x = x + self.mlp(self.ln_2(x))
79
+ return x
80
+
81
 
82
  @dataclass
83
  class GPTConfig:
 
193
  return model
194
 
195
 
196
+ # Configure logging and warnings
197
+ logging.getLogger('streamlit').setLevel(logging.ERROR)
198
+ warnings.filterwarnings('ignore', message='.*torch.classes.*')
199
+ warnings.filterwarnings('ignore', category=FutureWarning)
200
+
201
+ # Add the project root to Python path
202
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
203
+
204
+
205
+
206
  @st.cache_resource
207
  def load_model():
208
+ device = "cpu"
209
  config = GPTConfig()
210
  model = GPT(config)
211
+
212
+ # Load the trained weights from root directory
213
+ checkpoint = torch.load('trained_model_quantized.pt', map_location=device, weights_only=True)
214
+
215
+ # Handle pruned weights
216
+ state_dict = checkpoint['model_state_dict']
217
+ new_state_dict = {}
218
+
219
+ for key in model.state_dict().keys():
220
+ if key.endswith('.weight'):
221
+ # Check if this is a pruned weight
222
+ orig_key = key[:-7] + '.weight_orig' if key.endswith('.weight') else key
223
+ mask_key = key[:-7] + '.weight_mask' if key.endswith('.weight') else key
224
+
225
+ if orig_key in state_dict and mask_key in state_dict:
226
+ # Reconstruct the pruned weight
227
+ new_state_dict[key] = state_dict[orig_key] * state_dict[mask_key]
228
+ else:
229
+ # Use the weight as is
230
+ new_state_dict[key] = state_dict[key] if key in state_dict else model.state_dict()[key]
231
+ else:
232
+ # Copy non-weight parameters as is
233
+ new_state_dict[key] = state_dict[key] if key in state_dict else model.state_dict()[key]
234
+
235
+ # Load the processed state dict
236
+ model.load_state_dict(new_state_dict)
237
+
238
+ # Convert back to float32 for inference
239
+ model = model.float()
240
+ model.to(device)
241
+ model.eval()
242
+
243
+ return model, device
244
+
245
+ def generate_text(model, prompt, max_length=100, num_return_sequences=1, device='cpu'):
246
+ tokenizer = tiktoken.get_encoding('gpt2')
247
+ input_tokens = tokenizer.encode(prompt)
248
+ x = torch.tensor(input_tokens).unsqueeze(0).repeat(num_return_sequences, 1)
249
+ x = x.to(device)
250
+
251
+ # Calculate final length (input length + requested additional tokens)
252
+ input_length = x.size(1)
253
+ target_length = input_length + max_length
254
+
255
+ # Generate text
256
+ with torch.no_grad():
257
+ while x.size(1) < target_length:
258
+ logits = model(x)[0]
259
+ next_token_logits = logits[:, -1, :]
260
+ probs = torch.softmax(next_token_logits, dim=-1)
261
+ next_token = torch.multinomial(probs, num_samples=1)
262
+ x = torch.cat((x, next_token), dim=1)
263
+
264
+ # Print token information once before generating sequences
265
+ st.text(f"Size of Input tokens: {input_length}, Additional tokens to be predicted: {max_length}, Total tokens to be generated: {x.size(1)}")
266
+
267
+ # Decode generated sequences
268
+ generated_texts = []
269
+ for i in range(num_return_sequences):
270
+ tokens = x[i].tolist()
271
+ text = tokenizer.decode(tokens)
272
+ generated_texts.append(text)
273
+
274
+ return generated_texts
275
+
276
+ # Streamlit UI
277
  st.title("GPT Text Generator")
 
278
 
279
+ # Load model
280
+ model, device = load_model()
281
+
282
+ # Input form
283
+ prompt = st.text_area("Enter your prompt:", "Once upon a time")
284
+ max_length = st.slider("Predict additional text of length:", min_value=1, max_value=50, value=5)
285
+ num_sequences = st.slider("Number of sequences to generate:", 1, 5, 1)
286
 
287
  if st.button("Generate"):
288
+ with st.spinner("Generating text..."):
289
+ generated_texts = generate_text(
290
+ model=model,
291
+ prompt=prompt,
292
+ max_length=max_length,
293
+ num_return_sequences=num_sequences,
294
+ device=device
295
+ )
296
+
297
+ # Display results
298
+ for i, text in enumerate(generated_texts, 1):
299
+ st.write(f"\nSequence {i}:")
300
+ st.write(text)