COM ADMIN
commited on
Commit
·
e4267ec
1
Parent(s):
ccf4e3f
Try another model for accuracy
Browse files
main.py
CHANGED
@@ -16,7 +16,7 @@ os.environ["NLTK_DATA"] = "/tmp/.cache/nltk"
|
|
16 |
Path("/tmp/.cache/huggingface").mkdir(parents=True, exist_ok=True)
|
17 |
Path("/tmp/.cache/nltk").mkdir(parents=True, exist_ok=True)
|
18 |
|
19 |
-
#
|
20 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
21 |
from fastapi.middleware.cors import CORSMiddleware
|
22 |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
@@ -27,11 +27,11 @@ import nltk
|
|
27 |
from nltk.tokenize import sent_tokenize
|
28 |
from sklearn.metrics.pairwise import cosine_similarity
|
29 |
|
30 |
-
#
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
-
#
|
35 |
try:
|
36 |
nltk.data.path.append("/tmp/.cache/nltk")
|
37 |
nltk.data.find('tokenizers/punkt')
|
@@ -40,8 +40,8 @@ except LookupError:
|
|
40 |
nltk.download('punkt', download_dir="/tmp/.cache/nltk")
|
41 |
nltk.data.path.append("/tmp/.cache/nltk")
|
42 |
|
|
|
43 |
app = FastAPI()
|
44 |
-
|
45 |
app.add_middleware(
|
46 |
CORSMiddleware,
|
47 |
allow_origins=["*"],
|
@@ -49,104 +49,81 @@ app.add_middleware(
|
|
49 |
allow_headers=["*"],
|
50 |
)
|
51 |
|
52 |
-
#
|
53 |
-
MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
|
54 |
-
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
55 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
56 |
-
MAX_TEXT_LENGTH = 6000 # Optimal balance between accuracy and speed
|
57 |
PLAGIARISM_THRESHOLD = 0.75
|
58 |
-
TIMEOUT = 25
|
59 |
-
AI_CHUNK_SIZE = 718
|
60 |
-
|
61 |
-
# Health check endpoint
|
62 |
-
@app.get("/health")
|
63 |
-
def health_check():
|
64 |
-
return {"status": "healthy"}
|
65 |
|
66 |
# Load models
|
67 |
try:
|
68 |
-
logger.info("Loading fine-tuned AI detection model...")
|
69 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
70 |
ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE if DEVICE != -1 else "cpu")
|
71 |
ai_model.eval()
|
72 |
-
|
73 |
-
logger.info("Loading embedding model...")
|
74 |
embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
75 |
embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(DEVICE if DEVICE != -1 else "cpu")
|
76 |
embed_model.eval()
|
77 |
-
|
78 |
-
logger.info("All models loaded successfully")
|
79 |
except Exception as e:
|
80 |
logger.error(f"Model loading failed: {str(e)}", exc_info=True)
|
81 |
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
82 |
|
83 |
def extract_text(pdf_bytes: bytes) -> str:
|
84 |
-
"""Efficient text extraction with length control"""
|
85 |
try:
|
86 |
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
|
87 |
text = []
|
88 |
for page in doc:
|
89 |
-
if len('\n'.join(text)) > MAX_TEXT_LENGTH:
|
90 |
-
break
|
91 |
text.append(page.get_text().strip())
|
92 |
-
full_text = re.sub(r'\s+', ' ', '\n'.join(text))
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
raise ValueError("Text too short")
|
|
|
95 |
return full_text
|
96 |
except Exception as e:
|
97 |
logger.error(f"PDF processing error: {str(e)}")
|
98 |
raise HTTPException(400, "Invalid PDF content")
|
99 |
|
100 |
def predict_ai(text: str) -> float:
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
truncation=True,
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
probs = torch.softmax(outputs.logits, dim=1)
|
113 |
-
return float(probs[0][1]) # Assuming label '1' is AI-generated
|
114 |
|
115 |
def compute_embeddings(sentences: List[str]) -> np.ndarray:
|
116 |
-
|
117 |
-
inputs = embed_tokenizer(
|
118 |
-
sentences,
|
119 |
-
padding=True,
|
120 |
-
truncation=True,
|
121 |
-
max_length=128,
|
122 |
-
return_tensors="pt"
|
123 |
-
).to(embed_model.device)
|
124 |
-
|
125 |
with torch.no_grad():
|
126 |
outputs = embed_model(**inputs)
|
127 |
-
|
128 |
-
# Mean pooling
|
129 |
attention_mask = inputs['attention_mask']
|
130 |
last_hidden = outputs.last_hidden_state
|
131 |
return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)
|
132 |
|
133 |
def check_plagiarism(text: str) -> Tuple[float, bool]:
|
134 |
-
"""Optimized plagiarism check"""
|
135 |
try:
|
136 |
-
sentences = [s for s in sent_tokenize(text) if 5 < len(s.split()) < 100][:40]
|
137 |
if len(sentences) < 2:
|
138 |
return 0.0, False
|
139 |
-
|
140 |
embeddings = compute_embeddings(sentences).cpu().numpy()
|
141 |
sim_matrix = cosine_similarity(embeddings)
|
142 |
np.fill_diagonal(sim_matrix, 0)
|
143 |
-
|
144 |
-
# Check top 10% most similar pairs
|
145 |
n = len(sim_matrix)
|
146 |
top_k = max(1, int(0.1 * n * (n - 1) / 2))
|
147 |
top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:]
|
148 |
avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices]))
|
149 |
-
|
150 |
return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD
|
151 |
except Exception as e:
|
152 |
logger.error(f"Plagiarism check error: {str(e)}")
|
@@ -154,37 +131,25 @@ def check_plagiarism(text: str) -> Tuple[float, bool]:
|
|
154 |
|
155 |
@app.post("/detect")
|
156 |
async def detect_ai_content(file: UploadFile = File(...)):
|
157 |
-
"""Optimized endpoint using your fine-tuned model"""
|
158 |
start_time = time.time()
|
159 |
-
|
160 |
try:
|
161 |
-
# Fast validation
|
162 |
if not file.filename.lower().endswith('.pdf'):
|
163 |
raise HTTPException(400, "Only PDF files are accepted")
|
164 |
|
165 |
-
# Extract text
|
166 |
text = extract_text(await file.read())
|
167 |
-
logger.info(f"
|
168 |
|
169 |
-
|
170 |
-
ai_score = predict_ai(text[:AI_CHUNK_SIZE])
|
171 |
ai_percentage = round(ai_score * 100, 2)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
if time.time() - start_time < TIMEOUT - 5: # 5 second buffer
|
176 |
-
plagiarism_score, plagiarism_risk = check_plagiarism(text)
|
177 |
|
178 |
-
# Final timeout check
|
179 |
if time.time() - start_time > TIMEOUT:
|
180 |
raise HTTPException(500, "Processing timed out")
|
181 |
|
182 |
return {
|
183 |
-
"ai_generated_percentage": ai_percentage
|
184 |
-
"plagiarism_risk": plagiarism_risk,
|
185 |
-
"plagiarism_score": plagiarism_score,
|
186 |
-
"processing_time": round(time.time() - start_time, 2),
|
187 |
-
"model_used": MODEL_NAME # Show which model was used
|
188 |
}
|
189 |
|
190 |
except HTTPException:
|
@@ -196,7 +161,6 @@ async def detect_ai_content(file: UploadFile = File(...)):
|
|
196 |
|
197 |
|
198 |
|
199 |
-
|
200 |
# # main.py: Optimized AI Detection and Plagiarism Check API
|
201 |
|
202 |
# import os
|
|
|
16 |
Path("/tmp/.cache/huggingface").mkdir(parents=True, exist_ok=True)
|
17 |
Path("/tmp/.cache/nltk").mkdir(parents=True, exist_ok=True)
|
18 |
|
19 |
+
# Imports
|
20 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
21 |
from fastapi.middleware.cors import CORSMiddleware
|
22 |
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
|
|
|
27 |
from nltk.tokenize import sent_tokenize
|
28 |
from sklearn.metrics.pairwise import cosine_similarity
|
29 |
|
30 |
+
# Logging
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
34 |
+
# NLTK setup
|
35 |
try:
|
36 |
nltk.data.path.append("/tmp/.cache/nltk")
|
37 |
nltk.data.find('tokenizers/punkt')
|
|
|
40 |
nltk.download('punkt', download_dir="/tmp/.cache/nltk")
|
41 |
nltk.data.path.append("/tmp/.cache/nltk")
|
42 |
|
43 |
+
# FastAPI init
|
44 |
app = FastAPI()
|
|
|
45 |
app.add_middleware(
|
46 |
CORSMiddleware,
|
47 |
allow_origins=["*"],
|
|
|
49 |
allow_headers=["*"],
|
50 |
)
|
51 |
|
52 |
+
# Configs
|
53 |
+
MODEL_NAME = "Essay-Grader/roberta-ai-detector-20250401_232702"
|
54 |
+
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
55 |
DEVICE = 0 if torch.cuda.is_available() else -1
|
|
|
56 |
PLAGIARISM_THRESHOLD = 0.75
|
57 |
+
TIMEOUT = 25
|
58 |
+
AI_CHUNK_SIZE = 718
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
# Load models
|
61 |
try:
|
|
|
62 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
63 |
ai_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE if DEVICE != -1 else "cpu")
|
64 |
ai_model.eval()
|
65 |
+
|
|
|
66 |
embed_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
|
67 |
embed_model = AutoModel.from_pretrained(EMBEDDING_MODEL).to(DEVICE if DEVICE != -1 else "cpu")
|
68 |
embed_model.eval()
|
69 |
+
logger.info("Models loaded")
|
|
|
70 |
except Exception as e:
|
71 |
logger.error(f"Model loading failed: {str(e)}", exc_info=True)
|
72 |
raise RuntimeError(f"Failed to initialize models: {str(e)}")
|
73 |
|
74 |
def extract_text(pdf_bytes: bytes) -> str:
|
|
|
75 |
try:
|
76 |
with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
|
77 |
text = []
|
78 |
for page in doc:
|
|
|
|
|
79 |
text.append(page.get_text().strip())
|
80 |
+
full_text = re.sub(r'\s+', ' ', '\n'.join(text))
|
81 |
+
|
82 |
+
# Cut off after "References" or similar
|
83 |
+
match = re.search(r'(references|bibliography|works cited)', full_text, re.IGNORECASE)
|
84 |
+
if match:
|
85 |
+
full_text = full_text[:match.start()]
|
86 |
+
|
87 |
+
if len(full_text) < 150:
|
88 |
raise ValueError("Text too short")
|
89 |
+
|
90 |
return full_text
|
91 |
except Exception as e:
|
92 |
logger.error(f"PDF processing error: {str(e)}")
|
93 |
raise HTTPException(400, "Invalid PDF content")
|
94 |
|
95 |
def predict_ai(text: str) -> float:
|
96 |
+
chunks = [text[i:i+AI_CHUNK_SIZE] for i in range(0, len(text), AI_CHUNK_SIZE)]
|
97 |
+
total_score = 0.0
|
98 |
+
for chunk in chunks:
|
99 |
+
inputs = tokenizer(chunk, truncation=True, max_length=AI_CHUNK_SIZE, return_tensors="pt").to(ai_model.device)
|
100 |
+
with torch.no_grad():
|
101 |
+
outputs = ai_model(**inputs)
|
102 |
+
probs = torch.softmax(outputs.logits, dim=1)
|
103 |
+
total_score += float(probs[0][1])
|
104 |
+
avg_score = total_score / len(chunks)
|
105 |
+
return avg_score
|
|
|
|
|
|
|
106 |
|
107 |
def compute_embeddings(sentences: List[str]) -> np.ndarray:
|
108 |
+
inputs = embed_tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors="pt").to(embed_model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
with torch.no_grad():
|
110 |
outputs = embed_model(**inputs)
|
|
|
|
|
111 |
attention_mask = inputs['attention_mask']
|
112 |
last_hidden = outputs.last_hidden_state
|
113 |
return (last_hidden * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(1, keepdim=True)
|
114 |
|
115 |
def check_plagiarism(text: str) -> Tuple[float, bool]:
|
|
|
116 |
try:
|
117 |
+
sentences = [s for s in sent_tokenize(text) if 5 < len(s.split()) < 100][:40]
|
118 |
if len(sentences) < 2:
|
119 |
return 0.0, False
|
|
|
120 |
embeddings = compute_embeddings(sentences).cpu().numpy()
|
121 |
sim_matrix = cosine_similarity(embeddings)
|
122 |
np.fill_diagonal(sim_matrix, 0)
|
|
|
|
|
123 |
n = len(sim_matrix)
|
124 |
top_k = max(1, int(0.1 * n * (n - 1) / 2))
|
125 |
top_indices = np.argpartition(sim_matrix.flatten(), -top_k)[-top_k:]
|
126 |
avg_similarity = float(np.mean(sim_matrix.flatten()[top_indices]))
|
|
|
127 |
return round(avg_similarity * 100, 2), avg_similarity > PLAGIARISM_THRESHOLD
|
128 |
except Exception as e:
|
129 |
logger.error(f"Plagiarism check error: {str(e)}")
|
|
|
131 |
|
132 |
@app.post("/detect")
|
133 |
async def detect_ai_content(file: UploadFile = File(...)):
|
|
|
134 |
start_time = time.time()
|
|
|
135 |
try:
|
|
|
136 |
if not file.filename.lower().endswith('.pdf'):
|
137 |
raise HTTPException(400, "Only PDF files are accepted")
|
138 |
|
|
|
139 |
text = extract_text(await file.read())
|
140 |
+
logger.info(f"Text length: {len(text)}")
|
141 |
|
142 |
+
ai_score = predict_ai(text)
|
|
|
143 |
ai_percentage = round(ai_score * 100, 2)
|
144 |
|
145 |
+
if time.time() - start_time < TIMEOUT - 5:
|
146 |
+
check_plagiarism(text) # Run, but don’t return
|
|
|
|
|
147 |
|
|
|
148 |
if time.time() - start_time > TIMEOUT:
|
149 |
raise HTTPException(500, "Processing timed out")
|
150 |
|
151 |
return {
|
152 |
+
"ai_generated_percentage": ai_percentage
|
|
|
|
|
|
|
|
|
153 |
}
|
154 |
|
155 |
except HTTPException:
|
|
|
161 |
|
162 |
|
163 |
|
|
|
164 |
# # main.py: Optimized AI Detection and Plagiarism Check API
|
165 |
|
166 |
# import os
|