|
|
|
|
|
import os |
|
import re |
|
import time |
|
import logging |
|
from pathlib import Path |
|
from typing import List, Tuple |
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
import fitz |
|
import torch |
|
import numpy as np |
|
import nltk |
|
import asyncio |
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSequenceClassification, |
|
AutoModel, |
|
) |
|
from nltk.tokenize import sent_tokenize |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache/huggingface" |
|
os.environ["HF_HOME"] = "/tmp/.cache/huggingface" |
|
os.environ["NLTK_DATA"] = "/tmp/.cache/nltk" |
|
Path("/tmp/.cache/huggingface").mkdir(parents=True, exist_ok=True) |
|
Path("/tmp/.cache/nltk").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
nltk.data.path.append("/tmp/.cache/nltk") |
|
nltk.data.find("tokenizers/punkt") |
|
except LookupError: |
|
nltk.download("punkt", download_dir="/tmp/.cache/nltk") |
|
nltk.data.path.append("/tmp/.cache/nltk") |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_methods=["POST"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702" |
|
EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2" |
|
DEVICE = 0 if torch.cuda.is_available() else -1 |
|
MAX_TEXT_LENGTH = 10000 |
|
AI_CHUNK_SIZE = 512 |
|
PLAGIARISM_THRESHOLD = 0.75 |
|
TIMEOUT = 30 |
|
TASK_TIMEOUT = 15 |
|
MAX_SENTENCES = 20 |
|
|
|
|
|
try: |
|
logger.info("Loading models...") |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to( |
|
DEVICE if DEVICE != -1 else "cpu" |
|
) |
|
ai_model.eval() |
|
|
|
embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) |
|
embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to( |
|
DEVICE if DEVICE != -1 else "cpu" |
|
) |
|
embed_model.eval() |
|
logger.info("Models loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Model loading failed: {str(e)}", exc_info=True) |
|
raise RuntimeError(f"Failed to initialize models: {str(e)}") |
|
|
|
def extract_text(pdf_bytes: bytes) -> str: |
|
try: |
|
start_time = time.time() |
|
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc: |
|
text = [] |
|
for page in doc: |
|
if time.time() - start_time > TIMEOUT / 3: |
|
raise TimeoutError("PDF extraction timed out") |
|
page_text = page.get_text().strip() |
|
if "reference" in page_text.lower(): |
|
break |
|
text.append(page_text) |
|
|
|
full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH] |
|
if len(full_text) < 150: |
|
raise ValueError("Text too short") |
|
logger.info(f"Extracted text: {len(full_text)} characters") |
|
return full_text |
|
except Exception as e: |
|
logger.error(f"PDF error: {str(e)}") |
|
raise HTTPException(400, "Invalid PDF") |
|
|
|
async def predict_ai(text: str) -> float: |
|
try: |
|
async def run_inference(): |
|
inputs = tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=AI_CHUNK_SIZE, |
|
return_tensors="pt", |
|
).to(ai_model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = ai_model(**inputs) |
|
probs = torch.softmax(outputs.logits, dim=1) |
|
logger.info("AI detection completed") |
|
return float(probs[0][1]) |
|
|
|
return await asyncio.wait_for(run_inference(), timeout=TASK_TIMEOUT) |
|
except asyncio.TimeoutError: |
|
logger.error("AI detection timed out") |
|
raise HTTPException(500, "AI detection timed out") |
|
except Exception as e: |
|
logger.error(f"AI detection error: {str(e)}") |
|
raise HTTPException(500, f"AI detection failed: {str(e)}") |
|
|
|
async def compute_embeddings(sentences: List[str]) -> np.ndarray: |
|
inputs = embed_tokenizer( |
|
sentences, |
|
padding=True, |
|
truncation=True, |
|
max_length=128, |
|
return_tensors="pt", |
|
).to(embed_model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = embed_model(**inputs) |
|
|
|
attention_mask = inputs["attention_mask"] |
|
last_hidden = outputs.last_hidden_state |
|
return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum( |
|
1, keepdim=True |
|
).cpu().numpy() |
|
|
|
async def check_plagiarism(text: str) -> Tuple[float, bool]: |
|
try: |
|
async def run_plagiarism(): |
|
sentences = [ |
|
s for s in sent_tokenize(text) if 5 < len(s.split()) < 100 |
|
][:MAX_SENTENCES] |
|
if len(sentences) < 2: |
|
logger.info("Not enough sentences for plagiarism check") |
|
return 0.0, False |
|
|
|
embeddings = await compute_embeddings(sentences) |
|
sim_matrix = cosine_similarity(embeddings) |
|
np.fill_diagonal(sim_matrix, 0) |
|
|
|
n = len(sim_matrix) |
|
top_k = max(1, int(0.1 * n * (n - 1) / 2)) |
|
top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:] |
|
avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices])) |
|
|
|
logger.info("Plagiarism check completed") |
|
return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD |
|
|
|
return await asyncio.wait_for(run_plagiarism(), timeout=TASK_TIMEOUT) |
|
except asyncio.TimeoutError: |
|
logger.error("Plagiarism check timed out") |
|
return 0.0, False |
|
except Exception as e: |
|
logger.error(f"Plagiarism check error: {str(e)}") |
|
return 0.0, False |
|
|
|
@app.post("/detect") |
|
async def detect_ai_and_plagiarism(file: UploadFile = File(...), background_tasks: BackgroundTasks = None): |
|
start_time = time.time() |
|
|
|
try: |
|
if not file.filename.lower().endswith(".pdf"): |
|
raise HTTPException(400, "Only PDF files allowed") |
|
|
|
pdf_data = await file.read() |
|
text = extract_text(pdf_data) |
|
|
|
ai_future = asyncio.create_task(predict_ai(text)) |
|
plagiarism_future = asyncio.create_task(check_plagiarism(text)) |
|
|
|
ai_score, (plag_score, plag_risk) = await asyncio.gather( |
|
ai_future, plagiarism_future |
|
) |
|
|
|
total_time = time.time() - start_time |
|
if total_time > TIMEOUT: |
|
logger.error("Processing exceeded timeout") |
|
raise HTTPException(500, "Processing timed out") |
|
|
|
logger.info(f"Processing completed in {total_time:.2f} seconds") |
|
return { |
|
"ai_generated_percentage": round(ai_score * 100, 2), |
|
"plagiarism_percentage": plag_score, |
|
} |
|
|
|
except HTTPException as he: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Error: {str(e)}", exc_info=True) |
|
raise HTTPException(500, f"Processing failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|