# main.py: API for Detection and Plagiarism Check import os import re import time import logging from pathlib import Path from typing import List, Tuple from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware import fitz # PyMuPDF import torch import numpy as np import nltk import asyncio from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModel, ) from nltk.tokenize import sent_tokenize from sklearn.metrics.pairwise import cosine_similarity # Setup cache os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" os.environ["HF_HOME"] = "/tmp/.cache/huggingface" os.environ["NLTK_DATA"] = "/tmp/.cache/nltk" Path("/tmp/.cache/huggingface").mkdir(parents=True, exist_ok=True) Path("/tmp/.cache/nltk").mkdir(parents=True, exist_ok=True) # Logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # NLTK init try: nltk.data.path.append("/tmp/.cache/nltk") nltk.data.find("tokenizers/punkt") except LookupError: nltk.download("punkt", download_dir="/tmp/.cache/nltk") nltk.data.path.append("/tmp/.cache/nltk") # App init app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["POST"], allow_headers=["*"], ) # Model configs MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702" EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2" DEVICE = 0 if torch.cuda.is_available() else -1 MAX_TEXT_LENGTH = 10000 # Reduced for faster processing AI_CHUNK_SIZE = 512 PLAGIARISM_THRESHOLD = 0.75 TIMEOUT = 30 # Global timeout in seconds TASK_TIMEOUT = 15 # Per-task timeout in seconds MAX_SENTENCES = 20 # Limit sentences for plagiarism check # Load models try: logger.info("Loading models...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to( DEVICE if DEVICE != -1 else "cpu" ) ai_model.eval() embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to( DEVICE if DEVICE != -1 else "cpu" ) embed_model.eval() logger.info("Models loaded successfully") except Exception as e: logger.error(f"Model loading failed: {str(e)}", exc_info=True) raise RuntimeError(f"Failed to initialize models: {str(e)}") def extract_text(pdf_bytes: bytes) -> str: try: start_time = time.time() with fitz.open(stream=pdf_bytes, filetype="pdf") as doc: text = [] for page in doc: if time.time() - start_time > TIMEOUT / 3: # Early timeout for extraction raise TimeoutError("PDF extraction timed out") page_text = page.get_text().strip() if "reference" in page_text.lower(): break text.append(page_text) full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH] if len(full_text) < 150: raise ValueError("Text too short") logger.info(f"Extracted text: {len(full_text)} characters") return full_text except Exception as e: logger.error(f"PDF error: {str(e)}") raise HTTPException(400, "Invalid PDF") async def predict_ai(text: str) -> float: try: async def run_inference(): inputs = tokenizer( text, truncation=True, max_length=AI_CHUNK_SIZE, return_tensors="pt", ).to(ai_model.device) with torch.no_grad(): outputs = ai_model(**inputs) probs = torch.softmax(outputs.logits, dim=1) logger.info("AI detection completed") return float(probs[0][1]) # AI-generated probability return await asyncio.wait_for(run_inference(), timeout=TASK_TIMEOUT) except asyncio.TimeoutError: logger.error("AI detection timed out") raise HTTPException(500, "AI detection timed out") except Exception as e: logger.error(f"AI detection error: {str(e)}") raise HTTPException(500, f"AI detection failed: {str(e)}") async def compute_embeddings(sentences: List[str]) -> np.ndarray: inputs = embed_tokenizer( sentences, padding=True, truncation=True, max_length=128, return_tensors="pt", ).to(embed_model.device) with torch.no_grad(): outputs = embed_model(**inputs) attention_mask = inputs["attention_mask"] last_hidden = outputs.last_hidden_state return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum( 1, keepdim=True ).cpu().numpy() async def check_plagiarism(text: str) -> Tuple[float, bool]: try: async def run_plagiarism(): sentences = [ s for s in sent_tokenize(text) if 5 < len(s.split()) < 100 ][:MAX_SENTENCES] if len(sentences) < 2: logger.info("Not enough sentences for plagiarism check") return 0.0, False embeddings = await compute_embeddings(sentences) sim_matrix = cosine_similarity(embeddings) np.fill_diagonal(sim_matrix, 0) n = len(sim_matrix) top_k = max(1, int(0.1 * n * (n - 1) / 2)) top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:] avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices])) logger.info("Plagiarism check completed") return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD return await asyncio.wait_for(run_plagiarism(), timeout=TASK_TIMEOUT) except asyncio.TimeoutError: logger.error("Plagiarism check timed out") return 0.0, False except Exception as e: logger.error(f"Plagiarism check error: {str(e)}") return 0.0, False @app.post("/detect") async def detect_ai_and_plagiarism(file: UploadFile = File(...), background_tasks: BackgroundTasks = None): start_time = time.time() try: if not file.filename.lower().endswith(".pdf"): raise HTTPException(400, "Only PDF files allowed") pdf_data = await file.read() text = extract_text(pdf_data) ai_future = asyncio.create_task(predict_ai(text)) plagiarism_future = asyncio.create_task(check_plagiarism(text)) ai_score, (plag_score, plag_risk) = await asyncio.gather( ai_future, plagiarism_future ) total_time = time.time() - start_time if total_time > TIMEOUT: logger.error("Processing exceeded timeout") raise HTTPException(500, "Processing timed out") logger.info(f"Processing completed in {total_time:.2f} seconds") return { "ai_generated_percentage": round(ai_score * 100, 2), "plagiarism_percentage": plag_score, } except HTTPException as he: raise except Exception as e: logger.error(f"Error: {str(e)}", exc_info=True) raise HTTPException(500, f"Processing failed: {str(e)}") # Health check endpoint for debugging # @app.get("/health") # async def health_check(): # return {"status": "healthy", "python_version": "3.11"}