Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.cache_utils import DynamicCache | |
import os | |
from time import time | |
import pandas as pd | |
from huggingface_hub import login | |
HF_TOKEN = os.getenv("NEX_MODEL") # Updated key name for clarity | |
if not HF_TOKEN: | |
raise ValueError("Hugging Face token not found. Please set the 'NEX_MODEL' environment variable.") | |
# ============================== | |
# Helper: Human-readable bytes | |
def sizeof_fmt(num, suffix="B"): | |
# Formats bytes as human-readable (e.g. 1.5 GB) | |
for unit in ["", "K", "M", "G", "T"]: | |
if abs(num) < 1024.0: | |
return f"{num:3.2f} {unit}{suffix}" | |
num /= 1024.0 | |
return f"{num:.2f} P{suffix}" | |
# ============================== | |
# Core Model and Caching Logic | |
# ============================== | |
def generate(model, input_ids, past_key_values, max_new_tokens): | |
"""Token-by-token generation using cache for speed.""" | |
device = model.model.embed_tokens.weight.device | |
origin_len = input_ids.shape[-1] | |
input_ids = input_ids.to(device) | |
output_ids = input_ids.clone() | |
next_token = input_ids | |
with torch.no_grad(): | |
for _ in range(50): | |
out = model( | |
input_ids=next_token, | |
past_key_values=past_key_values, | |
use_cache=True | |
) | |
logits = out.logits[:, -1, :] | |
token = torch.argmax(logits, dim=-1, keepdim=True) | |
output_ids = torch.cat([output_ids, token], dim=-1) | |
past_key_values = out.past_key_values | |
next_token = token.to(device) | |
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id: | |
break | |
return output_ids[:, origin_len:] | |
def get_kv_cache(model, tokenizer, prompt): | |
"""Prepares and stores the key-value cache for the initial document/context.""" | |
device = model.model.embed_tokens.weight.device | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
cache = DynamicCache() | |
with torch.no_grad(): | |
_ = model( | |
input_ids=input_ids, | |
past_key_values=cache, | |
use_cache=True | |
) | |
return cache, input_ids.shape[-1] | |
def clean_up(cache, origin_len): | |
"""Trims the cache to only include the original document/context tokens.""" | |
for i in range(len(cache.key_cache)): | |
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :] | |
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :] | |
return cache | |
def calculate_cache_size(cache): | |
"""Calculate the total memory used by the key-value cache in bytes.""" | |
total_memory = 0 | |
for key in cache.key_cache: | |
total_memory += key.element_size() * key.nelement() | |
for value in cache.value_cache: | |
total_memory += value.element_size() * value.nelement() | |
return total_memory /(1024*1024) | |
def load_model_and_tokenizer(): | |
model_name = "GeneZC/MiniChat-1.5-3B" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
use_fast=False, | |
trust_remote_code=True | |
,token=HF_TOKEN | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
trust_remote_code=True | |
,token=HF_TOKEN | |
) | |
return model, tokenizer | |
def calculate_cache_size(cache): | |
""" | |
Calculate the total memory used by the key-value cache (past_key_values) in megabytes. | |
Args: | |
cache: The past_key_values object (usually a tuple of (key, value) pairs per layer). | |
Returns: | |
Total memory in megabytes. | |
""" | |
total_memory = 0 | |
for layer_cache in cache: | |
key_tensor, value_tensor = layer_cache | |
total_memory += key_tensor.element_size() * key_tensor.nelement() | |
total_memory += value_tensor.element_size() * value_tensor.nelement() | |
return total_memory / (1024 * 1024) # Convert to MB | |
def clone_cache(cache): | |
new_cache = DynamicCache() | |
for key, value in zip(cache.key_cache, cache.value_cache): | |
new_cache.key_cache.append(key.clone()) | |
new_cache.value_cache.append(value.clone()) | |
return new_cache | |
def load_document_and_cache(file_path): | |
try: | |
t2 = time() | |
with open(file_path, 'r') as file: | |
doc_text = file.read() | |
doc_text_count = len(doc_text) | |
max_length = int(1.3 * (doc_text_count * 0.3 + 1)) | |
# Cap the value at 16824 | |
if max_length > 16824: | |
max_length = 16824 | |
print(f" model_max_length set to: {max_length}") | |
model, tokenizer = load_model_and_tokenizer() | |
tokenizer.model_max_length=max_length | |
system_prompt = f""" | |
<|system|> | |
You are a helpful assistant. Provide concise, factual answers based only on the provided context. | |
If the information is not available, respond with: "I'm sorry, I don't have enough information to answer that." | |
<|user|> | |
Context: | |
{doc_text} | |
Question: | |
""".strip() | |
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt) | |
t3 = time() | |
print(f"{t3-t2}") | |
return cache,doc_text, doc_text_count, model, tokenizer | |
except FileNotFoundError: | |
st.error(f"Document file not found at {file_path}") | |
return None, None, None, None | |
# ============================== | |
# Streamlit UI | |
# ============================== | |
# Initialize token counters | |
input_tokens_count = 0 | |
generated_tokens_count = 0 | |
output_tokens_count = 0 | |
# Reset counters with a button | |
if st.button("π Reset Token Counters"): | |
input_tokens_count = 0 | |
generated_tokens_count = 0 | |
output_tokens_count = 0 | |
doc_text = None | |
cache = None | |
model = None | |
tokenizer = None | |
st.success("Token counters have been reset.") | |
st.title("π DeepSeek QA: Supercharged Caching & Memory Dashboard") | |
uploaded_file = st.file_uploader("π Upload your document (.txt)", type="txt") | |
# Initialize variables | |
doc_text = None | |
cache = None | |
model = None | |
tokenizer = None | |
if uploaded_file: | |
log = [] | |
# PART 1: File Upload & Save | |
t_start1 = time() | |
temp_file_path = "temp_document.txt" | |
with open(temp_file_path, "wb") as f: | |
f.write(uploaded_file.getvalue()) | |
t_end1 = time() | |
log.append(f"π File Upload & Save Time: {t_end1 - t_start1:.2f} s") | |
print(f"π File Upload & Save Time: {t_end1 - t_start1:.2f} s") | |
# PART 2: Document and Cache Load | |
t_start2 = time() | |
cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path) | |
t_end2 = time() | |
log.append(f"π Document & Cache Load Time: {t_end2 - t_start2:.2f} s") | |
print(f"π Document & Cache Load Time: {t_end2 - t_start2:.2f} s") | |
# PART 3: Document Preview Display | |
t_start3 = time() | |
with st.expander("π Document Preview"): | |
preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text | |
st.text(preview) | |
t_end3 = time() | |
log.append(f"π Document Preview Display Time: {t_end3 - t_start3:.2f} s") | |
print(f"π Document Preview Display Time: {t_end3 - t_start3:.2f} s") | |
t_start4 = time() | |
# PART 4: Show Basic Info | |
s_cache=calculate_cache_size(cache) | |
t_end4 = time() | |
log.append(f"π doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s") | |
print(f"π doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s||||||| size of the cache : {s_cache} MB") | |
#st.info( | |
# f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | " | |
# f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}" | |
#) | |
# ========================= | |
# User Query and Generation | |
# ========================= | |
query = st.text_input("π Ask a question about the document:") | |
if query and st.button("Generate Answer"): | |
with st.spinner("Generating answer..."): | |
log.append("π Query & Generation Steps:") | |
# PART 4.1: Clone Cache | |
t_start5 = time() | |
current_cache = clone_cache(cache) | |
t_end5 = time() | |
print(f"π Clone Cache Time: {t_end5 - t_start5:.2f} s") | |
log.append(f"π Clone Cache Time: {t_end5 - t_start5:.2f} s") | |
# PART 4.2: Tokenize Prompt | |
t_start6 = time() | |
full_prompt = f""" | |
<|user|> | |
Question: Please provide a clear and concise answer to the question .{query} | |
<|assistant|> | |
""".strip() | |
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids | |
input_tokens_count += input_ids.shape[-1] | |
t_end6 = time() | |
print(f"βοΈ Tokenization Time: {t_end6 - t_start6:.2f} s") | |
log.append(f"βοΈ Tokenization Time: {t_end6 - t_start6:.2f} s") | |
# PART 4.3: Generate Answer | |
t_start7 = time() | |
output_ids = generate(model, input_ids, current_cache, max_new_tokens=4) | |
last_generation_time = time() - t_start7 | |
print(f"π‘ Generation Time: {last_generation_time:.2f} s") | |
log.append(f"π‘ Generation Time: {last_generation_time:.2f} s") | |
generated_tokens_count = output_ids.shape[-1] | |
generated_tokens_count += generated_tokens_count | |
output_tokens_count = generated_tokens_count | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
st.success("Answer:") | |
st.write(response) | |
print(f"***************************************************************************************") | |
# Final Info Display | |
st.info( | |
# f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | " | |
f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s" | |
) | |
# ========================= | |
# Show Log | |
# ========================= | |
st.sidebar.header("π Performance Log") | |
for entry in log: | |
st.sidebar.write(entry) | |
# ========================= | |
# Sidebar: Cache Loader | |
# ========================= | |
st.sidebar.header("π οΈ Advanced Options") | |
st.sidebar.write("Load a previously saved cache for instant document context reuse.") | |
if st.sidebar.checkbox("Load saved cache"): | |
cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth") | |
if cache_file: | |
with open("temp_cache.pth", "wb") as f: | |
f.write(cache_file.getvalue()) | |
try: | |
loaded_cache = torch.load("temp_cache.pth") | |
cache = loaded_cache | |
st.sidebar.success("Cache loaded successfully!") | |
except Exception as e: | |
st.sidebar.error(f"Failed to load cache file: {e}") |