ovrelord's picture
Update app.py
0d3044e verified
import gradio as gr
import random
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from ctransformers import AutoModelForCausalLM
import time
import pickle
# Load LLM (4-bit quantized)
start_load = time.time()
llm = AutoModelForCausalLM.from_pretrained(
"TheBloke/zephyr-7B-alpha-GGUF",
model_file="zephyr-7b-alpha.Q4_K_M.gguf",
model_type="mistral"
)
print(f"Model loaded in {time.time() - start_load:.2f} seconds")
# Load sentences for random queries
with open("bc sentences.pkl", "rb") as f:
bs_sen = pickle.load(f)
# Load embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# Load FAISS indexes
faiss_a = FAISS.load_local("faiss/index_a", embeddings=embedding_model, allow_dangerous_deserialization=True)
faiss_b = FAISS.load_local("faiss/index_b", embeddings=embedding_model, allow_dangerous_deserialization=True)
# Set up retrievers
retriever_a = faiss_a.as_retriever(search_kwargs={"k": 10})
retriever_b = faiss_b.as_retriever(search_kwargs={"k": 10})
# System prompts
guidance_a = """You are a 1920s trade union representative.
Use the retrieved sentences as your knowledge base.
Speak persuasively as if you are arguing with a fellow trade unionist.
Do not use your own knowledge.
Vary your sentence structures, do not repeat phrases.
Respond with 2 sentences maximum.
Do not use the names of any trade union members or name any geographic locations."""
guidance_b = """You are a 2020s trade union representative.
Use the retrieved sentences as your knowledge base.
Speak persuasively as if you are arguing with a fellow trade unionist.
Do not use your own knowledge.
Vary your sentence structures, do not repeat phrases.
Respond with 2 sentences maximum.
Do not use the names of any trade union members or name any geographic locations."""
def generate_text(prompt):
try:
output = llm(prompt, max_new_tokens=150, temperature=0.7, do_sample=True)
return output.strip()
except Exception as e:
return f"[Error generating response: {e}]"
# Dialogue turn generator
def generate_turn(query, retriever, prompt_text):
docs = retriever.get_relevant_documents(query)
context = "\n".join(doc.page_content for doc in docs)
prompt = f"{prompt_text}\n\nYou just heard the following message:\n\"{query}\"\n\nHere are 5 excerpts from your documents that may help you reply:\n{context}\n\nRespond to the message above based ONLY on this information, and speak as if you were in a real conversation."
return generate_text(prompt)
# Conversation function
def simulate_conversation(query, history):
if not query.strip():
return "", history
a_reply = generate_turn(query, retriever_a, guidance_a)
b_reply = generate_turn(a_reply, retriever_b, guidance_b)
new_history = history + [(f"User: {query}", f"Speaker A: {a_reply}\nSpeaker B: {b_reply}")]
return "", new_history
# Random prompt from bs_sen
def random_query(history):
query = random.choice(bs_sen)
return simulate_conversation(query, history)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Dialogue Simulator between Two Union Eras")
chatbot = gr.Chatbot()
txt = gr.Textbox(label="Ask a question")
ask_btn = gr.Button("Ask")
rand_btn = gr.Button("Random Conversation")
state = gr.State([])
ask_btn.click(fn=simulate_conversation, inputs=[txt, state], outputs=[txt, chatbot])
rand_btn.click(fn=random_query, inputs=state, outputs=[txt, chatbot])
# AFTER LOADING FAISS
try:
test_doc = faiss_a.similarity_search("test", k=1)
print("FAISS A loaded successfully:", test_doc)
except Exception as e:
print("[ERROR] FAISS A failed to load:", str(e))
try:
test_doc = faiss_b.similarity_search("test", k=1)
print("FAISS B loaded successfully:", test_doc)
except Exception as e:
print("[ERROR] FAISS B failed to load:", str(e))
# AFTER LOADING PKL
try:
print(f"Loaded {len(bs_sen)} sentences from bc_sentences.pkl")
print("Sample sentence:", random.choice(bs_sen))
except Exception as e:
print("[ERROR] bc_sentences.pkl failed to load:", str(e))
# AFTER LOADING LLM
try:
warmup_response = llm("Hello", max_new_tokens=10)
print("LLM warmup response:", warmup_response)
except Exception as e:
print("[ERROR] LLM failed on warmup:", str(e))
demo.launch()