cag_new_model / app.py
kouki321's picture
Update app.py
5245a01 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache
import os
from time import time
import pandas as pd
from huggingface_hub import login
HF_TOKEN = os.getenv("NEX_MODEL") # Updated key name for clarity
if not HF_TOKEN:
raise ValueError("Hugging Face token not found. Please set the 'NEX_MODEL' environment variable.")
# ==============================
# Helper: Human-readable bytes
def sizeof_fmt(num, suffix="B"):
# Formats bytes as human-readable (e.g. 1.5 GB)
for unit in ["", "K", "M", "G", "T"]:
if abs(num) < 1024.0:
return f"{num:3.2f} {unit}{suffix}"
num /= 1024.0
return f"{num:.2f} P{suffix}"
# ==============================
# Core Model and Caching Logic
# ==============================
def generate(model, input_ids, past_key_values, max_new_tokens):
"""Token-by-token generation using cache for speed."""
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(50):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:]
def get_kv_cache(model, tokenizer, prompt):
"""Prepares and stores the key-value cache for the initial document/context."""
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache()
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache, input_ids.shape[-1]
def clean_up(cache, origin_len):
"""Trims the cache to only include the original document/context tokens."""
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
return cache
def calculate_cache_size(cache):
"""Calculate the total memory used by the key-value cache in bytes."""
total_memory = 0
for key in cache.key_cache:
total_memory += key.element_size() * key.nelement()
for value in cache.value_cache:
total_memory += value.element_size() * value.nelement()
return total_memory /(1024*1024)
@st.cache_resource
def load_model_and_tokenizer():
model_name = "GeneZC/MiniChat-1.5-3B"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
use_fast=False,
trust_remote_code=True
,token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
,token=HF_TOKEN
)
return model, tokenizer
def calculate_cache_size(cache):
"""
Calculate the total memory used by the key-value cache (past_key_values) in megabytes.
Args:
cache: The past_key_values object (usually a tuple of (key, value) pairs per layer).
Returns:
Total memory in megabytes.
"""
total_memory = 0
for layer_cache in cache:
key_tensor, value_tensor = layer_cache
total_memory += key_tensor.element_size() * key_tensor.nelement()
total_memory += value_tensor.element_size() * value_tensor.nelement()
return total_memory / (1024 * 1024) # Convert to MB
def clone_cache(cache):
new_cache = DynamicCache()
for key, value in zip(cache.key_cache, cache.value_cache):
new_cache.key_cache.append(key.clone())
new_cache.value_cache.append(value.clone())
return new_cache
@st.cache_resource
def load_document_and_cache(file_path):
try:
t2 = time()
with open(file_path, 'r') as file:
doc_text = file.read()
doc_text_count = len(doc_text)
max_length = int(1.3 * (doc_text_count * 0.3 + 1))
# Cap the value at 16824
if max_length > 16824:
max_length = 16824
print(f" model_max_length set to: {max_length}")
model, tokenizer = load_model_and_tokenizer()
tokenizer.model_max_length=max_length
system_prompt = f"""
<|system|>
You are a helpful assistant. Provide concise, factual answers based only on the provided context.
If the information is not available, respond with: "I'm sorry, I don't have enough information to answer that."
<|user|>
Context:
{doc_text}
Question:
""".strip()
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
t3 = time()
print(f"{t3-t2}")
return cache,doc_text, doc_text_count, model, tokenizer
except FileNotFoundError:
st.error(f"Document file not found at {file_path}")
return None, None, None, None
# ==============================
# Streamlit UI
# ==============================
# Initialize token counters
input_tokens_count = 0
generated_tokens_count = 0
output_tokens_count = 0
# Reset counters with a button
if st.button("πŸ”„ Reset Token Counters"):
input_tokens_count = 0
generated_tokens_count = 0
output_tokens_count = 0
doc_text = None
cache = None
model = None
tokenizer = None
st.success("Token counters have been reset.")
st.title("πŸš€ DeepSeek QA: Supercharged Caching & Memory Dashboard")
uploaded_file = st.file_uploader("πŸ“ Upload your document (.txt)", type="txt")
# Initialize variables
doc_text = None
cache = None
model = None
tokenizer = None
if uploaded_file:
log = []
# PART 1: File Upload & Save
t_start1 = time()
temp_file_path = "temp_document.txt"
with open(temp_file_path, "wb") as f:
f.write(uploaded_file.getvalue())
t_end1 = time()
log.append(f"πŸ“‚ File Upload & Save Time: {t_end1 - t_start1:.2f} s")
print(f"πŸ“‚ File Upload & Save Time: {t_end1 - t_start1:.2f} s")
# PART 2: Document and Cache Load
t_start2 = time()
cache, doc_text,doc_text_count, model, tokenizer = load_document_and_cache(temp_file_path)
t_end2 = time()
log.append(f"πŸ“„ Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
print(f"πŸ“„ Document & Cache Load Time: {t_end2 - t_start2:.2f} s")
# PART 3: Document Preview Display
t_start3 = time()
with st.expander("πŸ“„ Document Preview"):
preview = doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
st.text(preview)
t_end3 = time()
log.append(f"πŸ‘€ Document Preview Display Time: {t_end3 - t_start3:.2f} s")
print(f"πŸ‘€ Document Preview Display Time: {t_end3 - t_start3:.2f} s")
t_start4 = time()
# PART 4: Show Basic Info
s_cache=calculate_cache_size(cache)
t_end4 = time()
log.append(f"πŸ‘€ doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s")
print(f"πŸ‘€ doc_size_kb Preview Display Time: {t_end4 - t_start4:.2f} s||||||| size of the cache : {s_cache} MB")
#st.info(
# f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
# f"Cache Size: {cache_size if cache_size == 'N/A' else f'{cache_size:.2f} KB'}"
#)
# =========================
# User Query and Generation
# =========================
query = st.text_input("πŸ”Ž Ask a question about the document:")
if query and st.button("Generate Answer"):
with st.spinner("Generating answer..."):
log.append("πŸš€ Query & Generation Steps:")
# PART 4.1: Clone Cache
t_start5 = time()
current_cache = clone_cache(cache)
t_end5 = time()
print(f"πŸ” Clone Cache Time: {t_end5 - t_start5:.2f} s")
log.append(f"πŸ” Clone Cache Time: {t_end5 - t_start5:.2f} s")
# PART 4.2: Tokenize Prompt
t_start6 = time()
full_prompt = f"""
<|user|>
Question: Please provide a clear and concise answer to the question .{query}
<|assistant|>
""".strip()
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
input_tokens_count += input_ids.shape[-1]
t_end6 = time()
print(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s")
log.append(f"✍️ Tokenization Time: {t_end6 - t_start6:.2f} s")
# PART 4.3: Generate Answer
t_start7 = time()
output_ids = generate(model, input_ids, current_cache, max_new_tokens=4)
last_generation_time = time() - t_start7
print(f"πŸ’‘ Generation Time: {last_generation_time:.2f} s")
log.append(f"πŸ’‘ Generation Time: {last_generation_time:.2f} s")
generated_tokens_count = output_ids.shape[-1]
generated_tokens_count += generated_tokens_count
output_tokens_count = generated_tokens_count
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
st.success("Answer:")
st.write(response)
print(f"***************************************************************************************")
# Final Info Display
st.info(
# f"Document Chars: {len(doc_text)} | Size: {doc_size_kb:.2f} KB | "
f"Cache Clone Time: {log[-3].split(': ')[1]} | Generation Time: {last_generation_time:.2f} s"
)
# =========================
# Show Log
# =========================
st.sidebar.header("πŸ•’ Performance Log")
for entry in log:
st.sidebar.write(entry)
# =========================
# Sidebar: Cache Loader
# =========================
st.sidebar.header("πŸ› οΈ Advanced Options")
st.sidebar.write("Load a previously saved cache for instant document context reuse.")
if st.sidebar.checkbox("Load saved cache"):
cache_file = st.sidebar.file_uploader("Upload saved cache file", type="pth")
if cache_file:
with open("temp_cache.pth", "wb") as f:
f.write(cache_file.getvalue())
try:
loaded_cache = torch.load("temp_cache.pth")
cache = loaded_cache
st.sidebar.success("Cache loaded successfully!")
except Exception as e:
st.sidebar.error(f"Failed to load cache file: {e}")