Essay-Grader commited on
Commit
164dd9f
·
1 Parent(s): 1c12f42

Fixed the main.py

Browse files
Files changed (1) hide show
  1. main.py +85 -69
main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import re
3
  import time
@@ -5,7 +7,7 @@ import logging
5
  from pathlib import Path
6
  from typing import List, Tuple
7
 
8
- from fastapi import FastAPI, UploadFile, File, HTTPException
9
  from fastapi.middleware.cors import CORSMiddleware
10
 
11
  import fitz # PyMuPDF
@@ -51,67 +53,79 @@ app.add_middleware(
51
  )
52
 
53
  # Model configs
54
- MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
55
- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
56
  DEVICE = 0 if torch.cuda.is_available() else -1
57
- MAX_TEXT_LENGTH = 10000
58
  AI_CHUNK_SIZE = 512
59
  PLAGIARISM_THRESHOLD = 0.75
60
- TIMEOUT = 25 # total timeout buffer
 
61
 
62
  # Load models
63
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
64
- ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(
65
- DEVICE if DEVICE != -1 else "cpu"
66
- )
67
- ai_model.eval()
68
-
69
- embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
70
- embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(
71
- DEVICE if DEVICE != -1 else "cpu"
72
- )
73
- embed_model.eval()
74
-
75
- # Health check
76
- # @app.get("/health")
77
- # def health_check():
78
- # return {"status": "healthy"}
79
 
 
 
 
 
 
 
 
 
 
80
 
81
  def extract_text(pdf_bytes: bytes) -> str:
82
  try:
 
83
  with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
84
  text = []
85
  for page in doc:
 
 
86
  page_text = page.get_text().strip()
87
  if "reference" in page_text.lower():
88
- break # Exclude reference section
89
  text.append(page_text)
90
 
91
  full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH]
92
  if len(full_text) < 150:
93
  raise ValueError("Text too short")
 
94
  return full_text
95
  except Exception as e:
96
  logger.error(f"PDF error: {str(e)}")
97
  raise HTTPException(400, "Invalid PDF")
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- def predict_ai(text: str) -> float:
101
- inputs = tokenizer(
102
- text,
103
- truncation=True,
104
- max_length=AI_CHUNK_SIZE,
105
- return_tensors="pt",
106
- ).to(ai_model.device)
107
-
108
- with torch.no_grad():
109
- outputs = ai_model(**inputs)
110
- probs = torch.softmax(outputs.logits, dim=1)
111
- return float(probs[0][1]) # AI-generated probability
112
-
113
-
114
- def compute_embeddings(sentences: List[str]) -> np.ndarray:
115
  inputs = embed_tokenizer(
116
  sentences,
117
  padding=True,
@@ -127,34 +141,38 @@ def compute_embeddings(sentences: List[str]) -> np.ndarray:
127
  last_hidden = outputs.last_hidden_state
128
  return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
129
  1, keepdim=True
130
- )
131
-
132
 
133
- def check_plagiarism(text: str) -> Tuple[float, bool]:
134
  try:
135
- sentences = [
136
- s for s in sent_tokenize(text) if 5 < len(s.split()) < 100
137
- ][:40] # limit
138
- if len(sentences) < 2:
139
- return 0.0, False
140
-
141
- embeddings = compute_embeddings(sentences).cpu().numpy()
142
- sim_matrix = cosine_similarity(embeddings)
143
- np.fill_diagonal(sim_matrix, 0)
144
-
145
- n = len(sim_matrix)
146
- top_k = max(1, int(0.1 * n * (n - 1) / 2))
147
- top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:]
148
- avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices]))
149
-
150
- return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD
 
 
 
 
 
 
151
  except Exception as e:
152
  logger.error(f"Plagiarism check error: {str(e)}")
153
  return 0.0, False
154
 
155
-
156
  @app.post("/detect")
157
- async def detect_ai_and_plagiarism(file: UploadFile = File(...)):
158
  start_time = time.time()
159
 
160
  try:
@@ -164,14 +182,8 @@ async def detect_ai_and_plagiarism(file: UploadFile = File(...)):
164
  pdf_data = await file.read()
165
  text = extract_text(pdf_data)
166
 
167
- async def run_ai():
168
- return predict_ai(text)
169
-
170
- async def run_plagiarism():
171
- return check_plagiarism(text)
172
-
173
- ai_future = asyncio.create_task(run_ai())
174
- plagiarism_future = asyncio.create_task(run_plagiarism())
175
 
176
  ai_score, (plag_score, plag_risk) = await asyncio.gather(
177
  ai_future, plagiarism_future
@@ -179,19 +191,23 @@ async def detect_ai_and_plagiarism(file: UploadFile = File(...)):
179
 
180
  total_time = time.time() - start_time
181
  if total_time > TIMEOUT:
 
182
  raise HTTPException(500, "Processing timed out")
183
 
 
184
  return {
185
  "ai_generated_percentage": round(ai_score * 100, 2),
186
  "plagiarism_percentage": plag_score,
187
- # "plagiarism_risk": plag_risk
 
188
  }
189
 
 
 
190
  except Exception as e:
191
  logger.error(f"Error: {str(e)}", exc_info=True)
192
  raise HTTPException(500, f"Processing failed: {str(e)}")
193
-
194
-
195
 
196
 
197
 
 
1
+ # main.py: API for Detection and Plagiarism Check
2
+
3
  import os
4
  import re
5
  import time
 
7
  from pathlib import Path
8
  from typing import List, Tuple
9
 
10
+ from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
11
  from fastapi.middleware.cors import CORSMiddleware
12
 
13
  import fitz # PyMuPDF
 
53
  )
54
 
55
  # Model configs
