File size: 10,950 Bytes
3dd5a8e
 
88d5af8
3dd5a8e
 
 
 
28e5e4b
 
ec87de0
28e5e4b
 
 
3dd5a8e
88d5af8
 
3dd5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d5af8
 
 
 
 
3dd5a8e
5245a01
9dfced6
2fb73b9
22c9862
88d5af8
2fb73b9
88d5af8
 
 
28e5e4b
3dd5a8e
88d5af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dd5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
88d5af8
 
 
 
 
 
 
 
 
 
3dd5a8e
 
22c9862
 
3dd5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d5af8
3dd5a8e
 
88d5af8
3dd5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d5af8
3dd5a8e
 
88d5af8
3dd5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d5af8
3dd5a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os
from time import time
import pandas as pd
from huggingface_hub import login

HF_TOKEN = os.getenv("NEX_MODEL")  # Updated key name for clarity

if not HF_TOKEN:
    raise ValueError("Hugging Face token not found. Please set the 'NEX_MODEL' environment variable.")



# ==============================
# Helper: Human-readable bytes
def sizeof_fmt(num, suffix="B"):
    # Formats bytes as human-readable (e.g. 1.5 GB)
    for unit in ["", "K", "M", "G", "T"]:
        if abs(num) < 1024.0:
            return f"{num:3.2f} {unit}{suffix}"
        num /= 1024.0
    return f"{num:.2f} P{suffix}"

# ==============================
# Core Model and Caching Logic
# ==============================

def generate(model, input_ids, past_key_values, max_new_tokens):
    """Token-by-token generation using cache for speed."""
    device = model.model.embed_tokens.weight.device
    origin_len = input_ids.shape[-1]
    input_ids = input_ids.to(device)
    output_ids = input_ids.clone()
    next_token = input_ids
    with torch.no_grad():
        for _ in range(50):
            out = model(
                input_ids=next_token,
                past_key_values=past_key_values,
                use_cache=True
            )
            logits = out.logits[:, -1, :]
            token = torch.argmax(logits, dim=-1, keepdim=True)
            output_ids = torch.cat([output_ids, token], dim=-1)
            past_key_values = out.past_key_values
            next_token = token.to(device)
            if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
                break
    return output_ids[:, origin_len:]

def get_kv_cache(model, tokenizer, prompt):
    """Prepares and stores the key-value cache for the initial document/context."""
    device = model.model.embed_tokens.weight.device
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    cache = DynamicCache()
    with torch.no_grad():
        _ = model(
            input_ids=input_ids,
            past_key_values=cache,
            use_cache=True
        )
    return cache, input_ids.shape[-1]

def clean_up(cache, origin_len):
    """Trims the cache to only include the original document/context tokens."""
    for i in range(len(cache.key_cache)):
        cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
        cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
    return cache

def calculate_cache_size(cache):
    """Calculate the total memory used by the key-value cache in bytes."""
    total_memory = 0
    for key in cache.key_cache:
        total_memory += key.element_size() * key.nelement()
    for value in cache.value_cache:
        total_memory += value.element_size() * value.nelement()
    return total_memory /(1024*1024)
    
@st.cache_resource
def load_model_and_tokenizer():
    model_name = "GeneZC/MiniChat-1.5-3B"
        
    
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        use_fast=False,
        trust_remote_code=True
        ,token=HF_TOKEN
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
        trust_remote_code=True
        ,token=HF_TOKEN
    )
    return model, tokenizer
def calculate_cache_size(cache):
    """
    Calculate the total memory used by the key-value cache (past_key_values) in megabytes.
    Args:
        cache: The past_key_values object (usually a tuple of (key, value) pairs per layer).
    Returns:
        Total memory in megabytes.
    """
    total_memory = 0
    for layer_cache in cache:
        key_tensor, value_tensor = layer_cache
        total_memory += key_tensor.element_size() * key_tensor.nelement()
        total_memory += value_tensor.element_size() * value_tensor.nelement()
    return total_memory / (1024 * 1024)  # Convert to MB
def clone_cache(cache):
    new_cache = DynamicCache()
    for key, value in zip(cache.key_cache, cache.value_cache):
        new_cache.key_cache.append(key.clone())
        new_cache.value_cache.append(value.clone())
    return new_cache

@st.cache_resource
def load_document_and_cache(file_path):
    try:
        t2 = time()
        with open(file_path, 'r') as file:
            doc_text = file.read()
        doc_text_count = len(doc_text)
        max_length = int(1.3 * (doc_text_count * 0.3 + 1))
    
        # Cap the value at 16824
        if max_length > 16824:
            max_length = 16824
        print(f" model_max_length set to: {max_length}")
        
        model, tokenizer = load_model_and_tokenizer()
        tokenizer.model_max_length=max_length

        system_prompt = f"""
        <|system|>
        You are a helpful assistant. Provide concise, factual answers based only on the provided context.
        If the information is not available, respond with: "I'm sorry, I don't have enough information to answer that."
        <|user|>
        Context:
        {doc_text}
        Question:
        """.strip()
        cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
        t3 = time()
        print(f"{t3-t2}")
        return cache,doc_text, doc_text_count, model, tokenizer
    except FileNotFoundError:
        st.error(f"Document file not found at {file_path}")
        return None, None, None, None

 
