Essay-Grader's picture
Done fixing
f4e8889
# main.py: API for Detection and Plagiarism Check
import os
import re
import time
import logging
from pathlib import Path
from typing import List, Tuple
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
import fitz # PyMuPDF
import torch
import numpy as np
import nltk
import asyncio
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
AutoModel,
)
from nltk.tokenize import sent_tokenize
from sklearn.metrics.pairwise import cosine_similarity
# Setup cache
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface"
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
os.environ["NLTK_DATA"] = "/tmp/.cache/nltk"
Path("/tmp/.cache/huggingface").mkdir(parents=True, exist_ok=True)
Path("/tmp/.cache/nltk").mkdir(parents=True, exist_ok=True)
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# NLTK init
try:
nltk.data.path.append("/tmp/.cache/nltk")
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt", download_dir="/tmp/.cache/nltk")
nltk.data.path.append("/tmp/.cache/nltk")
# App init
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_headers=["*"],
)
# Model configs
MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2"
DEVICE = 0 if torch.cuda.is_available() else -1
MAX_TEXT_LENGTH = 10000 # Reduced for faster processing
AI_CHUNK_SIZE = 512
PLAGIARISM_THRESHOLD = 0.75
TIMEOUT = 30 # Global timeout in seconds
TASK_TIMEOUT = 15 # Per-task timeout in seconds
MAX_SENTENCES = 20 # Limit sentences for plagiarism check
# Load models
try:
logger.info("Loading models...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(
DEVICE if DEVICE != -1 else "cpu"
)
ai_model.eval()
embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(
DEVICE if DEVICE != -1 else "cpu"
)
embed_model.eval()
logger.info("Models loaded successfully")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}", exc_info=True)
raise RuntimeError(f"Failed to initialize models: {str(e)}")
def extract_text(pdf_bytes: bytes) -> str:
try:
start_time = time.time()
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
text = []
for page in doc:
if time.time() - start_time > TIMEOUT / 3: # Early timeout for extraction
raise TimeoutError("PDF extraction timed out")
page_text = page.get_text().strip()
if "reference" in page_text.lower():
break
text.append(page_text)
full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH]
if len(full_text) < 150:
raise ValueError("Text too short")
logger.info(f"Extracted text: {len(full_text)} characters")
return full_text
except Exception as e:
logger.error(f"PDF error: {str(e)}")
raise HTTPException(400, "Invalid PDF")
async def predict_ai(text: str) -> float:
try:
async def run_inference():
inputs = tokenizer(
text,
truncation=True,
max_length=AI_CHUNK_SIZE,
return_tensors="pt",
).to(ai_model.device)
with torch.no_grad():
outputs = ai_model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
logger.info("AI detection completed")
return float(probs[0][1]) # AI-generated probability
return await asyncio.wait_for(run_inference(), timeout=TASK_TIMEOUT)
except asyncio.TimeoutError:
logger.error("AI detection timed out")
raise HTTPException(500, "AI detection timed out")
except Exception as e:
logger.error(f"AI detection error: {str(e)}")
raise HTTPException(500, f"AI detection failed: {str(e)}")
async def compute_embeddings(sentences: List[str]) -> np.ndarray:
inputs = embed_tokenizer(
sentences,
padding=True,
truncation=True,
max_length=128,
return_tensors="pt",
).to(embed_model.device)
with torch.no_grad():
outputs = embed_model(**inputs)
attention_mask = inputs["attention_mask"]
last_hidden = outputs.last_hidden_state
return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
1, keepdim=True
).cpu().numpy()
async def check_plagiarism(text: str) -> Tuple[float, bool]:
try:
async def run_plagiarism():
sentences = [
s for s in sent_tokenize(text) if 5 < len(s.split()) < 100
][:MAX_SENTENCES]
if len(sentences) < 2:
logger.info("Not enough sentences for plagiarism check")
return 0.0, False
embeddings = await compute_embeddings(sentences)
sim_matrix = cosine_similarity(embeddings)
np.fill_diagonal(sim_matrix, 0)
n = len(sim_matrix)
top_k = max(1, int(0.1 * n * (n - 1) / 2))
top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:]
avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices]))
logger.info("Plagiarism check completed")
return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD
return await asyncio.wait_for(run_plagiarism(), timeout=TASK_TIMEOUT)
except asyncio.TimeoutError:
logger.error("Plagiarism check timed out")
return 0.0, False
except Exception as e:
logger.error(f"Plagiarism check error: {str(e)}")
return 0.0, False
@app.post("/detect")
async def detect_ai_and_plagiarism(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
start_time = time.time()
try:
if not file.filename.lower().endswith(".pdf"):
raise HTTPException(400, "Only PDF files allowed")
pdf_data = await file.read()
text = extract_text(pdf_data)
ai_future = asyncio.create_task(predict_ai(text))
plagiarism_future = asyncio.create_task(check_plagiarism(text))
ai_score, (plag_score, plag_risk) = await asyncio.gather(
ai_future, plagiarism_future
)
total_time = time.time() - start_time
if total_time > TIMEOUT:
logger.error("Processing exceeded timeout")
raise HTTPException(500, "Processing timed out")
logger.info(f"Processing completed in {total_time:.2f} seconds")
return {
"ai_generated_percentage": round(ai_score * 100, 2),
"plagiarism_percentage": plag_score,
}
except HTTPException as he:
raise
except Exception as e:
logger.error(f"Error: {str(e)}", exc_info=True)
raise HTTPException(500, f"Processing failed: {str(e)}")
# Health check endpoint for debugging
# @app.get("/health")
# async def health_check():
# return {"status": "healthy", "python_version": "3.11"}