Spaces:
Sleeping
Sleeping
import os | |
import time | |
import json | |
import logging | |
import threading | |
import gradio as gr | |
import google.generativeai as genai | |
from googleapiclient.discovery import build | |
from googleapiclient.http import MediaIoBaseDownload | |
from google.oauth2 import service_account | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
from langchain.chains import RetrievalQA | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from PyPDF2 import PdfReader | |
from gtts import gTTS | |
temp_file_map = {} | |
# ✅ Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# ✅ Load API Keys | |
logging.info("🔑 Loading API keys...") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY_1") | |
SERVICE_ACCOUNT_JSON = os.getenv("SERVICE_ACCOUNT_JSON") | |
if not GOOGLE_API_KEY or not SERVICE_ACCOUNT_JSON: | |
logging.error("❌ Missing API Key or Service Account JSON.") | |
raise ValueError("❌ Missing API Key or Service Account JSON. Please add them as environment variables.") | |
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY | |
SERVICE_ACCOUNT_FILE = json.loads(SERVICE_ACCOUNT_JSON) | |
SCOPES = ["https://www.googleapis.com/auth/drive"] | |
FOLDER_ID = "1xqOpwgwUoiJYf9GkeuB4dayme4zJcujf" | |
creds = service_account.Credentials.from_service_account_info(SERVICE_ACCOUNT_FILE) | |
drive_service = build("drive", "v3", credentials=creds) | |
# ✅ Initialize variables | |
vector_store = None | |
file_id_map = {} | |
temp_dir = "./temp_downloads" | |
os.makedirs(temp_dir, exist_ok=True) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# ✅ Get list of files from Google Drive | |
def get_files_from_drive(): | |
logging.info("📂 Fetching files from Google Drive...") | |
query = f"'{FOLDER_ID}' in parents and trashed = false" | |
results = drive_service.files().list(q=query, fields="files(id, name)").execute() | |
files = results.get("files", []) | |
global file_id_map | |
file_id_map = {file["name"]: file["id"] for file in files} | |
return list(file_id_map.keys()) if files else [] | |
# ✅ Download file from Google Drive | |
def download_file(file_id, file_name): | |
file_path = os.path.join(temp_dir, file_name) | |
request = drive_service.files().get_media(fileId=file_id) | |
with open(file_path, "wb") as f: | |
downloader = MediaIoBaseDownload(f, request) | |
done = False | |
while not done: | |
_, done = downloader.next_chunk() | |
return file_path | |
# ✅ Process documents | |
def process_documents(selected_files): | |
global vector_store | |
docs = [] | |
for file_name in selected_files: | |
file_path = download_file(file_id_map[file_name], file_name) | |
if file_name.endswith(".pdf"): | |
loader = PyPDFLoader(file_path) | |
elif file_name.endswith(".txt"): | |
loader = TextLoader(file_path) | |
elif file_name.endswith(".docx"): | |
loader = Docx2txtLoader(file_path) | |
else: | |
logging.warning(f"⚠️ Unsupported file type: {file_name}") | |
continue | |
docs.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
split_docs = text_splitter.split_documents(docs) | |
vector_store = Chroma.from_documents(split_docs, embeddings) | |
return "✅ Documents processed successfully!" | |
# ✅ Query document | |
import os | |
import time | |
import logging | |
from gtts import gTTS | |
from langchain.chains import RetrievalQA | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
# ✅ Ensure temp_file_map exists | |
temp_file_map = {} | |
def query_document(question): | |
if vector_store is None: | |
return "❌ No documents processed.", None | |
# ✅ Fetch stored documents | |
stored_docs = vector_store.get()["documents"] | |
# ✅ Calculate total word count safely | |
total_words = sum(len(doc.split()) if isinstance(doc, str) else len(doc.page_content.split()) for doc in stored_docs) | |
# ✅ Categorize file size | |
if total_words < 500: | |
file_size_category = "small" | |
k_value = 3 | |
elif total_words < 2000: | |
file_size_category = "medium" | |
k_value = 5 | |
else: | |
file_size_category = "large" | |
k_value = 10 | |
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": k_value}) | |
# ✅ Adjust response detail based on file size | |
if file_size_category == "small": | |
prompt_prefix = "Provide a **concise** response focusing on key points." | |
elif file_size_category == "medium": | |
prompt_prefix = "Provide a **detailed response** with examples and key insights." | |
else: | |
prompt_prefix = "Provide a **comprehensive and structured response**, including step-by-step analysis and explanations." | |
# ✅ Final prompt | |
detailed_prompt = f"""{prompt_prefix} | |
- Ensure clarity and completeness. | |
- Highlight the most relevant information. | |
**Question:** {question} | |
""" | |
# ✅ Dynamically select model based on file size | |
if file_size_category in ["small", "medium"]: | |
model_name = "gemini-2.0-pro-exp-02-05" | |
else: | |
model_name = "gemini-2.0-flash" | |
logging.info(f"🧠 Using Model: {model_name} for {file_size_category} file.") | |
model = ChatGoogleGenerativeAI(model=model_name, google_api_key=GOOGLE_API_KEY) | |
qa_chain = RetrievalQA.from_chain_type(llm=model, retriever=retriever) | |
response = qa_chain.invoke({"query": detailed_prompt})["result"] | |
# ✅ Convert response to speech | |
tts = gTTS(text=response, lang="en") | |
temp_audio_path = os.path.join(temp_dir, "response.mp3") | |
tts.save(temp_audio_path) | |
temp_file_map["response.mp3"] = time.time() | |
return response, temp_audio_path | |
# ✅ Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# 📄 AI-Powered Multi-Document Chatbot with Voice Output") | |
file_dropdown = gr.Dropdown(choices=get_files_from_drive(), label="📂 Select Files", multiselect=True) | |
refresh_button = gr.Button("🔄 Refresh Files") # 🔄 Add Refresh Button | |
process_button = gr.Button("🚀 Process Documents") | |
user_input = gr.Textbox(label="🔎 Ask a Question") | |
submit_button = gr.Button("💬 Get Answer") | |
response_output = gr.Textbox(label="📝 Response") | |
audio_output = gr.Audio(label="🔊 Audio Response") | |
# 🔄 Function to Refresh File List | |
def refresh_files(): | |
return gr.update(choices=get_files_from_drive()) | |
# ✅ Connect Refresh Button | |
refresh_button.click(refresh_files, outputs=file_dropdown) | |
# ✅ Connect Process Button | |
process_button.click(process_documents, inputs=file_dropdown, outputs=response_output) | |
# ✅ Connect Query Button | |
submit_button.click(query_document, inputs=user_input, outputs=[response_output, audio_output]) | |
demo.launch() | |