# ==============================
# Streamlit UI
# ==============================
# Initialize token counters
input_tokens_count = 0
generated_tokens_count = 0
output_tokens_count = 0

# Reset counters with a button
if st.button("πŸ”„ Reset Token Counters"):
    input_tokens_count = 0
    generated_tokens_count = 0
    output_tokens_count = 0
    doc_text = None
    cache = None
    model = None
    tokenizer = None
    st.success("Token counters have been reset.")

st.title("πŸš€ DeepSeek QA: Supercharged Caching & Memory Dashboard")

uploaded_file = st.file_uploader("πŸ“ Upload your document (.txt)", type="txt")

# Initialize variables
doc_text = None
cache = None
model = None
tokenizer = None

if uploaded_file:
    log = []

    # PART 1: File Upload & Save
    t_start1 = time()
    temp_file_path = "temp_document.txt"
    with open(temp_file_path, "wb") as f:
        f.write(uploaded_file.getvalue())
    t_end1 = time()
    log.append(f"πŸ“‚ File Upload & Save Time: {t_end1 - t_start1:.2f} s")
    print(f"πŸ“‚ File Upload & Save Time: {t_end1 - t_start1:.2f} s")

    # PART 2: Document and Cache Load
    t_start2 = time()
    cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path)
    t_end2 = time()
    log.append(f"πŸ“„ Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
    print(f"πŸ“„ Document & Cache Load Time: {t_end2 - t_start2:.2f} s")

    # PART 3: Document Preview Display
    t_start3 = time()
    with st.expander("πŸ“„ Document Preview"):
        preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
        st.text(preview)
    t_end3 = time()
    log.append(f"πŸ‘€ Document Preview Display Time: {t_end3 - t_start3:.2f} s")
    print(f"πŸ‘€ Document Preview Display Time: {t_end3 - t_start3:.2f} s")
    t_start4 = time()
    # PART 4: Show Basic Info
    s_cache=calculate_cache_size(cache)
    t_end4 = time()
    log.append(f"πŸ‘€ doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s")
    print(f"πŸ‘€ doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s||||||| size of the cache : {s_cache} MB")
    #st.info(
      #  f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
       # f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}"
    #)

    # =========================
    # User Query and Generation
    # =========================

    query = st.text_input("πŸ”Ž Ask a question about the document:")
    if query and st.button("Generate Answer"):
        with st.spinner("Generating answer..."):
            log.append("πŸš€ Query & Generation Steps:")

            # PART 4.1: Clone Cache
            t_start5 = time()
            current_cache = clone_cache(cache)
            t_end5 = time()
            print(f"πŸ” Clone Cache Time: {t_end5 - t_start5:.2f} s")
            log.append(f"πŸ” Clone Cache Time: {t_end5 - t_start5:.2f} s")

            # PART 4.2: Tokenize Prompt
            t_start6 = time()

            full_prompt = f"""
            <|user|>
            Question: Please provide a clear and concise answer to the question .{query}
            <|assistant|>
            """.strip()
            input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
            input_tokens_count += input_ids.shape[-1]
            t_end6 = time()
            print(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s")
            log.append(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s")

            # PART 4.3: Generate Answer
            t_start7 = time()
            output_ids = generate(model, input_ids, current_cache, max_new_tokens=4)
            last_generation_time = time() - t_start7
            print(f"πŸ’‘ Generation Time: {last_generation_time:.2f} s")
            log.append(f"πŸ’‘ Generation Time: {last_generation_time:.2f} s")

            generated_tokens_count = output_ids.shape[-1]
            generated_tokens_count += generated_tokens_count
            output_tokens_count = generated_tokens_count

            response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

            st.success("Answer:")
            st.write(response)
            print(f"***************************************************************************************")
            # Final Info Display
            st.info(
             #   f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
                f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s"
            )

    # =========================
    # Show Log
    # =========================
    st.sidebar.header("πŸ•’ Performance Log")
    for entry in log:
        st.sidebar.write(entry)

# =========================
# Sidebar: Cache Loader
# =========================

st.sidebar.header("πŸ› οΈ Advanced Options")
st.sidebar.write("Load a previously saved cache for instant document context reuse.")

if st.sidebar.checkbox("Load saved cache"):
    cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth")
    if cache_file:
        with open("temp_cache.pth", "wb") as f:
            f.write(cache_file.getvalue())
        try:
            loaded_cache = torch.load("temp_cache.pth")
            cache = loaded_cache
            st.sidebar.success("Cache loaded successfully!")
        except Exception as e:
            st.sidebar.error(f"Failed to load cache file: {e}")