Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
from transformers.cache_utils import DynamicCache
|
5 |
+
import os
|
6 |
+
from time import time
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
|
10 |
+
# ==============================
|
11 |
+
# Helper: Human-readable bytes
|
12 |
+
def sizeof_fmt(num, suffix="B"):
|
13 |
+
# Formats bytes as human-readable (e.g. 1.5 GB)
|
14 |
+
for unit in ["", "K", "M", "G", "T"]:
|
15 |
+
if abs(num) < 1024.0:
|
16 |
+
return f"{num:3.2f} {unit}{suffix}"
|
17 |
+
num /= 1024.0
|
18 |
+
return f"{num:.2f} P{suffix}"
|
19 |
+
|
20 |
+
# ==============================
|
21 |
+
# Core Model and Caching Logic
|
22 |
+
# ==============================
|
23 |
+
|
24 |
+
def generate(model, input_ids, past_key_values, max_new_tokens):
|
25 |
+
"""Token-by-token generation using cache for speed."""
|
26 |
+
device = model.model.embed_tokens.weight.device
|
27 |
+
origin_len = input_ids.shape[-1]
|
28 |
+
input_ids = input_ids.to(device)
|
29 |
+
output_ids = input_ids.clone()
|
30 |
+
next_token = input_ids
|
31 |
+
with torch.no_grad():
|
32 |
+
for _ in range(50):
|
33 |
+
out = model(
|
34 |
+
input_ids=next_token,
|
35 |
+
past_key_values=past_key_values,
|
36 |
+
use_cache=True
|
37 |
+
)
|
38 |
+
logits = out.logits[:, -1, :]
|
39 |
+
token = torch.argmax(logits, dim=-1, keepdim=True)
|
40 |
+
output_ids = torch.cat([output_ids, token], dim=-1)
|
41 |
+
past_key_values = out.past_key_values
|
42 |
+
next_token = token.to(device)
|
43 |
+
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
|
44 |
+
break
|
45 |
+
return output_ids[:, origin_len:]
|
46 |
+
|
47 |
+
def get_kv_cache(model, tokenizer, prompt):
|
48 |
+
"""Prepares and stores the key-value cache for the initial document/context."""
|
49 |
+
device = model.model.embed_tokens.weight.device
|
50 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
51 |
+
cache = DynamicCache()
|
52 |
+
with torch.no_grad():
|
53 |
+
_ = model(
|
54 |
+
input_ids=input_ids,
|
55 |
+
past_key_values=cache,
|
56 |
+
use_cache=True
|
57 |
+
)
|
58 |
+
return cache, input_ids.shape[-1]
|
59 |
+
|
60 |
+
def clean_up(cache, origin_len):
|
61 |
+
"""Trims the cache to only include the original document/context tokens."""
|
62 |
+
for i in range(len(cache.key_cache)):
|
63 |
+
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
|
64 |
+
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
|
65 |
+
return cache
|
66 |
+
|
67 |
+
def calculate_cache_size(cache):
|
68 |
+
"""Calculate the total memory used by the key-value cache in bytes."""
|
69 |
+
total_memory = 0
|
70 |
+
for key in cache.key_cache:
|
71 |
+
total_memory += key.element_size() * key.nelement()
|
72 |
+
for value in cache.value_cache:
|
73 |
+
total_memory += value.element_size() * value.nelement()
|
74 |
+
return total_memory /(1024*1024)
|
75 |
+
|
76 |
+
@st.cache_resource
|
77 |
+
def load_model_and_tokenizer(doc_text_count):
|
78 |
+
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
80 |
+
model_name,
|
81 |
+
trust_remote_code=True,
|
82 |
+
model_max_length=1.3*round(doc_text_count * 0.3 + 1)
|
83 |
+
)
|
84 |
+
model = AutoModelForCausalLM.from_pretrained(
|
85 |
+
model_name,
|
86 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
87 |
+
device_map="auto",
|
88 |
+
trust_remote_code=True
|
89 |
+
)
|
90 |
+
return model, tokenizer
|
91 |
+
|
92 |
+
def clone_cache(cache):
|
93 |
+
new_cache = DynamicCache()
|
94 |
+
for key, value in zip(cache.key_cache, cache.value_cache):
|
95 |
+
new_cache.key_cache.append(key.clone())
|
96 |
+
new_cache.value_cache.append(value.clone())
|
97 |
+
return new_cache
|
98 |
+
|
99 |
+
@st.cache_resource
|
100 |
+
def load_document_and_cache(file_path):
|
101 |
+
try:
|
102 |
+
t2 = time()
|
103 |
+
with open(file_path, 'r') as file:
|
104 |
+
doc_text = file.read()
|
105 |
+
doc_text_count = len(doc_text)
|
106 |
+
model, tokenizer = load_model_and_tokenizer(doc_text_count)
|
107 |
+
system_prompt = f"""
|
108 |
+
<|system|>
|
109 |
+
Answer concisely and precisely. You are an assistant who provides concise factual answers.
|
110 |
+
<|user|>
|
111 |
+
Context:
|
112 |
+
{doc_text}
|
113 |
+
Question:
|
114 |
+
""".strip()
|
115 |
+
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
|
116 |
+
t3 = time()
|
117 |
+
print(f"{t3-t2}")
|
118 |
+
return cache,doc_text, doc_text_count, model, tokenizer
|
119 |
+
except FileNotFoundError:
|
120 |
+
st.error(f"Document file not found at {file_path}")
|
121 |
+
return None, None, None, None
|
122 |
+
|
123 |
+
|
124 |
+
# ==============================
|
125 |
+
# Streamlit UI
|
126 |
+
# ==============================
|
127 |
+
# Initialize token counters
|
128 |
+
input_tokens_count = 0
|
129 |
+
generated_tokens_count = 0
|
130 |
+
output_tokens_count = 0
|
131 |
+
|
132 |
+
# Reset counters with a button
|
133 |
+
if st.button("π Reset Token Counters"):
|
134 |
+
input_tokens_count = 0
|
135 |
+
generated_tokens_count = 0
|
136 |
+
output_tokens_count = 0
|
137 |
+
doc_text = None
|
138 |
+
cache = None
|
139 |
+
model = None
|
140 |
+
tokenizer = None
|
141 |
+
st.success("Token counters have been reset.")
|
142 |
+
|
143 |
+
st.title("π DeepSeek QA: Supercharged Caching & Memory Dashboard")
|
144 |
+
|
145 |
+
uploaded_file = st.file_uploader("π Upload your document (.txt)", type="txt")
|
146 |
+
|
147 |
+
# Initialize variables
|
148 |
+
doc_text = None
|
149 |
+
cache = None
|
150 |
+
model = None
|
151 |
+
tokenizer = None
|
152 |
+
|
153 |
+
if uploaded_file:
|
154 |
+
log = []
|
155 |
+
|
156 |
+
# PART 1: File Upload & Save
|
157 |
+
t_start1 = time()
|
158 |
+
temp_file_path = "temp_document.txt"
|
159 |
+
with open(temp_file_path, "wb") as f:
|
160 |
+
f.write(uploaded_file.getvalue())
|
161 |
+
t_end1 = time()
|
162 |
+
log.append(f"π File Upload & Save Time: {t_end1 - t_start1:.2f} s")
|
163 |
+
print(f"π File Upload & Save Time: {t_end1 - t_start1:.2f} s")
|
164 |
+
|
165 |
+
# PART 2: Document and Cache Load
|
166 |
+
t_start2 = time()
|
167 |
+
cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path)
|
168 |
+
t_end2 = time()
|
169 |
+
log.append(f"π Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
|
170 |
+
print(f"π Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
|
171 |
+
|
172 |
+
# PART 3: Document Preview Display
|
173 |
+
t_start3 = time()
|
174 |
+
with st.expander("π Document Preview"):
|
175 |
+
preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
|
176 |
+
st.text(preview)
|
177 |
+
t_end3 = time()
|
178 |
+
log.append(f"π Document Preview Display Time: {t_end3 - t_start3:.2f} s")
|
179 |
+
print(f"π Document Preview Display Time: {t_end3 - t_start3:.2f} s")
|
180 |
+
t_start4 = time()
|
181 |
+
# PART 4: Show Basic Info
|
182 |
+
#doc_size_kb = os.path.getsize(temp_file_path) / 1024
|
183 |
+
#cache_size = os.path.getsize("temp_cache.pth") / 1024 if os.path.exists("temp_cache.pth") else "N/A"
|
184 |
+
t_end4 = time()
|
185 |
+
log.append(f"π doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s")
|
186 |
+
print(f"π doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s")
|
187 |
+
#st.info(
|
188 |
+
# f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
|
189 |
+
# f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}"
|
190 |
+
#)
|
191 |
+
|
192 |
+
# =========================
|
193 |
+
# User Query and Generation
|
194 |
+
# =========================
|
195 |
+
|
196 |
+
query = st.text_input("π Ask a question about the document:")
|
197 |
+
if query and st.button("Generate Answer"):
|
198 |
+
with st.spinner("Generating answer..."):
|
199 |
+
log.append("π Query & Generation Steps:")
|
200 |
+
|
201 |
+
# PART 4.1: Clone Cache
|
202 |
+
t_start5 = time()
|
203 |
+
current_cache = clone_cache(cache)
|
204 |
+
t_end5 = time()
|
205 |
+
print(f"π Clone Cache Time: {t_end5 - t_start5:.2f} s")
|
206 |
+
log.append(f"π Clone Cache Time: {t_end5 - t_start5:.2f} s")
|
207 |
+
|
208 |
+
# PART 4.2: Tokenize Prompt
|
209 |
+
t_start6 = time()
|
210 |
+
model, tokenizer = load_model_and_tokenizer(doc_text_count)
|
211 |
+
full_prompt = f"""
|
212 |
+
<|user|>
|
213 |
+
Question: {query}
|
214 |
+
<|assistant|>
|
215 |
+
""".strip()
|
216 |
+
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
|
217 |
+
input_tokens_count += input_ids.shape[-1]
|
218 |
+
t_end6 = time()
|
219 |
+
print(f"βοΈ Tokenization Time: {t_end6 - t_start6:.2f} s")
|
220 |
+
log.append(f"βοΈ Tokenization Time: {t_end6 - t_start6:.2f} s")
|
221 |
+
|
222 |
+
# PART 4.3: Generate Answer
|
223 |
+
t_start7 = time()
|
224 |
+
output_ids = generate(model, input_ids, current_cache, max_new_tokens=4)
|
225 |
+
last_generation_time = time() - t_start7
|
226 |
+
print(f"π‘ Generation Time: {last_generation_time:.2f} s")
|
227 |
+
log.append(f"π‘ Generation Time: {last_generation_time:.2f} s")
|
228 |
+
|
229 |
+
generated_tokens_count = output_ids.shape[-1]
|
230 |
+
generated_tokens_count += generated_tokens_count
|
231 |
+
output_tokens_count = generated_tokens_count
|
232 |
+
|
233 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
234 |
+
|
235 |
+
st.success("Answer:")
|
236 |
+
st.write(response)
|
237 |
+
|
238 |
+
# Final Info Display
|
239 |
+
st.info(
|
240 |
+
# f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
|
241 |
+
f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s"
|
242 |
+
)
|
243 |
+
|
244 |
+
# =========================
|
245 |
+
# Show Log
|
246 |
+
# =========================
|
247 |
+
st.sidebar.header("π Performance Log")
|
248 |
+
for entry in log:
|
249 |
+
st.sidebar.write(entry)
|
250 |
+
|
251 |
+
# =========================
|
252 |
+
# Sidebar: Cache Loader
|
253 |
+
# =========================
|
254 |
+
|
255 |
+
st.sidebar.header("π οΈ Advanced Options")
|
256 |
+
st.sidebar.write("Load a previously saved cache for instant document context reuse.")
|
257 |
+
|
258 |
+
if st.sidebar.checkbox("Load saved cache"):
|
259 |
+
cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth")
|
260 |
+
if cache_file:
|
261 |
+
with open("temp_cache.pth", "wb") as f:
|
262 |
+
f.write(cache_file.getvalue())
|
263 |
+
try:
|
264 |
+
loaded_cache = torch.load("temp_cache.pth")
|
265 |
+
cache = loaded_cache
|
266 |
+
st.sidebar.success("Cache loaded successfully!")
|
267 |
+
except Exception as e:
|
268 |
+
st.sidebar.error(f"Failed to load cache file: {e}")
|