Test-api / app.py
syedMohib44
Done
729a6b2
raw
history blame
8.82 kB
import torch
import json
import os
import faiss
import numpy as np
from pptx import Presentation
from fastapi import FastAPI, UploadFile, File
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
from io import BytesIO
# ---------------------------- #
# CONFIGURATION
# ---------------------------- #
MODEL_NAME = "./models/facebook-opt-1.3b"
SUMMARIZATION_MODEL = "./models/bart-large-cnn"
EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2"
DATA_DIRECTORY = "./dataset/"
# ---------------------------- #
# FUNCTION TO LOAD JSON FILES
# ---------------------------- #
def load_text_from_json(directory):
text_data = set() # Use set to remove duplicates
for filename in os.listdir(directory):
if filename.endswith(".json"):
with open(os.path.join(directory, filename), "r", encoding="utf-8") as file:
data = json.load(file)
for entry in data.get("data", []):
question = entry.get("question", "").strip()
answer = entry.get("answer", "").strip()
if question and answer:
text_data.add(f"Q: {question} A: {answer}")
return list(text_data)
# ---------------------------- #
# FUNCTION TO LOAD POWERPOINT FILES
# ---------------------------- #
def extract_text_from_pptx(file_path):
prs = Presentation(file_path)
text_data = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
text_data.append(shape.text.strip())
return " ".join(text_data)
def load_text_from_pptx(directory):
text_data = set()
for filename in os.listdir(directory):
if filename.endswith(".pptx"):
pptx_text = extract_text_from_pptx(os.path.join(directory, filename))
text_data.add(pptx_text)
return list(text_data)
# ---------------------------- #
# LOAD ALL TEXT DATA
# ---------------------------- #
all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY)
# ---------------------------- #
# CHUNK DATA PROPERLY
# ---------------------------- #
CHUNK_SIZE = 500
chunks = set()
for text in all_text:
sentences = text.split(". ")
temp_chunk = ""
for sentence in sentences:
if len(temp_chunk) + len(sentence) < CHUNK_SIZE:
temp_chunk += sentence + ". "
else:
chunks.add(temp_chunk.strip()) # Store chunk
temp_chunk = sentence + ". "
if temp_chunk:
chunks.add(temp_chunk.strip()) # Store last chunk
chunks = list(chunks) # Convert to list after deduplication
# ---------------------------- #
# EMBEDDING MODEL & FAISS VECTOR SEARCH
# ---------------------------- #
embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True)
chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True)
# FAISS index
index = faiss.IndexFlatL2(chunk_embeddings.shape[1])
index.add(chunk_embeddings)
# ---------------------------- #
# LOAD LLM MODEL
# ---------------------------- #
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu"
)
# Summarization pipeline
summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL)
# ---------------------------- #
# FASTAPI SETUP
# ---------------------------- #
app = FastAPI()
def retrieve_relevant_text(question, top_k=3):
question_embedding = embedder.encode([question], convert_to_numpy=True)
_, idxs = index.search(question_embedding, top_k)
retrieved_texts = [chunks[idx] for idx in idxs[0]]
# Filter out chunks that contain the same question
filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()]
unique_texts = list(set(filtered_chunks))
context_text = " ".join(unique_texts)
if len(context_text) > 1000:
context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
return context_text
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
global chunks, index, chunk_embeddings
filename = file.filename
content = await file.read()
new_texts = []
try:
# -------------------- #
# Process .json files
# -------------------- #
if filename.endswith(".json"):
data = json.loads(content)
for entry in data.get("data", []):
question = entry.get("question", "").strip()
answer = entry.get("answer", "").strip()
if question and answer:
new_texts.append(f"Q: {question} A: {answer}")
# -------------------- #
# Process .pptx files
# -------------------- #
elif filename.endswith(".pptx"):
prs = Presentation(BytesIO(content))
ppt_text = []
for slide in prs.slides:
for shape in slide.shapes:
if hasattr(shape, "text"):
ppt_text.append(shape.text.strip())
new_texts.append(" ".join(ppt_text))
else:
return {"error": "Unsupported file type. Use .json or .pptx"}
# -------------------- #
# Chunk and embed
# -------------------- #
new_chunks = set()
for text in new_texts:
sentences = text.split(". ")
temp = ""
for s in sentences:
if len(temp) + len(s) < CHUNK_SIZE:
temp += s + ". "
else:
new_chunks.add(temp.strip())
temp = s + ". "
if temp:
new_chunks.add(temp.strip())
# Remove existing chunks (dedup)
new_chunks = list(set(new_chunks) - set(chunks))
if not new_chunks:
return {"message": "No new unique chunks to add."}
# Encode and update FAISS
new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True)
index.add(new_embeddings)
chunks.extend(new_chunks)
return {
"status": "success",
"new_chunks_added": len(new_chunks),
"total_chunks": len(chunks)
}
except Exception as e:
return {"error": str(e)}
@app.get("/faq/")
def faq(question: str):
"""Answer user queries using retrieved knowledge."""
retrieved_text = retrieve_relevant_text(question)
prompt = (
f"{retrieved_text.strip()}\n\n"
f"Answer the following question based only on the above context:\n"
f"{question.strip()}\n\n"
f"Answer:"
)
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
with torch.no_grad():
output = model.generate(
**inputs,
max_length=200,
repetition_penalty=1.3,
no_repeat_ngram_size=4,
temperature=0.7,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
raw_answer = tokenizer.decode(output[0], skip_special_tokens=True)
# ---------------------------- #
# POST-PROCESSING CLEANUP
# ---------------------------- #
cleaned_answer = raw_answer
# Remove the prompt (everything before final 'Answer:' keyword)
if "Answer:" in cleaned_answer:
cleaned_answer = cleaned_answer.split("Answer:")[-1]
# Remove repeated question (case-insensitive)
question_lower = question.strip().lower()
cleaned_answer = cleaned_answer.strip()
if cleaned_answer.lower().startswith(question_lower):
cleaned_answer = cleaned_answer[len(question):].strip()
# Final touch: remove context/prompt tokens if they leaked
for token in ["Context:", "Question:", "Answer:"]:
cleaned_answer = cleaned_answer.replace(token, "").strip()
return {"answer": cleaned_answer}
# --------- Gradio UI --------- #
def gradio_upload(file):
if file is None:
return "No file selected."
try:
content = str(file) # Works for NamedString
import requests
base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860")
response = requests.post(f"{base_url}/upload/", json={"content": content})
if response.status_code == 200:
return "βœ… Data successfully uploaded and indexed!"
else:
return f"❌ Failed: {response.text}"
except Exception as e:
return f"❌ Error: {str(e)}"
gr_app = gr.Interface(
fn=gradio_upload,
inputs=gr.File(label="Upload .txt or .json file"),
outputs="text",
title="Upload Knowledge",
)
# Mount Gradio at /ui
app = mount_gradio_app(app, gr_app, path="/ui")