Essay-Grader commited on
Commit
109611c
·
1 Parent(s): c7b743d

Fix the api

Browse files
Files changed (1) hide show
  1. main.py +63 -56
main.py CHANGED
@@ -1,4 +1,5 @@
1
  # main.py: AI Detection and Plagiarism Check API
 
2
  from fastapi import FastAPI, UploadFile, File, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from transformers import pipeline
@@ -6,6 +7,7 @@ from sentence_transformers import SentenceTransformer, util
6
  import fitz
7
  import logging
8
  import torch
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
@@ -25,7 +27,8 @@ AI_MODEL = "Hello-SimpleAI/chatgpt-detector-roberta"
25
  PLAGIARISM_MODEL = "sentence-transformers/all-mpnet-base-v2"
26
  DEVICE = 0 if torch.cuda.is_available() else -1
27
  MAX_SEQ_LENGTH = 512
28
- CHUNK_SIZE = 500 # Characters per chunk
 
29
 
30
  # Initialize models
31
  ai_pipeline = None
@@ -35,110 +38,114 @@ plagiarism_model = None
35
  def initialize_models():
36
  global ai_pipeline, plagiarism_model
37
  try:
38
- # Configure AI pipeline with proper text handling
39
  ai_pipeline = pipeline(
40
  "text-classification",
41
  model=AI_MODEL,
42
  device=DEVICE,
43
- padding="max_length",
44
  truncation=True,
45
  max_length=MAX_SEQ_LENGTH
46
  )
47
- logger.info("AI model loaded successfully")
48
 
49
- # Configure plagiarism detector
50
  plagiarism_model = SentenceTransformer(PLAGIARISM_MODEL)
51
- logger.info("Plagiarism model loaded successfully")
52
 
53
  except Exception as e:
54
- logger.error(f"Model initialization failed: {str(e)}", exc_info=True)
55
- raise RuntimeError(f"Model loading failed: {str(e)}")
56
 
57
  def extract_text(pdf_bytes: bytes) -> str:
58
- """Extract and validate PDF text content"""
59
  try:
60
  with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
61
- text = " ".join(page.get_text() for page in doc).strip()
62
- if not text:
63
- raise ValueError("Empty PDF file")
64
  if len(text) < 100:
65
- raise ValueError("Text too short (min 100 characters)")
66
  return text
67
  except Exception as e:
68
- logger.error(f"PDF extraction error: {str(e)}")
69
- raise HTTPException(400, f"PDF processing failed: {str(e)}")
70
 
71
  def analyze_ai_content(text: str) -> float:
72
- """Analyze text for AI-generated content with chunking"""
73
  try:
74
- # Split text into manageable chunks
75
- chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
76
- if not chunks:
77
- return 0.0
 
 
78
 
79
- ai_scores = []
 
 
 
80
  for chunk in chunks:
81
  result = ai_pipeline(chunk)
82
- score = next(
83
- (r['score'] for r in result if r['label'] in ['AI', 'Fake']),
84
- 0.0
85
- )
86
- ai_scores.append(score)
87
 
88
- return round((sum(ai_scores) / len(ai_scores)) * 100, 2)
89
 
90
  except Exception as e:
91
- logger.error(f"AI analysis failed: {str(e)}", exc_info=True)
92
- raise HTTPException(500, "AI analysis error")
93
 
94
  def analyze_plagiarism(text: str) -> float:
95
- """Check for potential plagiarism"""
96
  try:
97
- # Sample reference texts - replace with your database
98
  reference_texts = [
99
- "Academic integrity is fundamental to learning.",
100
- "Plagiarism undermines educational values.",
101
- "Original thought is essential for innovation."
 
 
102
  ]
103
 
104
- # Encode and compare
105
- doc_emb = plagiarism_model.encode(text, convert_to_tensor=True)
 
 
 
 
 
106
  ref_embs = plagiarism_model.encode(reference_texts, convert_to_tensor=True)
107
- similarities = util.cos_sim(doc_emb, ref_embs)[0]
108
 
109
- # Calculate similarity percentage
110
- match_count = sum(s > 0.75 for s in similarities)
111
- return round((match_count / len(reference_texts)) * 100, 2)
 
 
 
 
112
 
113
  except Exception as e:
114
- logger.error(f"Plagiarism check failed: {str(e)}", exc_info=True)
115
  return 0.0
116
 
117
  @app.post("/analyze")
118
  async def analyze_essay(file: UploadFile = File(...)):
119
- """Main analysis endpoint"""
120
  try:
121
- if not file.filename.lower().endswith(".pdf"):
122
- raise HTTPException(400, "Only PDF files accepted")
123
 
124
- # Process PDF
125
- pdf_bytes = await file.read()
126
- text = extract_text(pdf_bytes)
127
-
128
- # Perform analyses
129
- ai_score = analyze_ai_content(text)
130
- plagiarism_score = analyze_plagiarism(text)
131
 
132
  return {
133
- "ai_generated_percentage": ai_score,
134
- "plagiarism_risk": plagiarism_score
135
  }
136
 
137
- except HTTPException as he:
138
  raise
139
  except Exception as e:
140
- logger.error(f"Unexpected error: {str(e)}", exc_info=True)
141
- raise HTTPException(500, "Processing failed")
142
 
143
 
144
  # from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
 
