Spaces:
Running
Running
from huggingface_hub import HfApi | |
import tempfile | |
import streamlit as st | |
import pandas as pd | |
import io, os | |
import subprocess | |
import pdfplumber | |
from lxml import etree | |
from bs4 import BeautifulSoup | |
from PyPDF2 import PdfReader | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings.base import Embeddings | |
from langchain_openai import ChatOpenAI | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.agents import Tool | |
from langchain.memory import ConversationBufferMemory | |
from langchain.text_splitter import CharacterTextSplitter | |
from dotenv import load_dotenv | |
import google.generativeai as genai | |
from typing import List | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.runnables import Runnable | |
import google.generativeai as genai | |
from datetime import datetime | |
load_dotenv() | |
def load_environment(): | |
# Ensure HF_TOKEN is available | |
if "HUGGINGFACEHUB_API_TOKEN" not in os.environ and "HF_TOKEN" in os.environ: | |
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.environ["HF_TOKEN"] | |
if "GOOGLE_API_KEY" not in os.environ: | |
raise ValueError("GOOGLE_API_KEY not found in environment variables.") | |
genai.configure(api_key=st.secrets["GOOGLE_API_KEY"]) | |
from keybert import KeyBERT | |
from sentence_transformers import CrossEncoder | |
from sentence_transformers import SentenceTransformer | |
class GeminiLLM(Runnable): | |
def __init__(self, model_name="models/gemini-1.5-pro-latest", api_key=None): | |
self.api_key = api_key or os.environ["GOOGLE_API_KEY"] | |
if not self.api_key: | |
raise ValueError("GOOGLE_API_KEY not found.") | |
genai.configure(api_key=self.api_key) | |
self.model = genai.GenerativeModel(model_name) | |
def _call(self, prompt: str, stop=None) -> str: | |
response = self.model.generate_content(prompt) | |
return response.text | |
def _llm_type(self) -> str: | |
return "custom_gemini" | |
def invoke(self, input, config=None): | |
response = self.model.generate_content(input) | |
return response.text.strip() | |
class GeminiEmbeddings(Embeddings): | |
def __init__(self, model_name="models/embedding-001", api_key=None): | |
api_key = "AIzaSyBIfGJRoet_wzzYXIiWXxStkIigEOzSR2o" | |
if not api_key: | |
raise ValueError("GOOGLE_API_KEY not found in environment variables.") | |
os.environ["GOOGLE_API_KEY"] = api_key | |
genai.configure(api_key=api_key) | |
self.model_name = model_name | |
def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
return [ | |
genai.embed_content( | |
model=self.model_name, | |
content=text, | |
task_type="retrieval_document" | |
)["embedding"] | |
for text in texts | |
] | |
def embed_query(self, text: str) -> List[float]: | |
return genai.embed_content( | |
model=self.model_name, | |
content=text, | |
task_type="retrieval_query" | |
)["embedding"] | |
vectorstore_global = None | |
if "feedback_log" not in st.session_state: | |
st.session_state["feedback_log"] = [] | |
def preload_modtran_document(): | |
global vectorstore_global | |
embeddings = GeminiEmbeddings() | |
st.session_state.vectorstore = FAISS.load_local("monte_vectorstore", embeddings, allow_dangerous_deserialization=True) | |
set_global_vectorstore(st.session_state.vectorstore) | |
st.session_state.chat_ready = True | |
def convert_pdf_to_xml(pdf_file, xml_path): | |
os.makedirs("temp", exist_ok=True) | |
pdf_path = os.path.join("temp", pdf_file.name) | |
with open(pdf_path, 'wb') as f: | |
f.write(pdf_file.getbuffer()) | |
subprocess.run(["pdftohtml", "-xml", pdf_path, xml_path], check=True) | |
return xml_path | |
def extract_text_from_xml(xml_path, document_name): | |
tree = etree.parse(xml_path) | |
text_chunks = [] | |
for page in tree.xpath("//page"): | |
page_num = int(page.get("number", 0)) | |
texts = [text.text for text in page.xpath('.//text') if text.text] | |
combined_text = '\n'.join(texts) | |
text_chunks.append({"text": combined_text, "page": page_num, "document": document_name}) | |
return text_chunks | |
def extract_text_from_pdf(pdf_file, document_name): | |
text_chunks = [] | |
with pdfplumber.open(pdf_file) as pdf: | |
for i, page in enumerate(pdf.pages): | |
text = page.extract_text() | |
if text: | |
text_chunks.append({"text": text, "page": i + 1, "document": document_name}) | |
return text_chunks | |
def get_uploaded_text(uploaded_files): | |
raw_text = [] | |
for uploaded_file in uploaded_files: | |
document_name = uploaded_file.name | |
if document_name.endswith(".pdf"): | |
text_chunks = extract_text_from_pdf(uploaded_file, document_name) | |
raw_text.extend(text_chunks) | |
elif uploaded_file.name.endswith((".html", ".htm")): | |
soup = BeautifulSoup(uploaded_file.getvalue(), 'lxml') | |
raw_text.append({"text": soup.get_text(), "page": None, "document": document_name}) | |
elif uploaded_file.name.endswith((".txt")): | |
content = uploaded_file.getvalue().decode("utf-8") | |
raw_text.append({"text": content, "page": None, "document": document_name}) | |
return raw_text | |
def get_text_chunks(raw_text): | |
splitter = CharacterTextSplitter(separator='\n', chunk_size=500, chunk_overlap=100) | |
final_chunks = [] | |
for chunk in raw_text: | |
for split_text in splitter.split_text(chunk["text"]): | |
final_chunks.append({"text": split_text, "page": chunk["page"], "document": chunk["document"]}) | |
return final_chunks | |
def get_vectorstore(text_chunks): | |
if not text_chunks: | |
raise ValueError("text_chunks is empty. Cannot initialize FAISS vectorstore.") | |
embeddings = GeminiEmbeddings() | |
texts = [chunk["text"] for chunk in text_chunks] | |
metadatas = [{"page": chunk["page"], "document": chunk["document"]} for chunk in text_chunks] | |
return FAISS.from_texts(texts, embedding=embeddings, metadatas=metadatas) | |
def set_global_vectorstore(vectorstore): | |
global vectorstore_global | |
vectorstore_global = vectorstore | |
kw_model = None | |
reranker = None | |
def get_kw_model(): | |
global kw_model | |
if kw_model is None: | |
# Load sentence transformer with HF token explicitly | |
model = SentenceTransformer( | |
'sentence-transformers/all-MiniLM-L6-v2', | |
use_auth_token=os.environ.get("HF_TOKEN") | |
) | |
kw_model = KeyBERT(model=model) | |
return kw_model | |
def self_reasoning(query, context): | |
print("π§ͺ self_reasoning received context of length:", len(context)) | |
llm = GeminiLLM() | |
reasoning_prompt = f""" | |
You are an AI assistant that analyzes the context provided to answer the user's query comprehensively and clearly. | |
Answer in a concise, factual way using the terminology from the context. Avoid extra explanation unless explicitly asked. | |
YOU MUST mention the document file name (e.g., tools.html, refguide.html) in your answer. | |
### Example 1: | |
**Question:** What is the purpose of the Monte GUI? | |
**Context:** | |
[From `tools.html`] The Monte GUI provides interfaces for setting up trajectory parameters and viewing output results. | |
**Answer:** The Monte GUI helps users configure trajectory parameters and visualize results. (From `tools.html`) | |
### Example 2: | |
**Question:** How do you perform covariance analysis in Monte? | |
**Context:** | |
[From `designEdition.html`] The Monte Design Edition includes support for statistical maneuver and covariance analysis during the design phase. | |
**Answer:** Monte supports covariance analysis through the Design Edition. (From `designEdition.html`) | |
### Now answer: | |
**Question:** {query} | |
**Context:** | |
{context} | |
**Answer:** | |
""" | |
try: | |
result = llm._call(reasoning_prompt) | |
print("β Gemini returned a result.") | |
return result | |
except Exception as e: | |
print("β Error in self_reasoning:", e) | |
return f"β οΈ Gemini failed: {e}" | |
def faiss_search_with_keywords(query): | |
global vectorstore_global | |
if vectorstore_global is None: | |
raise ValueError("FAISS vectorstore is not initialized.") | |
kw_model = get_kw_model() | |
keywords = kw_model.extract_keywords(query, keyphrase_ngram_range=(1,2), stop_words='english', top_n=5) | |
refined_query = " ".join([keyword[0] for keyword in keywords]) | |
retriever = vectorstore_global.as_retriever(search_kwargs={"k": 13}) | |
docs = retriever.get_relevant_documents(refined_query) | |
context = '\n\n'.join([f"[From `{doc.metadata.get('document', 'unknown.html')}`] {doc.page_content}" for doc in docs]) | |
return self_reasoning(query, context) | |
def get_reranker(): | |
global reranker | |
if reranker is None: | |
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
return reranker | |
def faiss_search_with_reasoning(query): | |
global vectorstore_global | |
if vectorstore_global is None: | |
raise ValueError("FAISS vectorstore is not initialized.") | |
reranker = get_reranker() | |
retriever = vectorstore_global.as_retriever(search_kwargs={"k": 13}) | |
docs = retriever.get_relevant_documents(query) | |
pairs = [(query, doc.page_content) for doc in docs] | |
scores = reranker.predict(pairs) | |
reranked_docs = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True) | |
top_docs = [doc for _, doc in reranked_docs[:5]] | |
context = '\n\n'.join([f"[From `{doc.metadata.get('document', 'unknown.html')}`] {doc.page_content.strip()}" for doc in top_docs]) | |
return self_reasoning(query, context) | |
faiss_keyword_tool = Tool( | |
name="FAISS Keyword Search", | |
func=faiss_search_with_keywords, | |
description="Searches FAISS with a keyword-based approach to retrieve context." | |
) | |
faiss_reasoning_tool = Tool( | |
name="FAISS Reasoning Search", | |
func=faiss_search_with_reasoning, | |
description="Searches FAISS with detailed reasoning to retrieve context." | |
) | |
def initialize_chatbot_agent(): | |
llm = GeminiLLM() | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
tools = [faiss_keyword_tool, faiss_reasoning_tool] | |
agent = initialize_agent( | |
tools=tools, | |
llm=llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
memory=memory, | |
verbose=False, | |
handle_parsing_errors=True | |
) | |
return agent | |
def handle_user_query(query): | |
try: | |
global vectorstore_global | |
if vectorstore_global is None: | |
raise ValueError("Vectorstore is not initialized.") | |
print("π Starting handle_user_query with:", query) | |
if "how" in query.lower(): | |
print("π§ Routing to: faiss_search_with_reasoning") | |
context = faiss_search_with_reasoning(query) | |
else: | |
print("π§ Routing to: faiss_search_with_keywords") | |
context = faiss_search_with_keywords(query) | |
print("π Context length:", len(context)) | |
print("βοΈ Calling self_reasoning...") | |
answer = self_reasoning(query, context) | |
print("β Answer generated.") | |
return answer | |
except Exception as e: | |
print("β Error in handle_user_query:", e) | |
return f"β οΈ Error: {e}" | |
def save_feedback_to_huggingface(): | |
try: | |
if not st.session_state.feedback_log: | |
print("β οΈ No feedbacks collected yet.") | |
return | |
feedback_df = pd.DataFrame(st.session_state.feedback_log) | |
now = datetime.now().strftime("%Y%m%d_%H%M%S") | |
filename = f"feedback_{now}.csv" | |
with tempfile.TemporaryDirectory() as tmpdir: | |
filepath = os.path.join(tmpdir, filename) | |
feedback_df.to_csv(filepath, index=False) | |
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if not token: | |
raise ValueError("β Hugging Face token not found!") | |
print(f"π€ Attempting upload to repo: ZarinT/chatbot-feedback as {filename}") | |
print("π Feedback data:", feedback_df) | |
api = HfApi(token=token) | |
api.upload_file( | |
path_or_fileobj=filepath, | |
path_in_repo=filename, | |
repo_id="ZarinT/chatbot-feedback", | |
repo_type="dataset" | |
) | |
print("β Feedback uploaded successfully.") | |
st.session_state.feedback_log.clear() | |
except Exception as e: | |
print("β Feedback upload failed:", e) | |
def clear_user_input(): | |
st.session_state["user_input"] = "" | |
from datetime import datetime | |
def clear_user_input(): | |
st.session_state.user_input = "" | |
from datetime import datetime | |
def clear_user_input(): | |
st.session_state.user_input = "" | |
from datetime import datetime | |
def clear_user_input(): | |
st.session_state.user_input = "" | |
def main(): | |
load_environment() | |
# Initialize session state | |
if "chat_ready" not in st.session_state: | |
st.session_state.chat_ready = False | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
if "vectorstore" not in st.session_state: | |
st.session_state.vectorstore = None | |
if "feedback_log" not in st.session_state: | |
st.session_state.feedback_log = [] | |
if "feedback_submitted" not in st.session_state: | |
st.session_state.feedback_submitted = False | |
if "last_answered_question" not in st.session_state: | |
st.session_state.last_answered_question = "" | |
st.markdown(""" | |
<h2>Chat with MONTE Documents π</h2> | |
""", unsafe_allow_html=True) | |
# Inject custom CSS for chat bubbles | |
st.markdown(""" | |
<style> | |
.chat-container { | |
display: flex; | |
flex-direction: column; | |
gap: 0.75rem; | |
} | |
.bot-bubble { | |
align-self: flex-start; | |
background-color: #e8f4fd; | |
color: #111; | |
padding: 1rem; | |
border-radius: 0.5rem 1rem 1rem 1rem; | |
max-width: 80%; | |
} | |
.user-bubble { | |
align-self: flex-end; | |
background-color: rgba(0, 0, 0, 0.85); | |
color: white; | |
padding: 1rem; | |
border-radius: 1rem 0.5rem 1rem 1rem; | |
max-width: 80%; | |
text-align: right; | |
} | |
.rating-line { | |
font-weight: bold; | |
color: #ffaa00; | |
font-size: 0.95rem; | |
max-width: 80%; | |
align-self: center; | |
text-align: center; | |
} | |
.feedback-counter { | |
font-style: italic; | |
color: #666; | |
font-size: 0.9rem; | |
margin: 1rem 0; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Load vectorstore and chatbot agent | |
if not st.session_state.chat_ready: | |
with st.spinner("Loading Monte documents..."): | |
preload_modtran_document() | |
st.session_state.agent = initialize_chatbot_agent() | |
st.session_state.chat_ready = True | |
st.success("Monte Docuemnts loaded successfully!") | |
# Render all previous Q&A in chat format | |
st.markdown('<div class="chat-container">', unsafe_allow_html=True) | |
for i, exchange in enumerate(st.session_state.chat_history): | |
# User's question (right side) | |
st.markdown(f'<div class="user-bubble"><strong>You:</strong> {exchange["user"]}</div>', unsafe_allow_html=True) | |
# MODTRAN Bot's answer (left side) | |
st.markdown(f'<div class="bot-bubble"><strong>MODTRAN Bot:</strong> {exchange["bot"]}</div>', unsafe_allow_html=True) | |
# If already rated | |
if "rating" in exchange: | |
st.markdown(f'<div class="rating-line">βοΈ You rated this: {exchange["rating"]}</div>', unsafe_allow_html=True) | |
# If not rated yet and it's the last message, show form | |
elif i == len(st.session_state.chat_history) - 1: | |
with st.form(key=f"feedback_form_{i}"): | |
rating = st.radio( | |
"Rate this response:", | |
options=["Not helpful", "Somewhat helpful", "Neutral", "Helpful", "Very helpful"], | |
key=f"rating_{i}", | |
horizontal=True | |
) | |
submitted = st.form_submit_button("Submit Rating") | |
if submitted: | |
st.session_state.chat_history[i]["rating"] = rating | |
st.session_state.feedback_log.append({ | |
"question": exchange["user"], | |
"response": exchange["bot"], | |
"rating": rating, | |
"timestamp": datetime.now().isoformat() | |
}) | |
if len(st.session_state.feedback_log) >= 2: | |
print("π¦ Upload threshold reached β saving feedback to Hugging Face.") | |
save_feedback_to_huggingface() | |
st.session_state.feedback_submitted = True | |
st.rerun() | |
st.markdown('</div>', unsafe_allow_html=True) | |
# Show current feedback progress | |
#st.markdown(f'<div class="feedback-counter">π Feedbacks collected: <strong>{len(st.session_state.feedback_log)} / 5</strong></div>', unsafe_allow_html=True) | |
# Input for next question | |
user_question = st.chat_input("Ask your next question:") | |
if user_question and user_question != st.session_state.last_answered_question: | |
with st.spinner("Generating answer..."): | |
try: | |
set_global_vectorstore(st.session_state.vectorstore) | |
response = handle_user_query(user_question) | |
except Exception as e: | |
response = f"β οΈ Something went wrong: {e}" | |
# Save new Q&A | |
st.session_state.chat_history.append({ | |
"user": user_question, | |
"bot": response | |
}) | |
st.session_state.last_answered_question = user_question | |
st.rerun() | |
if __name__ == "__main__": | |
load_environment() | |
main() | |