File size: 5,995 Bytes
bc07cfe b59d3a6 e9f0d54 b59d3a6 bc07cfe e9f0d54 b59d3a6 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe b59d3a6 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 b59d3a6 e9f0d54 b59d3a6 bc07cfe b59d3a6 bc07cfe e9f0d54 b59d3a6 bc07cfe b59d3a6 bc07cfe e9f0d54 bc07cfe e9f0d54 bc07cfe e9f0d54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File
from fastapi.security import HTTPBearer
from pydantic import BaseModel
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
import torch
import os
import asyncio
from contextlib import asynccontextmanager
import logging
from io import BytesIO
import docx
import fitz # PyMuPDF
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
SECRET_TOKEN = os.getenv("SECRET_TOKEN")
bearer_scheme = HTTPBearer()
# Ai-Text-Detector
MODEL_PATH = "./Ai-Text-Detector/model"
WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth"
# FastAPI app instance
app = FastAPI()
# Global model and tokenizer variables
model, tokenizer = None, None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logging setup
logging.basicConfig(level=logging.DEBUG)
# Load model and tokenizer function
def load_model():
global model, tokenizer
try:
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
config = GPT2Config.from_pretrained(MODEL_PATH)
model_instance = GPT2LMHeadModel(config)
model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
model_instance.to(device)
model_instance.eval()
model, tokenizer = model_instance, tokenizer
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise RuntimeError(f"Error loading model: {str(e)}")
# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
load_model() # Load model when FastAPI app starts
yield
# Attach the lifespan to the app instance
app = FastAPI(lifespan=lifespan)
# Input schema for text analysis
class TextInput(BaseModel):
text: str
# Function to classify text using the model
def classify_text(text: str):
if not model or not tokenizer:
raise RuntimeError("Model or tokenizer not loaded.")
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
return "AI-generated", perplexity
elif perplexity < 80:
return "Probably AI-generated", perplexity
else:
return "Human-written", perplexity
# POST route to analyze text with Bearer token
@app.post("/analyze")
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
# Verify token
if token.credentials != SECRET_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
text = data.text.strip()
# Input validation
if not text:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(text.split()) < 2:
raise HTTPException(status_code=400, detail="Text must contain at least two words")
try:
# Classify text
label, perplexity = await asyncio.to_thread(classify_text, text)
return {"result": label, "perplexity": round(perplexity, 2)}
except Exception as e:
logging.error(f"Error processing text: {str(e)}")
raise HTTPException(status_code=500, detail="Model processing error")
# Function to parse .docx files
def parse_docx(file: BytesIO):
doc = docx.Document(file)
text = ""
for para in doc.paragraphs:
text += para.text + "\n"
return text
# Function to parse .pdf files
def parse_pdf(file: BytesIO):
try:
doc = fitz.open(stream=file, filetype="pdf")
text = ""
for page_num in range(doc.page_count):
page = doc.load_page(page_num)
text += page.get_text()
return text
except Exception as e:
logging.error(f"Error while processing PDF: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing PDF file")
# Function to parse .txt files
def parse_txt(file: BytesIO):
return file.read().decode("utf-8")
# POST route to upload files and analyze content
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)):
file_contents = None
try:
if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
file_contents = parse_docx(BytesIO(await file.read()))
elif file.content_type == 'application/pdf':
file_contents = parse_pdf(BytesIO(await file.read()))
elif file.content_type == 'text/plain':
file_contents = parse_txt(BytesIO(await file.read()))
else:
raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.")
logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}")
# Check if the text length exceeds 10,000 characters
if len(file_contents) > 10000:
return {"message": "File contains more than 10,000 characters."}
# Clean the text by removing newline and tab characters
cleaned_text = file_contents.replace("\n", "").replace("\t", "")
# Analyze the cleaned text
label, perplexity = await asyncio.to_thread(classify_text, cleaned_text)
return {"result": label, "perplexity": round(perplexity, 2)}
except Exception as e:
logging.error(f"Error processing file: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing the file")
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {
"message": "FastAPI AI Text Detector is running.",
"usage": "Use /docs or /analyze to test the API."
}
|