kouki321 commited on
Commit
3dd5a8e
Β·
verified Β·
1 Parent(s): ff48ce9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from transformers.cache_utils import DynamicCache
5
+ import os
6
+ from time import time
7
+ import pandas as pd
8
+
9
+
10
+ # ==============================
11
+ # Helper: Human-readable bytes
12
+ def sizeof_fmt(num, suffix="B"):
13
+ # Formats bytes as human-readable (e.g. 1.5 GB)
14
+ for unit in ["", "K", "M", "G", "T"]:
15
+ if abs(num) < 1024.0:
16
+ return f"{num:3.2f} {unit}{suffix}"
17
+ num /= 1024.0
18
+ return f"{num:.2f} P{suffix}"
19
+
20
+ # ==============================
21
+ # Core Model and Caching Logic
22
+ # ==============================
23
+
24
+ def generate(model, input_ids, past_key_values, max_new_tokens):
25
+ """Token-by-token generation using cache for speed."""
26
+ device = model.model.embed_tokens.weight.device
27
+ origin_len = input_ids.shape[-1]
28
+ input_ids = input_ids.to(device)
29
+ output_ids = input_ids.clone()
30
+ next_token = input_ids
31
+ with torch.no_grad():
32
+ for _ in range(50):
33
+ out = model(
34
+ input_ids=next_token,
35
+ past_key_values=past_key_values,
36
+ use_cache=True
37
+ )
38
+ logits = out.logits[:, -1, :]
39
+ token = torch.argmax(logits, dim=-1, keepdim=True)
40
+ output_ids = torch.cat([output_ids, token], dim=-1)
41
+ past_key_values = out.past_key_values
42
+ next_token = token.to(device)
43
+ if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
44
+ break
45
+ return output_ids[:, origin_len:]
46
+
47
+ def get_kv_cache(model, tokenizer, prompt):
48
+ """Prepares and stores the key-value cache for the initial document/context."""
49
+ device = model.model.embed_tokens.weight.device
50
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
+ cache = DynamicCache()
52
+ with torch.no_grad():
53
+ _ = model(
54
+ input_ids=input_ids,
55
+ past_key_values=cache,
56
+ use_cache=True
57
+ )
58
+ return cache, input_ids.shape[-1]
59
+
60
+ def clean_up(cache, origin_len):
61
+ """Trims the cache to only include the original document/context tokens."""
62
+ for i in range(len(cache.key_cache)):
63
+ cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
64
+ cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
65
+ return cache
66
+
67
+ def calculate_cache_size(cache):
68
+ """Calculate the total memory used by the key-value cache in bytes."""
69
+ total_memory = 0
70
+ for key in cache.key_cache:
71
+ total_memory += key.element_size() * key.nelement()
72
+ for value in cache.value_cache:
73
+ total_memory += value.element_size() * value.nelement()
74
+ return total_memory /(1024*1024)
75
+
76
+ @st.cache_resource
77
+ def load_model_and_tokenizer(doc_text_count):
78
+ model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
79
+ tokenizer = AutoTokenizer.from_pretrained(
80
+ model_name,
81
+ trust_remote_code=True,
82
+ model_max_length=1.3*round(doc_text_count * 0.3 + 1)
83
+ )
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ model_name,
86
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
87
+ device_map="auto",
88
+ trust_remote_code=True
89
+ )
90
+ return model, tokenizer
91
+
92
+ def clone_cache(cache):
93
+ new_cache = DynamicCache()
94
+ for key, value in zip(cache.key_cache, cache.value_cache):
95
+ new_cache.key_cache.append(key.clone())
96
+ new_cache.value_cache.append(value.clone())
97
+ return new_cache
98
+
99
+ @st.cache_resource
100
+ def load_document_and_cache(file_path):
101
+ try:
102
+ t2 = time()
103
+ with open(file_path, 'r') as file:
104
+ doc_text = file.read()
105
+ doc_text_count = len(doc_text)
106
+ model, tokenizer = load_model_and_tokenizer(doc_text_count)
107
+ system_prompt = f"""
108
+ <|system|>
109
+ Answer concisely and precisely. You are an assistant who provides concise factual answers.
110
+ <|user|>
111
+ Context:
112
+ {doc_text}
113
+ Question:
114
+ """.strip()
115
+ cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
116
+ t3 = time()
117
+ print(f"{t3-t2}")
118
+ return cache,doc_text, doc_text_count, model, tokenizer
119
+ except FileNotFoundError:
120
+ st.error(f"Document file not found at {file_path}")
121
+ return None, None, None, None
122
+
123
+
124
+ # ==============================
125
+ # Streamlit UI
126
+ # ==============================
127
+ # Initialize token counters
128
+ input_tokens_count = 0
129
+ generated_tokens_count = 0
130
+ output_tokens_count = 0
131
+
132
+ # Reset counters with a button
133
+ if st.button("πŸ”„ Reset Token Counters"):
134
+ input_tokens_count = 0
135
+ generated_tokens_count = 0
136
+ output_tokens_count = 0
137
+ doc_text = None
138
+ cache = None
139
+ model = None
140
+ tokenizer = None
141
+ st.success("Token counters have been reset.")
142
+
143
+ st.title("πŸš€ DeepSeek QA: Supercharged Caching & Memory Dashboard")
144
+
145
+ uploaded_file = st.file_uploader("πŸ“ Upload your document (.txt)", type="txt")
146
+
147
+ # Initialize variables
148
+ doc_text = None
149
+ cache = None
150
+ model = None
151
+ tokenizer = None
152
+
153
+ if uploaded_file:
154
+ log = []
155
+
156
+ # PART 1: File Upload & Save
157
+ t_start1 = time()
158
+ temp_file_path = "temp_document.txt"
159
+ with open(temp_file_path, "wb") as f:
160
+ f.write(uploaded_file.getvalue())
161
+ t_end1 = time()
162
+ log.append(f"πŸ“‚ File Upload & Save Time: {t_end1 - t_start1:.2f} s")
163
+ print(f"πŸ“‚ File Upload & Save Time: {t_end1 - t_start1:.2f} s")
164
+
165
+ # PART 2: Document and Cache Load
166
+ t_start2 = time()
167
+ cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path)
168
+ t_end2 = time()
169
+ log.append(f"πŸ“„ Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
170
+ print(f"πŸ“„ Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
171
+
172
+ # PART 3: Document Preview Display
173
+ t_start3 = time()
174
+ with st.expander("πŸ“„ Document Preview"):
175
+ preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
176
+ st.text(preview)
177
+ t_end3 = time()
178
+ log.append(f"πŸ‘€ Document Preview Display Time: {t_end3 - t_start3:.2f} s")
179
+ print(f"πŸ‘€ Document Preview Display Time: {t_end3 - t_start3:.2f} s")
180
+ t_start4 = time()
181
+ # PART 4: Show Basic Info
182
+ #doc_size_kb = os.path.getsize(temp_file_path) / 1024
183
+ #cache_size = os.path.getsize("temp_cache.pth") / 1024 if os.path.exists("temp_cache.pth") else "N/A"
184
+ t_end4 = time()
185
+ log.append(f"πŸ‘€ doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s")
186
+ print(f"πŸ‘€ doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s")
187
+ #st.info(
188
+ # f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
189
+ # f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}"
190
+ #)
191
+
192
+ # =========================
193
+ # User Query and Generation
194
+ # =========================
195
+
196
+ query = st.text_input("πŸ”Ž Ask a question about the document:")
197
+ if query and st.button("Generate Answer"):
198
+ with st.spinner("Generating answer..."):
199
+ log.append("πŸš€ Query & Generation Steps:")
200
+
201
+ # PART 4.1: Clone Cache
202
+ t_start5 = time()
203
+ current_cache = clone_cache(cache)
204
+ t_end5 = time()
205
+ print(f"πŸ” Clone Cache Time: {t_end5 - t_start5:.2f} s")
206
+ log.append(f"πŸ” Clone Cache Time: {t_end5 - t_start5:.2f} s")
207
+
208
+ # PART 4.2: Tokenize Prompt
209
+ t_start6 = time()
210
+ model, tokenizer = load_model_and_tokenizer(doc_text_count)
211
+ full_prompt = f"""
212
+ <|user|>
213
+ Question: {query}
214
+ <|assistant|>
215
+ """.strip()
216
+ input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
217
+ input_tokens_count += input_ids.shape[-1]
218
+ t_end6 = time()
219
+ print(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s")
220
+ log.append(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s")
221
+
222
+ # PART 4.3: Generate Answer
223
+ t_start7 = time()
224
+ output_ids = generate(model, input_ids, current_cache, max_new_tokens=4)
225
+ last_generation_time = time() - t_start7
226
+ print(f"πŸ’‘ Generation Time: {last_generation_time:.2f} s")
227
+ log.append(f"πŸ’‘ Generation Time: {last_generation_time:.2f} s")
228
+
229
+ generated_tokens_count = output_ids.shape[-1]
230
+ generated_tokens_count += generated_tokens_count
231
+ output_tokens_count = generated_tokens_count
232
+
233
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
234
+
235
+ st.success("Answer:")
236
+ st.write(response)
237
+
238
+ # Final Info Display
239
+ st.info(
240
+ # f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
241
+ f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s"
242
+ )
243
+
244
+ # =========================
245
+ # Show Log
246
+ # =========================
247
+ st.sidebar.header("πŸ•’ Performance Log")
248
+ for entry in log:
249
+ st.sidebar.write(entry)
250
+
251
+ # =========================
252
+ # Sidebar: Cache Loader
253
+ # =========================
254
+
255
+ st.sidebar.header("πŸ› οΈ Advanced Options")
256
+ st.sidebar.write("Load a previously saved cache for instant document context reuse.")
257
+
258
+ if st.sidebar.checkbox("Load saved cache"):
259
+ cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth")
260
+ if cache_file:
261
+ with open("temp_cache.pth", "wb") as f:
262
+ f.write(cache_file.getvalue())
263
+ try:
264
+ loaded_cache = torch.load("temp_cache.pth")
265
+ cache = loaded_cache
266
+ st.sidebar.success("Cache loaded successfully!")
267
+ except Exception as e:
268
+ st.sidebar.error(f"Failed to load cache file: {e}")