asm-app / main.py
chenguittiMaroua's picture
Update main.py
704f972 verified
raw
history blame
9.08 kB
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer, util
from typing import Optional
import io
import fitz # PyMuPDF
from PIL import Image
import pandas as pd
import uvicorn
from functools import lru_cache
from docx import Document
from pptx import Presentation
import pytesseract
import torch
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi import Request
from pathlib import Path
import os
print(os.getcwd()) # This prints the current working directory
# Initialize FastAPI app
app = FastAPI()
print(os.getcwd())
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=[
"https://*.hf.space",
"http://localhost",
"http://localhost:8000"
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static files (frontend)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Model loading with caching
@lru_cache()
def get_summarizer():
return pipeline("summarization", model="facebook/bart-large-cnn")
@lru_cache()
def get_image_captioning():
return pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
@lru_cache()
def get_translator():
return pipeline("translation", model="facebook/nllb-200-distilled-600M")
@lru_cache()
def get_qa_model():
model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
return tokenizer, model
# Helper Functions
def answer_question(question: str, context: str) -> dict:
tokenizer, model = get_qa_model()
try:
inputs = tokenizer(
question,
context,
max_length=512,
truncation="only_second",
padding="max_length",
return_tensors="pt"
)
with torch.no_grad():
outputs = model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"]
)
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
answer = tokenizer.decode(
inputs["input_ids"][0][answer_start:answer_end],
skip_special_tokens=True
).strip()
start_score = torch.nn.functional.softmax(outputs.start_logits, dim=1)[0][answer_start]
end_score = torch.nn.functional.softmax(outputs.end_logits, dim=1)[0][answer_end-1]
confidence = float((start_score + end_score) / 2)
return {
"answer": answer if answer else "No answer found",
"confidence": confidence
}
except Exception as e:
return {
"answer": f"Error processing answer: {str(e)}",
"confidence": 0.0
}
@app.get("/", response_class=HTMLResponse)
def home ():
with open("static/indexAI.html","r") as file :
return file.read()
# API Endpoints
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/summarize")
async def summarize_document(file: UploadFile = File(...)):
try:
content = await file.read()
file_ext = file.filename.split(".")[-1].lower()
text = ""
if file_ext == "docx":
doc = Document(io.BytesIO(content))
text = " ".join([p.text for p in doc.paragraphs if p.text.strip()])
elif file_ext in ["xls", "xlsx"]:
df = pd.read_excel(io.BytesIO(content))
text = " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
elif file_ext == "pptx":
ppt = Presentation(io.BytesIO(content))
text = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
elif file_ext == "pdf":
pdf = fitz.open(stream=content, filetype="pdf")
text = " ".join([page.get_text("text") for page in pdf])
elif file_ext in ["jpg", "jpeg", "png"]:
image = Image.open(io.BytesIO(content))
text = get_image_captioning()(image)[0]['generated_text']
else:
raise HTTPException(400, "Unsupported file format")
if not text.strip():
raise HTTPException(400, "No extractable text found")
summarizer = get_summarizer()
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
summary = " ".join([
summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"]
for chunk in chunks
])
return {"summary": summary}
except Exception as e:
raise HTTPException(500, f"Error processing document: {str(e)}")
@app.post("/ask")
async def ask_question(
question: str = Form(...),
file: Optional[UploadFile] = File(None),
text: Optional[str] = Form(None)
):
context = ""
try:
if file:
content = await file.read()
file_ext = file.filename.split(".")[-1].lower()
if file_ext == "pdf":
pdf = fitz.open(stream=content, filetype="pdf")
context = " ".join([page.get_text("text") for page in pdf])
elif file_ext == "docx":
doc = Document(io.BytesIO(content))
context = " ".join([p.text for p in doc.paragraphs if p.text.strip()])
elif file_ext in ["xls", "xlsx"]:
df = pd.read_excel(io.BytesIO(content))
context = " ".join(df.iloc[:, 0].dropna().astype(str).tolist())
elif file_ext == "pptx":
ppt = Presentation(io.BytesIO(content))
context = " ".join([shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text")])
elif file_ext in ["jpg", "jpeg", "png"]:
image = Image.open(io.BytesIO(content))
try:
context = pytesseract.image_to_string(image)
if not context.strip():
context = get_image_captioning()(image)[0]['generated_text']
except:
context = get_image_captioning()(image)[0]['generated_text']
else:
raise HTTPException(400, "Unsupported file format")
elif text:
context = text
else:
raise HTTPException(400, "Either file or text input is required")
if not context.strip():
raise HTTPException(400, "No extractable content found")
context = " ".join(context.split())
result = answer_question(question, context)
if result["confidence"] < 0.3:
model = SentenceTransformer('all-MiniLM-L6-v2')
sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
if sentences:
question_embedding = model.encode(question, convert_to_tensor=True)
sentence_embeddings = model.encode(sentences, convert_to_tensor=True)
cos_scores = util.cos_sim(question_embedding, sentence_embeddings)[0]
best_idx = torch.argmax(cos_scores).item()
if cos_scores[best_idx] > 0.5:
new_result = answer_question(question, sentences[best_idx])
if new_result["confidence"] > result["confidence"]:
result = new_result
return {
"answer": result["answer"],
"confidence": result["confidence"],
"context_used": context[:500] + "..." if len(context) > 500 else context
}
except Exception as e:
raise HTTPException(500, f"Error processing question: {str(e)}")
@app.post("/api/caption")
async def caption_image(file: UploadFile = File(...)):
try:
image = Image.open(io.BytesIO(await file.read())).convert("RGB")
caption = get_image_captioning()(image)[0]['generated_text']
return {"caption": caption}
except Exception as e:
raise HTTPException(500, f"Error processing image: {str(e)}")
@app.post("/translate")
async def translate_text(
text: str = Form(...),
target_lang: str = Form(...),
src_lang: str = Form("eng_Latn")
):
try:
translated = get_translator()(text, src_lang=src_lang, tgt_lang=target_lang)
return {"translated_text": translated[0]["translation_text"]}
except Exception as e:
raise HTTPException(500, f"Error translating text: {str(e)}")
# Run the application
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)