Commit
·
164dd9f
1
Parent(s):
1c12f42
Fixed the main.py
Browse files
main.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import time
|
@@ -5,7 +7,7 @@ import logging
|
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple
|
7 |
|
8 |
-
from fastapi import FastAPI, UploadFile, File, HTTPException
|
9 |
from fastapi.middleware.cors import CORSMiddleware
|
10 |
|
11 |
import fitz # PyMuPDF
|
@@ -51,67 +53,79 @@ app.add_middleware(
|
|
51 |
)
|
52 |
|
53 |
# Model configs
|
54 |
-
MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
|
55 |
-
EMBEDDING_MODEL = "sentence-transformers/
|
56 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
57 |
-
MAX_TEXT_LENGTH = 10000
|
58 |
AI_CHUNK_SIZE = 512
|
59 |
PLAGIARISM_THRESHOLD = 0.75
|
60 |
-
TIMEOUT =
|
|
|
61 |
|
62 |
# Load models
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
)
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(
|
71 |
-
DEVICE if DEVICE != -1 else "cpu"
|
72 |
-
)
|
73 |
-
embed_model.eval()
|
74 |
-
|
75 |
-
# Health check
|
76 |
-
# @app.get("/health")
|
77 |
-
# def health_check():
|
78 |
-
# return {"status": "healthy"}
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def extract_text(pdf_bytes: bytes) -> str:
|
82 |
try:
|
|
|
83 |
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
|
84 |
text = []
|
85 |
for page in doc:
|
|
|
|
|
86 |
page_text = page.get_text().strip()
|
87 |
if "reference" in page_text.lower():
|
88 |
-
break
|
89 |
text.append(page_text)
|
90 |
|
91 |
full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH]
|
92 |
if len(full_text) < 150:
|
93 |
raise ValueError("Text too short")
|
|
|
94 |
return full_text
|
95 |
except Exception as e:
|
96 |
logger.error(f"PDF error: {str(e)}")
|
97 |
raise HTTPException(400, "Invalid PDF")
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
def
|
101 |
-
inputs = tokenizer(
|
102 |
-
text,
|
103 |
-
truncation=True,
|
104 |
-
max_length=AI_CHUNK_SIZE,
|
105 |
-
return_tensors="pt",
|
106 |
-
).to(ai_model.device)
|
107 |
-
|
108 |
-
with torch.no_grad():
|
109 |
-
outputs = ai_model(**inputs)
|
110 |
-
probs = torch.softmax(outputs.logits, dim=1)
|
111 |
-
return float(probs[0][1]) # AI-generated probability
|
112 |
-
|
113 |
-
|
114 |
-
def compute_embeddings(sentences: List[str]) -> np.ndarray:
|
115 |
inputs = embed_tokenizer(
|
116 |
sentences,
|
117 |
padding=True,
|
@@ -127,34 +141,38 @@ def compute_embeddings(sentences: List[str]) -> np.ndarray:
|
|
127 |
last_hidden = outputs.last_hidden_state
|
128 |
return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
|
129 |
1, keepdim=True
|
130 |
-
)
|
131 |
-
|
132 |
|
133 |
-
def check_plagiarism(text: str) -> Tuple[float, bool]:
|
134 |
try:
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
except Exception as e:
|
152 |
logger.error(f"Plagiarism check error: {str(e)}")
|
153 |
return 0.0, False
|
154 |
|
155 |
-
|
156 |
@app.post("/detect")
|
157 |
-
async def detect_ai_and_plagiarism(file: UploadFile = File(...)):
|
158 |
start_time = time.time()
|
159 |
|
160 |
try:
|
@@ -164,14 +182,8 @@ async def detect_ai_and_plagiarism(file: UploadFile = File(...)):
|
|
164 |
pdf_data = await file.read()
|
165 |
text = extract_text(pdf_data)
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
async def run_plagiarism():
|
171 |
-
return check_plagiarism(text)
|
172 |
-
|
173 |
-
ai_future = asyncio.create_task(run_ai())
|
174 |
-
plagiarism_future = asyncio.create_task(run_plagiarism())
|
175 |
|
176 |
ai_score, (plag_score, plag_risk) = await asyncio.gather(
|
177 |
ai_future, plagiarism_future
|
@@ -179,19 +191,23 @@ async def detect_ai_and_plagiarism(file: UploadFile = File(...)):
|
|
179 |
|
180 |
total_time = time.time() - start_time
|
181 |
if total_time > TIMEOUT:
|
|
|
182 |
raise HTTPException(500, "Processing timed out")
|
183 |
|
|
|
184 |
return {
|
185 |
"ai_generated_percentage": round(ai_score * 100, 2),
|
186 |
"plagiarism_percentage": plag_score,
|
187 |
-
|
|
|
188 |
}
|
189 |
|
|
|
|
|
190 |
except Exception as e:
|
191 |
logger.error(f"Error: {str(e)}", exc_info=True)
|
192 |
raise HTTPException(500, f"Processing failed: {str(e)}")
|
193 |
-
|
194 |
-
|
195 |
|
196 |
|
197 |
|
|
|
1 |
+
# main.py: API for Detection and Plagiarism Check
|
2 |
+
|
3 |
import os
|
4 |
import re
|
5 |
import time
|
|
|
7 |
from pathlib import Path
|
8 |
from typing import List, Tuple
|
9 |
|
10 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
|
11 |
from fastapi.middleware.cors import CORSMiddleware
|
12 |
|
13 |
import fitz # PyMuPDF
|
|
|
53 |
)
|
54 |
|
55 |
# Model configs
|
56 |
+
MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
|
57 |
+
EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2"
|
58 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
59 |
+
MAX_TEXT_LENGTH = 10000
|
60 |
AI_CHUNK_SIZE = 512
|
61 |
PLAGIARISM_THRESHOLD = 0.75
|
62 |
+
TIMEOUT = 30
|
63 |
+
MAX_SENTENCES = 20
|
64 |
|
65 |
# Load models
|
66 |
+
try:
|
67 |
+
logger.info("Loading models...")
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
69 |
+
ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(
|
70 |
+
DEVICE if DEVICE != -1 else "cpu"
|
71 |
+
)
|
72 |
+
ai_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
75 |
+
embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(
|
76 |
+
DEVICE if DEVICE != -1 else "cpu"
|
77 |
+
)
|
78 |
+
embed_model.eval()
|
79 |
+
logger.info("Models loaded successfully")
|
80 |
+
except Exception as e:
|
81 |
+
logger.error(f"Model loading failed: {str(e)}", exc_info=True)
|
82 |
+
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
83 |
|
84 |
def extract_text(pdf_bytes: bytes) -> str:
|
85 |
try:
|
86 |
+
start_time = time.time()
|
87 |
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
|
88 |
text = []
|
89 |
for page in doc:
|
90 |
+
if time.time() - start_time > TIMEOUT / 3: # Early timeout for extraction
|
91 |
+
raise TimeoutError("PDF extraction timed out")
|
92 |
page_text = page.get_text().strip()
|
93 |
if "reference" in page_text.lower():
|
94 |
+
break
|
95 |
text.append(page_text)
|
96 |
|
97 |
full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH]
|
98 |
if len(full_text) < 150:
|
99 |
raise ValueError("Text too short")
|
100 |
+
logger.info(f"Extracted text: {len(full_text)} characters")
|
101 |
return full_text
|
102 |
except Exception as e:
|
103 |
logger.error(f"PDF error: {str(e)}")
|
104 |
raise HTTPException(400, "Invalid PDF")
|
105 |
|
106 |
+
async def predict_ai(text: str) -> float:
|
107 |
+
try:
|
108 |
+
async with asyncio.timeout(TIMEOUT / 2): # Per-task timeout
|
109 |
+
inputs = tokenizer(
|
110 |
+
text,
|
111 |
+
truncation=True,
|
112 |
+
max_length=AI_CHUNK_SIZE,
|
113 |
+
return_tensors="pt",
|
114 |
+
).to(ai_model.device)
|
115 |
+
|
116 |
+
with torch.no_grad():
|
117 |
+
outputs = ai_model(**inputs)
|
118 |
+
probs = torch.softmax(outputs.logits, dim=1)
|
119 |
+
logger.info("AI detection completed")
|
120 |
+
return float(probs[0][1]) # AI-generated probability
|
121 |
+
except asyncio.TimeoutError:
|
122 |
+
logger.error("AI detection timed out")
|
123 |
+
raise HTTPException(500, "AI detection timed out")
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(f"AI detection error: {str(e)}")
|
126 |
+
raise HTTPException(500, f"AI detection failed: {str(e)}")
|
127 |
|
128 |
+
async def compute_embeddings(sentences: List[str]) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
inputs = embed_tokenizer(
|
130 |
sentences,
|
131 |
padding=True,
|
|
|
141 |
last_hidden = outputs.last_hidden_state
|
142 |
return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
|
143 |
1, keepdim=True
|
144 |
+
).cpu().numpy()
|
|
|
145 |
|
146 |
+
async def check_plagiarism(text: str) -> Tuple[float, bool]:
|
147 |
try:
|
148 |
+
async with asyncio.timeout(TIMEOUT / 2): # Per-task timeout
|
149 |
+
sentences = [
|
150 |
+
s for s in sent_tokenize(text) if 5 < len(s.split()) < 100
|
151 |
+
][:MAX_SENTENCES]
|
152 |
+
if len(sentences) < 2:
|
153 |
+
logger.info("Not enough sentences for plagiarism check")
|
154 |
+
return 0.0, False
|
155 |
+
|
156 |
+
embeddings = await compute_embeddings(sentences)
|
157 |
+
sim_matrix = cosine_similarity(embeddings)
|
158 |
+
np.fill_diagonal(sim_matrix, 0)
|
159 |
+
|
160 |
+
n = len(sim_matrix)
|
161 |
+
top_k = max(1, int(0.1 * n * (n - 1) / 2))
|
162 |
+
top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:]
|
163 |
+
avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices]))
|
164 |
+
|
165 |
+
logger.info("Plagiarism check completed")
|
166 |
+
return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD
|
167 |
+
except asyncio.TimeoutError:
|
168 |
+
logger.error("Plagiarism check timed out")
|
169 |
+
return 0.0, False
|
170 |
except Exception as e:
|
171 |
logger.error(f"Plagiarism check error: {str(e)}")
|
172 |
return 0.0, False
|
173 |
|
|
|
174 |
@app.post("/detect")
|
175 |
+
async def detect_ai_and_plagiarism(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
|
176 |
start_time = time.time()
|
177 |
|
178 |
try:
|
|
|
182 |
pdf_data = await file.read()
|
183 |
text = extract_text(pdf_data)
|
184 |
|
185 |
+
ai_future = asyncio.create_task(predict_ai(text))
|
186 |
+
plagiarism_future = asyncio.create_task(check_plagiarism(text))
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
ai_score, (plag_score, plag_risk) = await asyncio.gather(
|
189 |
ai_future, plagiarism_future
|
|
|
191 |
|
192 |
total_time = time.time() - start_time
|
193 |
if total_time > TIMEOUT:
|
194 |
+
logger.error("Processing exceeded timeout")
|
195 |
raise HTTPException(500, "Processing timed out")
|
196 |
|
197 |
+
logger.info(f"Processing completed in {total_time:.2f} seconds")
|
198 |
return {
|
199 |
"ai_generated_percentage": round(ai_score * 100, 2),
|
200 |
"plagiarism_percentage": plag_score,
|
201 |
+
"plagiarism_risk": plag_risk,
|
202 |
+
"processing_time": round(total_time, 2),
|
203 |
}
|
204 |
|
205 |
+
except HTTPException as he:
|
206 |
+
raise
|
207 |
except Exception as e:
|
208 |
logger.error(f"Error: {str(e)}", exc_info=True)
|
209 |
raise HTTPException(500, f"Processing failed: {str(e)}")
|
210 |
+
|
|
|
211 |
|
212 |
|
213 |
|