Spaces:
Sleeping
Sleeping
import os | |
import time | |
import pandas as pd | |
import gradio as gr | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from PyPDF2 import PdfReader | |
# Configuration constants | |
COLLECTION_NAME = "GBVRS" | |
DATA_FOLDER = "./" | |
APP_VERSION = "v1.0.0" | |
APP_NAME = "Ijwi ry'Ubufasha" | |
MAX_HISTORY_MESSAGES = 8 # Limit history to avoid token limits | |
# Global variables for application state | |
llm = None | |
embed_model = None | |
vectorstore = None | |
retriever = None | |
rag_chain = None | |
# User session management | |
class UserSession: | |
def __init__(self, session_id, llm): | |
"""Initialize a user session with unique ID and language model.""" | |
self.session_id = session_id | |
self.user_info = {"Nickname": "Guest"} | |
self.conversation_history = [] | |
self.llm = llm | |
self.welcome_message = None | |
self.last_activity = time.time() | |
def set_user(self, user_info): | |
"""Set user information and generate welcome message.""" | |
self.user_info = user_info | |
self.generate_welcome_message() | |
# Initialize conversation history with welcome message | |
welcome = self.get_welcome_message() | |
self.conversation_history = [ | |
{"role": "assistant", "content": welcome}, | |
] | |
def get_user(self): | |
"""Get current user information.""" | |
return self.user_info | |
def generate_welcome_message(self): | |
"""Generate a dynamic welcome message using the LLM.""" | |
try: | |
nickname = self.user_info.get("Nickname", "Guest") | |
# Use the LLM to generate the message | |
prompt = ( | |
f"Create a brief and warm welcome message for {nickname} that's about 1-2 sentences. " | |
f"Emphasize this is a safe space for discussing gender-based violence issues " | |
f"and that we provide support and resources. Keep it warm and reassuring." | |
) | |
response = self.llm.invoke(prompt) | |
welcome = response.content.strip() | |
# Format the message with HTML styling | |
self.welcome_message = ( | |
f"<div style='font-size: 18px; color: #4E6BBF;'>" | |
f"{welcome}" | |
f"</div>" | |
) | |
except Exception as e: | |
# Fallback welcome message | |
nickname = self.user_info.get("Nickname", "Guest") | |
self.welcome_message = ( | |
f"<div style='font-size: 18px; color: #4E6BBF;'>" | |
f"Welcome, {nickname}! You're in a safe space. We're here to provide support with " | |
f"gender-based violence issues and connect you with resources that can help." | |
f"</div>" | |
) | |
def get_welcome_message(self): | |
"""Get the formatted welcome message.""" | |
if not self.welcome_message: | |
self.generate_welcome_message() | |
return self.welcome_message | |
def add_to_history(self, role, message): | |
"""Add a message to the conversation history.""" | |
self.conversation_history.append({"role": role, "content": message}) | |
self.last_activity = time.time() | |
# Trim history if it gets too long | |
if len(self.conversation_history) > MAX_HISTORY_MESSAGES * 2: # Keep pairs of messages | |
# Keep the first message (welcome) and the most recent messages | |
self.conversation_history = [self.conversation_history[0]] + self.conversation_history[-MAX_HISTORY_MESSAGES*2+1:] | |
def get_conversation_history(self): | |
"""Get the full conversation history.""" | |
return self.conversation_history | |
def get_formatted_history(self): | |
"""Get conversation history formatted as a string for the LLM.""" | |
# Skip the welcome message and only include the last few exchanges | |
recent_history = self.conversation_history[1:] if len(self.conversation_history) > 1 else [] | |
# Limit to last MAX_HISTORY_MESSAGES exchanges | |
if len(recent_history) > MAX_HISTORY_MESSAGES * 2: | |
recent_history = recent_history[-MAX_HISTORY_MESSAGES*2:] | |
formatted_history = "" | |
for entry in recent_history: | |
role = "User" if entry["role"] == "user" else "Assistant" | |
# Truncate very long messages to avoid token limits | |
content = entry["content"] | |
if len(content) > 500: # Limit message length | |
content = content[:500] + "..." | |
formatted_history += f"{role}: {content}\n\n" | |
return formatted_history | |
def is_expired(self, timeout_seconds=3600): | |
"""Check if the session has been inactive for too long.""" | |
return (time.time() - self.last_activity) > timeout_seconds | |
# Session manager to handle multiple users | |
class SessionManager: | |
def __init__(self): | |
"""Initialize the session manager.""" | |
self.sessions = {} | |
self.session_timeout = 3600 # 1 hour timeout | |
def get_session(self, session_id): | |
"""Get an existing session or create a new one.""" | |
# Clean expired sessions first | |
self._clean_expired_sessions() | |
# Create new session if needed | |
if session_id not in self.sessions: | |
self.sessions[session_id] = UserSession(session_id, llm) | |
return self.sessions[session_id] | |
def _clean_expired_sessions(self): | |
"""Remove expired sessions to free up memory.""" | |
expired_keys = [] | |
for key, session in self.sessions.items(): | |
if session.is_expired(self.session_timeout): | |
expired_keys.append(key) | |
for key in expired_keys: | |
del self.sessions[key] | |
# Initialize the session manager | |
session_manager = SessionManager() | |
def initialize_assistant(): | |
"""Initialize the assistant with necessary components and configurations.""" | |
global llm, embed_model, vectorstore, retriever, rag_chain | |
# Initialize API key - try both possible key names | |
groq_api_key = os.environ.get('GBV') or os.environ.get('GBV') | |
if not groq_api_key: | |
print("WARNING: No GROQ API key found in userdata.") | |
# Initialize LLM - Default to Llama model which is more widely available | |
llm = ChatGroq( | |
model="llama-3.3-70b-versatile", # More reliable than whisper model | |
api_key=groq_api_key | |
) | |
# Set up embedding model | |
try: | |
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
except Exception as e: | |
# Fallback to smaller model | |
embed_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Process data and create vector store | |
print("Processing data files...") | |
data = process_data_files() | |
print("Creating vector store...") | |
vectorstore = create_vectorstore(data) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
# Create RAG chain | |
print("Setting up RAG chain...") | |
rag_chain = create_rag_chain() | |
print(f"✅ {APP_NAME} initialized successfully") | |
def process_data_files(): | |
"""Process all data files from the specified folder.""" | |
context_data = [] | |
try: | |
if not os.path.exists(DATA_FOLDER): | |
print(f"WARNING: Data folder does not exist: {DATA_FOLDER}") | |
return context_data | |
# Get list of data files | |
all_files = os.listdir(DATA_FOLDER) | |
data_files = [f for f in all_files if f.lower().endswith(('.csv', '.xlsx', '.xls'))] | |
if not data_files: | |
print(f"WARNING: No data files found in: {DATA_FOLDER}") | |
return context_data | |
# Process each file | |
for index, file_name in enumerate(data_files, 1): | |
print(f"Processing file {index}/{len(data_files)}: {file_name}") | |
file_path = os.path.join(DATA_FOLDER, file_name) | |
try: | |
# Read file based on extension | |
if file_name.lower().endswith('.csv'): | |
df = pd.read_csv(file_path) | |
else: | |
df = pd.read_excel(file_path) | |
# Check if column 3 exists (source data is in third column) | |
if df.shape[1] > 2: | |
column_data = df.iloc[:, 2].dropna().astype(str).tolist() | |
# Each row becomes one chunk with metadata | |
for i, text in enumerate(column_data): | |
if text and len(text.strip()) > 0: | |
context_data.append({ | |
"page_content": text, | |
"metadata": { | |
"source": file_name, | |
"row": i+1 | |
} | |
}) | |
else: | |
print(f"WARNING: File {file_name} has fewer than 3 columns.") | |
except Exception as e: | |
print(f"ERROR processing file {file_name}: {e}") | |
print(f"✅ Created {len(context_data)} chunks from {len(data_files)} files.") | |
except Exception as e: | |
print(f"ERROR accessing data folder: {e}") | |
return context_data | |
def create_vectorstore(data): | |
""" | |
Creates and returns a Chroma vector store populated with the provided data. | |
Parameters: | |
data (list): A list of dictionaries, each containing 'page_content' and 'metadata'. | |
Returns: | |
Chroma: The populated Chroma vector store instance. | |
""" | |
# Initialize the vector store | |
vectorstore = Chroma( | |
collection_name=COLLECTION_NAME, | |
embedding_function=embed_model, | |
persist_directory="./" | |
) | |
if not data: | |
print("⚠️ No data provided. Returning an empty vector store.") | |
return vectorstore | |
try: | |
# Extract text and metadata from the data | |
texts = [doc["page_content"] for doc in data] | |
# Add the texts and metadata to the vector store | |
vectorstore.add_texts(texts) | |
except Exception as e: | |
print(f"❌ Failed to add documents to vector store: {e}") | |
return vs | |
def create_rag_chain(): | |
"""Create the RAG chain for processing user queries.""" | |
# Define the prompt template | |
template = """ | |
You are a compassionate and supportive AI assistant specializing in helping individuals affected by Gender-Based Violence (GBV). Your responses must be based EXCLUSIVELY on the information provided in the context. Your primary goal is to provide emotionally intelligent support while maintaining appropriate boundaries. | |
**Previous conversation:** {conversation_history} | |
**Context information:** {context} | |
**User's Question:** {question} | |
When responding follow these guidelines: | |
1. **Strict Context Adherence** | |
- Only use information that appears in the provided {context} | |
- If the answer is not found in the context, state "I don't have that information in my available resources" rather than generating a response | |
2. **Personalized Communication** | |
- Avoid contractions (e.g., use I am instead of I'm) | |
- Incorporate thoughtful pauses or reflective questions when the conversation involves difficult topics | |
- Use selective emojis (😊, 🤗, ❤️) only when tone-appropriate and not during crisis discussions | |
- Balance warmth with professionalism | |
3. **Emotional Intelligence** | |
- Validate feelings without judgment | |
- Offer reassurance when appropriate, always centered on empowerment | |
- Adjust your tone based on the emotional state conveyed | |
4. **Conversation Management** | |
- Refer to {conversation_history} to maintain continuity and avoid repetition | |
- Use clear paragraph breaks for readability | |
5. **Information Delivery** | |
- Extract only relevant information from {context} that directly addresses the question | |
- Present information in accessible, non-technical language | |
- When information is unavailable, respond with: "I don't have that specific information right now, {first_name}. Would it be helpful if I focus on [alternative support option]?" | |
6. **Safety and Ethics** | |
- Do not generate any speculative content or advice not supported by the context | |
- If the context contains safety information, prioritize sharing that information | |
Your response must come entirely from the provided context, maintaining the supportive tone while never introducing information from outside the provided materials. | |
**Context:** {context} | |
**User's Question:** {question} | |
**Your Response:** | |
""" | |
rag_prompt = PromptTemplate.from_template(template) | |
def get_context_and_question(query_with_session): | |
# Extract query and session_id | |
query = query_with_session["query"] | |
session_id = query_with_session["session_id"] | |
# Get the user session | |
session = session_manager.get_session(session_id) | |
user_info = session.get_user() | |
first_name = user_info.get("Nickname", "User") | |
conversation_hist = session.get_formatted_history() | |
try: | |
# Retrieve relevant documents | |
retrieved_docs = retriever.invoke(query) | |
context_str = format_context(retrieved_docs) | |
except Exception as e: | |
print(f"ERROR retrieving documents: {e}") | |
context_str = "No relevant information found." | |
# Return the combined inputs for the prompt | |
return { | |
"context": context_str, | |
"question": query, | |
"first_name": first_name, | |
"conversation_history": conversation_hist | |
} | |
# Build the chain | |
try: | |
chain = ( | |
RunnablePassthrough() | |
| get_context_and_question | |
| rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
return chain | |
except Exception as e: | |
print(f"ERROR creating RAG chain: {e}") | |
# Return a simple function as fallback | |
def fallback_chain(query_with_session): | |
session_id = query_with_session["session_id"] | |
session = session_manager.get_session(session_id) | |
nickname = session.get_user().get("Nickname", "there") | |
return f"I'm here to help you, {nickname}, but I'm experiencing some technical difficulties right now. Please try again shortly." | |
return fallback_chain | |
def format_context(retrieved_docs): | |
"""Format retrieved documents into a string context.""" | |
if not retrieved_docs: | |
return "No relevant information available." | |
return "\n\n".join([doc.page_content for doc in retrieved_docs]) | |
def rag_memory_stream(message, history, session_id): | |
"""Process user message and generate response with memory.""" | |
# Get the user session | |
session = session_manager.get_session(session_id) | |
# Add user message to history | |
session.add_to_history("user", message) | |
try: | |
# Get response from RAG chain | |
print(f"Processing message for session {session_id}: {message[:50]}...") | |
# Pass both query and session_id to the chain | |
response = rag_chain.invoke({ | |
"query": message, | |
"session_id": session_id | |
}) | |
print(f"Generated response: {response[:50]}...") | |
# Add assistant response to history | |
session.add_to_history("assistant", response) | |
# Yield the response | |
yield response | |
except Exception as e: | |
import traceback | |
print(f"ERROR in rag_memory_stream: {e}") | |
print(f"Detailed error: {traceback.format_exc()}") | |
nickname = session.get_user().get("Nickname", "there") | |
error_msg = f"I'm sorry, {nickname}. I encountered an error processing your request. Let's try a different question." | |
session.add_to_history("assistant", error_msg) | |
yield error_msg | |
def collect_user_info(nickname, session_id): | |
"""Store user details and initialize session.""" | |
if not nickname or nickname.strip() == "": | |
return "Nickname is required to proceed.", gr.update(visible=False), gr.update(visible=True), [] | |
# Store user info for chat session | |
user_info = { | |
"Nickname": nickname.strip(), | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S") | |
} | |
# Get the session and set user info | |
session = session_manager.get_session(session_id) | |
session.set_user(user_info) | |
# Generate welcome message | |
welcome_message = session.get_welcome_message() | |
# Return welcome message and update UI | |
return welcome_message, gr.update(visible=True), gr.update(visible=False), [(None, welcome_message)] | |
def get_css(): | |
"""Define CSS for the UI.""" | |
return """ | |
:root { | |
--primary: #4E6BBF; | |
--primary-light: #697BBF; | |
--text-primary: #333333; | |
--text-secondary: #666666; | |
--background: #F9FAFC; | |
--card-bg: #FFFFFF; | |
--border: #E1E5F0; | |
--shadow: rgba(0, 0, 0, 0.05); | |
} | |
body, .gradio-container { | |
margin: 0; | |
padding: 0; | |
width: 100vw; | |
height: 100vh; | |
display: flex; | |
flex-direction: column; | |
justify-content: center; | |
align-items: center; | |
background: var(--background); | |
color: var(--text-primary); | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
.gradio-container { | |
max-width: 100%; | |
max-height: 100%; | |
} | |
.gr-box { | |
background: var(--card-bg); | |
color: var(--text-primary); | |
border-radius: 12px; | |
padding: 2rem; | |
border: 1px solid var(--border); | |
box-shadow: 0 4px 12px var(--shadow); | |
} | |
.gr-button-primary { | |
background: var(--primary); | |
color: white; | |
padding: 12px 24px; | |
border-radius: 8px; | |
transition: all 0.3s ease; | |
border: none; | |
font-weight: bold; | |
} | |
.gr-button-primary:hover { | |
transform: translateY(-1px); | |
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
background: var(--primary-light); | |
} | |
footer { | |
text-align: center; | |
color: var(--text-secondary); | |
padding: 1rem; | |
font-size: 0.9em; | |
} | |
.gr-markdown h2 { | |
color: var(--primary); | |
margin-bottom: 0.5rem; | |
font-size: 1.8em; | |
} | |
.gr-markdown h3 { | |
color: var(--text-secondary); | |
margin-bottom: 1.5rem; | |
font-weight: normal; | |
} | |
#chatbot_container .chat-title h1, | |
#chatbot_container .empty-chatbot { | |
color: var(--primary); | |
} | |
#input_nickname { | |
padding: 12px; | |
border-radius: 8px; | |
border: 1px solid var(--border); | |
background: var(--card-bg); | |
transition: all 0.3s ease; | |
} | |
#input_nickname:focus { | |
border-color: var(--primary); | |
box-shadow: 0 0 0 2px rgba(78, 107, 191, 0.2); | |
outline: none; | |
} | |
.chatbot-container .message.user { | |
background: #E8F0FE; | |
border-radius: 12px 12px 0 12px; | |
} | |
.chatbot-container .message.bot { | |
background: #F5F7FF; | |
border-radius: 12px 12px 12px 0; | |
} | |
""" | |
def create_ui(): | |
"""Create and configure the Gradio UI.""" | |
with gr.Blocks(css=get_css(), theme=gr.themes.Soft()) as demo: | |
# Create a unique session ID for this browser tab | |
session_id = gr.State(value=f"session_{int(time.time())}_{os.urandom(4).hex()}") | |
# Registration section | |
with gr.Column(visible=True, elem_id="registration_container") as registration_container: | |
gr.Markdown(f"## Welcome to {APP_NAME}") | |
gr.Markdown("### Your privacy is important to us. Please provide a nickname to continue.") | |
with gr.Row(): | |
first_name = gr.Textbox( | |
label="Nickname", | |
placeholder="Enter your nickname", | |
scale=1, | |
elem_id="input_nickname" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Start Chatting", variant="primary", scale=2) | |
response_message = gr.Markdown() | |
# Chatbot section (initially hidden) | |
with gr.Column(visible=False, elem_id="chatbot_container") as chatbot_container: | |
# Create a custom chat interface to pass session_id to our function | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
height=500, | |
show_label=False | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Type your message here...", | |
show_label=False, | |
container=False, | |
scale=9 | |
) | |
submit = gr.Button("Send", scale=1, variant="primary") | |
examples = gr.Examples( | |
examples=[ | |
"What resources are available for GBV victims?", | |
"How can I report an incident?", | |
"What are my legal rights?", | |
"I need help, what should I do first?" | |
], | |
inputs=msg | |
) | |
# Footer with version info | |
gr.Markdown(f"{APP_NAME} {APP_VERSION} © 2025") | |
# Handle chat message submission | |
def respond(message, chat_history, session_id): | |
bot_message = "" | |
for chunk in rag_memory_stream(message, chat_history, session_id): | |
bot_message += chunk | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
msg.submit(respond, [msg, chatbot, session_id], [msg, chatbot]) | |
submit.click(respond, [msg, chatbot, session_id], [msg, chatbot]) | |
# Handle user registration | |
submit_btn.click( | |
collect_user_info, | |
inputs=[first_name, session_id], | |
outputs=[response_message, chatbot_container, registration_container, chatbot] | |
) | |
return demo | |
def launch_app(): | |
"""Launch the Gradio interface.""" | |
ui = create_ui() | |
ui.launch(share=True) | |
# Main execution | |
if __name__ == "__main__": | |
try: | |
# Initialize and launch the assistant | |
initialize_assistant() | |
launch_app() | |
except Exception as e: | |
import traceback | |
print(f"❌ Fatal error initializing GBV Assistant: {e}") | |
print(traceback.format_exc()) | |
# Create a minimal emergency UI to display the error | |
with gr.Blocks() as error_demo: | |
gr.Markdown("## System Error") | |
gr.Markdown(f"An error occurred while initializing the application: {str(e)}") | |
gr.Markdown("Please check your configuration and try again.") | |
error_demo.launch(share=True, inbrowser=True, debug=True) |