1
  # main.py: AI Detection and Plagiarism Check API
2
+
3
  from fastapi import FastAPI, UploadFile, File, HTTPException
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from transformers import pipeline
 
7
  import fitz
8
  import logging
9
  import torch
10
+ import numpy as np
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
 
27
  PLAGIARISM_MODEL = "sentence-transformers/all-mpnet-base-v2"
28
  DEVICE = 0 if torch.cuda.is_available() else -1
29
  MAX_SEQ_LENGTH = 512
30
+ CHUNK_SIZE = 400 # Reduced chunk size for token safety
31
+ SIMILARITY_THRESHOLD = 0.65 # Adjusted threshold
32
 
33
  # Initialize models
34
  ai_pipeline = None
 
38
  def initialize_models():
39
  global ai_pipeline, plagiarism_model
40
  try:
41
+ # Verify model labels
42
  ai_pipeline = pipeline(
43
  "text-classification",
44
  model=AI_MODEL,
45
  device=DEVICE,
46
+ padding=True,
47
  truncation=True,
48
  max_length=MAX_SEQ_LENGTH
49
  )
50
+ logger.info(f"AI model labels: {ai_pipeline.model.config.label2id}")
51
 
52
+ # Initialize plagiarism model
53
  plagiarism_model = SentenceTransformer(PLAGIARISM_MODEL)
54
+ logger.info("Models loaded successfully")
55
 
56
  except Exception as e:
57
+ logger.error(f"Initialization failed: {str(e)}", exc_info=True)
58
+ raise RuntimeError(f"Model loading error: {str(e)}")
59
 
60
  def extract_text(pdf_bytes: bytes) -> str:
61
+ """Improved PDF text extraction"""
62
  try:
63
  with fitz.open(stream=pdf_bytes, filetype="pdf") as doc:
64
+ text = "\n".join([page.get_text() for page in doc]).strip()
 
 
65
  if len(text) < 100:
66
+ raise ValueError("Minimum 100 characters required")
67
  return text
68
  except Exception as e:
69
+ logger.error(f"PDF Error: {str(e)}")
70
+ raise HTTPException(400, "Invalid PDF content")
71
 
72
  def analyze_ai_content(text: str) -> float:
73
+ """Robust AI detection with label verification"""
74
  try:
75
+ # Verify model labels
76
+ label_mapping = ai_pipeline.model.config.label2id
77
+ ai_labels = [k for k in label_mapping if k.lower() in ['ai', 'fake']]
78
+
79
+ if not ai_labels:
80
+ raise ValueError("No valid AI labels found in model")
81
 
82
+ # Process in token-aware chunks
83
+ chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE)]
84
+ scores = []
85
+
86
  for chunk in chunks:
87
  result = ai_pipeline(chunk)
88
+ for item in result:
89
+ if item['label'] in ai_labels:
90
+ scores.append(item['score'])
 
 
91
 
92
+ return round((sum(scores)/len(scores)) * 100, 2) if scores else 0.0
93
 
94
  except Exception as e:
95
+ logger.error(f"AI Analysis Error: {str(e)}")
96
+ raise HTTPException(500, "AI analysis failed")
97
 
98
  def analyze_plagiarism(text: str) -> float:
99
+ """Enhanced plagiarism detection"""
100
  try:
101
+ # Use meaningful reference texts
102
  reference_texts = [
103
+ "The importance of academic integrity cannot be overstated.",
104
+ "Plagiarism detection systems help maintain educational standards.",
105
+ "Original work demonstrates true learning and understanding.",
106
+ "Proper citation is essential for avoiding plagiarism.",
107
+ "Educational institutions take academic honesty very seriously."
108
  ]
109
 
110
+ # Sentence-level comparison
111
+ sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 20]
112
+ if not sentences:
113
+ return 0.0
114
+
115
+ # Batch processing
116
+ sentence_embs = plagiarism_model.encode(sentences, convert_to_tensor=True)
117
  ref_embs = plagiarism_model.encode(reference_texts, convert_to_tensor=True)
 
118
 
119
+ # Calculate similarities
120
+ similarities = util.cos_sim(sentence_embs, ref_embs)
121
+ max_similarities = np.max(similarities.cpu().numpy(), axis=1)
122
+
123
+ # Calculate percentage above threshold
124
+ match_count = sum(s > SIMILARITY_THRESHOLD for s in max_similarities)
125
+ return round((match_count / len(sentences)) * 100, 2)
126
 
127
  except Exception as e:
128
+ logger.error(f"Plagiarism Error: {str(e)}")
129
  return 0.0
130
 
131
  @app.post("/analyze")
132
  async def analyze_essay(file: UploadFile = File(...)):
 
133
  try:
134
+ if not file.filename.lower().endswith('.pdf'):
135
+ raise HTTPException(400, "PDF files only")
136
 
137
+ text = extract_text(await file.read())
 
 
 
 
 
 
138
 
139
  return {
140
+ "ai_generated_percentage": analyze_ai_content(text),
141
+ "plagiarism_risk": analyze_plagiarism(text)
142
  }
143
 
144
+ except HTTPException:
145
  raise
146
  except Exception as e:
147
+ logger.error(f"Critical Error: {str(e)}")
148
+ raise HTTPException(500, "Analysis failed")
149
 
150
 
151
  # from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks