import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM from transformers.cache_utils import DynamicCache import os from time import time import pandas as pd from huggingface_hub import login HF_TOKEN = os.getenv("NEX_MODEL") # Updated key name for clarity if not HF_TOKEN: raise ValueError("Hugging Face token not found. Please set the 'NEX_MODEL' environment variable.") # ============================== # Helper: Human-readable bytes def sizeof_fmt(num, suffix="B"): # Formats bytes as human-readable (e.g. 1.5 GB) for unit in ["", "K", "M", "G", "T"]: if abs(num) < 1024.0: return f"{num:3.2f} {unit}{suffix}" num /= 1024.0 return f"{num:.2f} P{suffix}" # ============================== # Core Model and Caching Logic # ============================== def generate(model, input_ids, past_key_values, max_new_tokens): """Token-by-token generation using cache for speed.""" device = model.model.embed_tokens.weight.device origin_len = input_ids.shape[-1] input_ids = input_ids.to(device) output_ids = input_ids.clone() next_token = input_ids with torch.no_grad(): for _ in range(50): out = model( input_ids=next_token, past_key_values=past_key_values, use_cache=True ) logits = out.logits[:, -1, :] token = torch.argmax(logits, dim=-1, keepdim=True) output_ids = torch.cat([output_ids, token], dim=-1) past_key_values = out.past_key_values next_token = token.to(device) if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: break return output_ids[:, origin_len:] def get_kv_cache(model, tokenizer, prompt): """Prepares and stores the key-value cache for the initial document/context.""" device = model.model.embed_tokens.weight.device input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) cache = DynamicCache() with torch.no_grad(): _ = model( input_ids=input_ids, past_key_values=cache, use_cache=True ) return cache, input_ids.shape[-1] def clean_up(cache, origin_len): """Trims the cache to only include the original document/context tokens.""" for i in range(len(cache.key_cache)): cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] return cache def calculate_cache_size(cache): """Calculate the total memory used by the key-value cache in bytes.""" total_memory = 0 for key in cache.key_cache: total_memory += key.element_size() * key.nelement() for value in cache.value_cache: total_memory += value.element_size() * value.nelement() return total_memory /(1024*1024) @st.cache_resource def load_model_and_tokenizer(): model_name = "GeneZC/MiniChat-1.5-3B" tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=False, trust_remote_code=True ,token=HF_TOKEN ) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", trust_remote_code=True ,token=HF_TOKEN ) return model, tokenizer def calculate_cache_size(cache): """ Calculate the total memory used by the key-value cache (past_key_values) in megabytes. Args: cache: The past_key_values object (usually a tuple of (key, value) pairs per layer). Returns: Total memory in megabytes. """ total_memory = 0 for layer_cache in cache: key_tensor, value_tensor = layer_cache total_memory += key_tensor.element_size() * key_tensor.nelement() total_memory += value_tensor.element_size() * value_tensor.nelement() return total_memory / (1024 * 1024) # Convert to MB def clone_cache(cache): new_cache = DynamicCache() for key, value in zip(cache.key_cache, cache.value_cache): new_cache.key_cache.append(key.clone()) new_cache.value_cache.append(value.clone()) return new_cache @st.cache_resource def load_document_and_cache(file_path): try: t2 = time() with open(file_path, 'r') as file: doc_text = file.read() doc_text_count = len(doc_text) max_length = int(1.3 * (doc_text_count * 0.3 + 1)) # Cap the value at 16824 if max_length > 16824: max_length = 16824 print(f" model_max_length set to: {max_length}") model, tokenizer = load_model_and_tokenizer() tokenizer.model_max_length=max_length system_prompt = f""" <|system|> You are a helpful assistant. Provide concise, factual answers based only on the provided context. If the information is not available, respond with: "I'm sorry, I don't have enough information to answer that." <|user|> Context: {doc_text} Question: """.strip() cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) t3 = time() print(f"{t3-t2}") return cache,doc_text, doc_text_count, model, tokenizer except FileNotFoundError: st.error(f"Document file not found at {file_path}") return None, None, None, None # ============================== # Streamlit UI # ============================== # Initialize token counters input_tokens_count = 0 generated_tokens_count = 0 output_tokens_count = 0 # Reset counters with a button if st.button("🔄 Reset Token Counters"): input_tokens_count = 0 generated_tokens_count = 0 output_tokens_count = 0 doc_text = None cache = None model = None tokenizer = None st.success("Token counters have been reset.") st.title("🚀 DeepSeek QA: Supercharged Caching & Memory Dashboard") uploaded_file = st.file_uploader("📝 Upload your document (.txt)", type="txt") # Initialize variables doc_text = None cache = None model = None tokenizer = None if uploaded_file: log = [] # PART 1: File Upload & Save t_start1 = time() temp_file_path = "temp_document.txt" with open(temp_file_path, "wb") as f: f.write(uploaded_file.getvalue()) t_end1 = time() log.append(f"📂 File Upload & Save Time: {t_end1 - t_start1:.2f} s") print(f"📂 File Upload & Save Time: {t_end1 - t_start1:.2f} s") # PART 2: Document and Cache Load t_start2 = time() cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path) t_end2 = time() log.append(f"📄 Document & Cache Load Time: {t_end2 - t_start2:.2f} s") print(f"📄 Document & Cache Load Time: {t_end2 - t_start2:.2f} s") # PART 3: Document Preview Display t_start3 = time() with st.expander("📄 Document Preview"): preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text st.text(preview) t_end3 = time() log.append(f"👀 Document Preview Display Time: {t_end3 - t_start3:.2f} s") print(f"👀 Document Preview Display Time: {t_end3 - t_start3:.2f} s") t_start4 = time() # PART 4: Show Basic Info s_cache=calculate_cache_size(cache) t_end4 = time() log.append(f"👀 doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s") print(f"👀 doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s||||||| size of the cache : {s_cache} MB") #st.info( # f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | " # f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}" #) # ========================= # User Query and Generation # ========================= query = st.text_input("🔎 Ask a question about the document:") if query and st.button("Generate Answer"): with st.spinner("Generating answer..."): log.append("🚀 Query & Generation Steps:") # PART 4.1: Clone Cache t_start5 = time() current_cache = clone_cache(cache) t_end5 = time() print(f"🔁 Clone Cache Time: {t_end5 - t_start5:.2f} s") log.append(f"🔁 Clone Cache Time: {t_end5 - t_start5:.2f} s") # PART 4.2: Tokenize Prompt t_start6 = time() full_prompt = f""" <|user|> Question: Please provide a clear and concise answer to the question .{query} <|assistant|> """.strip() input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids input_tokens_count += input_ids.shape[-1] t_end6 = time() print(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s") log.append(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s") # PART 4.3: Generate Answer t_start7 = time() output_ids = generate(model, input_ids, current_cache, max_new_tokens=4) last_generation_time = time() - t_start7 print(f"💡 Generation Time: {last_generation_time:.2f} s") log.append(f"💡 Generation Time: {last_generation_time:.2f} s") generated_tokens_count = output_ids.shape[-1] generated_tokens_count += generated_tokens_count output_tokens_count = generated_tokens_count response = tokenizer.decode(output_ids[0], skip_special_tokens=True) st.success("Answer:") st.write(response) print(f"***************************************************************************************") # Final Info Display st.info( # f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | " f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s" ) # ========================= # Show Log # ========================= st.sidebar.header("🕒 Performance Log") for entry in log: st.sidebar.write(entry) # ========================= # Sidebar: Cache Loader # ========================= st.sidebar.header("🛠️ Advanced Options") st.sidebar.write("Load a previously saved cache for instant document context reuse.") if st.sidebar.checkbox("Load saved cache"): cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth") if cache_file: with open("temp_cache.pth", "wb") as f: f.write(cache_file.getvalue()) try: loaded_cache = torch.load("temp_cache.pth") cache = loaded_cache st.sidebar.success("Cache loaded successfully!") except Exception as e: st.sidebar.error(f"Failed to load cache file: {e}")