56
+ MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
57
+ EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2"
58
  DEVICE = 0 if torch.cuda.is_available() else -1
59
+ MAX_TEXT_LENGTH = 10000
60
  AI_CHUNK_SIZE = 512
61
  PLAGIARISM_THRESHOLD = 0.75
62
+ TIMEOUT = 30
63
+ MAX_SENTENCES = 20
64
 
65
  # Load models
66
+ try:
67
+ logger.info("Loading models...")
68
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
69
+ ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(
70
+ DEVICE if DEVICE != -1 else "cpu"
71
+ )
72
+ ai_model.eval()
 
 
 
 
 
 
 
 
 
73
 
74
+ embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
75
+ embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(
76
+ DEVICE if DEVICE != -1 else "cpu"
77
+ )
78
+ embed_model.eval()
79
+ logger.info("Models loaded successfully")
80
+ except Exception as e:
81
+ logger.error(f"Model loading failed: {str(e)}", exc_info=True)
82
+ raise RuntimeError(f"Failed to initialize models: {str(e)}")
83
 
84
  def extract_text(pdf_bytes: bytes) -> str:
85
  try:
86
+ start_time = time.time()
87
  with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
88
  text = []
89
  for page in doc:
90
+ if time.time() - start_time > TIMEOUT / 3: # Early timeout for extraction
91
+ raise TimeoutError("PDF extraction timed out")
92
  page_text = page.get_text().strip()
93
  if "reference" in page_text.lower():
94
+ break
95
  text.append(page_text)
96
 
97
  full_text = re.sub(r"\s+", " ", "\n".join(text))[:MAX_TEXT_LENGTH]
98
  if len(full_text) < 150:
99
  raise ValueError("Text too short")
100
+ logger.info(f"Extracted text: {len(full_text)} characters")
101
  return full_text
102
  except Exception as e:
103
  logger.error(f"PDF error: {str(e)}")
104
  raise HTTPException(400, "Invalid PDF")
105
 
106
+ async def predict_ai(text: str) -> float:
107
+ try:
108
+ async with asyncio.timeout(TIMEOUT / 2): # Per-task timeout
109
+ inputs = tokenizer(
110
+ text,
111
+ truncation=True,
112
+ max_length=AI_CHUNK_SIZE,
113
+ return_tensors="pt",
114
+ ).to(ai_model.device)
115
+
116
+ with torch.no_grad():
117
+ outputs = ai_model(**inputs)
118
+ probs = torch.softmax(outputs.logits, dim=1)
119
+ logger.info("AI detection completed")
120
+ return float(probs[0][1]) # AI-generated probability
121
+ except asyncio.TimeoutError:
122
+ logger.error("AI detection timed out")
123
+ raise HTTPException(500, "AI detection timed out")
124
+ except Exception as e:
125
+ logger.error(f"AI detection error: {str(e)}")
126
+ raise HTTPException(500, f"AI detection failed: {str(e)}")
127
 
128
+ async def compute_embeddings(sentences: List[str]) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  inputs = embed_tokenizer(
130
  sentences,
131
  padding=True,
 
141
  last_hidden = outputs.last_hidden_state
142
  return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(
143
  1, keepdim=True
144
+ ).cpu().numpy()
 
145
 
146
+ async def check_plagiarism(text: str) -> Tuple[float, bool]:
147
  try:
148
+ async with asyncio.timeout(TIMEOUT / 2): # Per-task timeout
149
+ sentences = [
150
+ s for s in sent_tokenize(text) if 5 < len(s.split()) < 100
151
+ ][:MAX_SENTENCES]
152
+ if len(sentences) < 2:
153
+ logger.info("Not enough sentences for plagiarism check")
154
+ return 0.0, False
155
+
156
+ embeddings = await compute_embeddings(sentences)
157
+ sim_matrix = cosine_similarity(embeddings)
158
+ np.fill_diagonal(sim_matrix, 0)
159
+
160
+ n = len(sim_matrix)
161
+ top_k = max(1, int(0.1 * n * (n - 1) / 2))
162
+ top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:]
163
+ avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices]))
164
+
165
+ logger.info("Plagiarism check completed")
166
+ return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD
167
+ except asyncio.TimeoutError:
168
+ logger.error("Plagiarism check timed out")
169
+ return 0.0, False
170
  except Exception as e:
171
  logger.error(f"Plagiarism check error: {str(e)}")
172
  return 0.0, False
173
 
 
174
  @app.post("/detect")
175
+ async def detect_ai_and_plagiarism(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
176
  start_time = time.time()
177
 
178
  try:
 
182
  pdf_data = await file.read()
183
  text = extract_text(pdf_data)
184
 
185
+ ai_future = asyncio.create_task(predict_ai(text))
186
+ plagiarism_future = asyncio.create_task(check_plagiarism(text))
 
 
 
 
 
 
187
 
188
  ai_score, (plag_score, plag_risk) = await asyncio.gather(
189
  ai_future, plagiarism_future
 
191
 
192
  total_time = time.time() - start_time
193
  if total_time > TIMEOUT:
194
+ logger.error("Processing exceeded timeout")
195
  raise HTTPException(500, "Processing timed out")
196
 
197
+ logger.info(f"Processing completed in {total_time:.2f} seconds")
198
  return {
199
  "ai_generated_percentage": round(ai_score * 100, 2),
200
  "plagiarism_percentage": plag_score,
201
+ "plagiarism_risk": plag_risk,
202
+ "processing_time": round(total_time, 2),
203
  }
204
 
205
+ except HTTPException as he:
206
+ raise
207
  except Exception as e:
208
  logger.error(f"Error: {str(e)}", exc_info=True)
209
  raise HTTPException(500, f"Processing failed: {str(e)}")
210
+
 
211
 
212
 
213