Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering | |
from sentence_transformers import SentenceTransformer, util | |
from typing import Optional | |
import io | |
import fitz # PyMuPDF | |
from PIL import Image | |
import pandas as pd | |
import uvicorn | |
from functools import lru_cache | |
from docx import Document | |
from pptx import Presentation | |
import pytesseract | |
import torch | |
from fastapi.responses import HTMLResponse | |
from fastapi.templating import Jinja2Templates | |
from fastapi import Request | |
from pathlib import Path | |
import os | |
print(os.getcwd()) # This prints the current working directory | |
# Initialize FastAPI app | |
app = FastAPI() | |
print(os.getcwd()) | |
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates")) | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=[ | |
"https://*.hf.space", | |
"http://localhost", | |
"http://localhost:8000" | |
], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Serve static files (frontend) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Model loading with caching | |
def get_summarizer(): | |
return pipeline("summarization", model="facebook/bart-large-cnn") | |
def get_image_captioning(): | |
return pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") | |
def get_translator(): | |
return pipeline("translation", model="facebook/nllb-200-distilled-600M") | |
def get_qa_model(): | |
model_name = "deepset/roberta-base-squad2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForQuestionAnswering.from_pretrained(model_name) | |
return tokenizer, model | |
# Helper Functions | |
def answer_question(question: str, context: str) -> dict: | |
tokenizer, model = get_qa_model() | |
try: | |
inputs = tokenizer( | |
question, | |
context, | |
max_length=512, | |
truncation="only_second", | |
padding="max_length", | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
outputs = model( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"] | |
) | |
answer_start = torch.argmax(outputs.start_logits) | |
answer_end = torch.argmax(outputs.end_logits) + 1 | |
answer = tokenizer.decode( | |
inputs["input_ids"][0][answer_start:answer_end], | |
skip_special_tokens=True | |
).strip() | |
start_score = torch.nn.functional.softmax(outputs.start_logits, dim=1)[0][answer_start] | |
end_score = torch.nn.functional.softmax(outputs.end_logits, dim=1)[0][answer_end-1] | |
confidence = float((start_score + end_score) / 2) | |
return { | |
"answer": answer if answer else "No answer found", | |
"confidence": confidence | |
} | |
except Exception as e: | |
return { | |
"answer": f"Error processing answer: {str(e)}", | |
"confidence": 0.0 | |
} | |
def home (): | |
with open("static/indexAI.html","r") as file : | |
return file.read() | |
# API Endpoints | |
async def health_check(): | |
return {"status": "healthy"} | |
async def summarize_document(file: UploadFile = File(...)): | |
try: | |
content = await file.read() | |
file_ext = file.filename.split(".")[-1].lower() | |
text = "" | |
if file_ext == "docx": | |
doc = Document(io.BytesIO(content)) | |
text = " ".join([p.text for p in doc.paragraphs if p.text.strip()]) | |
elif file_ext in ["xls", "xlsx"]: | |
df = pd.read_excel(io.BytesIO(content)) | |
text = " ".join(df.iloc[:, 0].dropna().astype(str).tolist()) | |
elif file_ext == "pptx": | |
ppt = Presentation(io.BytesIO(content)) | |
text = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")]) | |
elif file_ext == "pdf": | |
pdf = fitz.open(stream=content, filetype="pdf") | |
text = " ".join([page.get_text("text") for page in pdf]) | |
elif file_ext in ["jpg", "jpeg", "png"]: | |
image = Image.open(io.BytesIO(content)) | |
text = get_image_captioning()(image)[0]['generated_text'] | |
else: | |
raise HTTPException(400, "Unsupported file format") | |
if not text.strip(): | |
raise HTTPException(400, "No extractable text found") | |
summarizer = get_summarizer() | |
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)] | |
summary = " ".join([ | |
summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
for chunk in chunks | |
]) | |
return {"summary": summary} | |
except Exception as e: | |
raise HTTPException(500, f"Error processing document: {str(e)}") | |
async def ask_question( | |
question: str = Form(...), | |
file: Optional[UploadFile] = File(None), | |
text: Optional[str] = Form(None) | |
): | |
context = "" | |
try: | |
if file: | |
content = await file.read() | |
file_ext = file.filename.split(".")[-1].lower() | |
if file_ext == "pdf": | |
pdf = fitz.open(stream=content, filetype="pdf") | |
context = " ".join([page.get_text("text") for page in pdf]) | |
elif file_ext == "docx": | |
doc = Document(io.BytesIO(content)) | |
context = " ".join([p.text for p in doc.paragraphs if p.text.strip()]) | |
elif file_ext in ["xls", "xlsx"]: | |
df = pd.read_excel(io.BytesIO(content)) | |
context = " ".join(df.iloc[:, 0].dropna().astype(str).tolist()) | |
elif file_ext == "pptx": | |
ppt = Presentation(io.BytesIO(content)) | |
context = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")]) | |
elif file_ext in ["jpg", "jpeg", "png"]: | |
image = Image.open(io.BytesIO(content)) | |
try: | |
context = pytesseract.image_to_string(image) | |
if not context.strip(): | |
context = get_image_captioning()(image)[0]['generated_text'] | |
except: | |
context = get_image_captioning()(image)[0]['generated_text'] | |
else: | |
raise HTTPException(400, "Unsupported file format") | |
elif text: | |
context = text | |
else: | |
raise HTTPException(400, "Either file or text input is required") | |
if not context.strip(): | |
raise HTTPException(400, "No extractable content found") | |
context = " ".join(context.split()) | |
result = answer_question(question, context) | |
if result["confidence"] < 0.3: | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20] | |
if sentences: | |
question_embedding = model.encode(question, convert_to_tensor=True) | |
sentence_embeddings = model.encode(sentences, convert_to_tensor=True) | |
cos_scores = util.cos_sim(question_embedding, sentence_embeddings)[0] | |
best_idx = torch.argmax(cos_scores).item() | |
if cos_scores[best_idx] > 0.5: | |
new_result = answer_question(question, sentences[best_idx]) | |
if new_result["confidence"] > result["confidence"]: | |
result = new_result | |
return { | |
"answer": result["answer"], | |
"confidence": result["confidence"], | |
"context_used": context[:500] + "..." if len(context) > 500 else context | |
} | |
except Exception as e: | |
raise HTTPException(500, f"Error processing question: {str(e)}") | |
async def caption_image(file: UploadFile = File(...)): | |
try: | |
image = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
caption = get_image_captioning()(image)[0]['generated_text'] | |
return {"caption": caption} | |
except Exception as e: | |
raise HTTPException(500, f"Error processing image: {str(e)}") | |
async def translate_text( | |
text: str = Form(...), | |
target_lang: str = Form(...), | |
src_lang: str = Form("eng_Latn") | |
): | |
try: | |
translated = get_translator()(text, src_lang=src_lang, tgt_lang=target_lang) | |
return {"translated_text": translated[0]["translation_text"]} | |
except Exception as e: | |
raise HTTPException(500, f"Error translating text: {str(e)}") | |
# Run the application | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) |