|
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File |
|
from fastapi.security import HTTPBearer |
|
from pydantic import BaseModel |
|
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config |
|
import torch |
|
import os |
|
import asyncio |
|
from contextlib import asynccontextmanager |
|
import logging |
|
from io import BytesIO |
|
import docx |
|
import fitz |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
SECRET_TOKEN = os.getenv("SECRET_TOKEN") |
|
bearer_scheme = HTTPBearer() |
|
|
|
MODEL_PATH = "./Ai-Text-Detector/model" |
|
WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
model, tokenizer = None, None |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|
|
def load_model(): |
|
global model, tokenizer |
|
try: |
|
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH) |
|
config = GPT2Config.from_pretrained(MODEL_PATH) |
|
model_instance = GPT2LMHeadModel(config) |
|
model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device)) |
|
model_instance.to(device) |
|
model_instance.eval() |
|
model, tokenizer = model_instance, tokenizer |
|
logging.info("Model loaded successfully.") |
|
except Exception as e: |
|
logging.error(f"Error loading model: {str(e)}") |
|
raise RuntimeError(f"Error loading model: {str(e)}") |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
load_model() |
|
yield |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
|
|
def classify_text(text: str): |
|
if not model or not tokenizer: |
|
raise RuntimeError("Model or tokenizer not loaded.") |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids) |
|
loss = outputs.loss |
|
perplexity = torch.exp(loss).item() |
|
|
|
if perplexity < 60: |
|
return "AI-generated", perplexity |
|
elif perplexity < 80: |
|
return "Probably AI-generated", perplexity |
|
else: |
|
return "Human-written", perplexity |
|
|
|
|
|
@app.post("/analyze") |
|
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)): |
|
|
|
if token.credentials != SECRET_TOKEN: |
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
|
text = data.text.strip() |
|
|
|
|
|
if not text: |
|
raise HTTPException(status_code=400, detail="Text cannot be empty") |
|
|
|
if len(text.split()) < 2: |
|
raise HTTPException(status_code=400, detail="Text must contain at least two words") |
|
|
|
try: |
|
|
|
label, perplexity = await asyncio.to_thread(classify_text, text) |
|
return {"result": label, "perplexity": round(perplexity, 2)} |
|
except Exception as e: |
|
logging.error(f"Error processing text: {str(e)}") |
|
raise HTTPException(status_code=500, detail="Model processing error") |
|
|
|
|
|
def parse_docx(file: BytesIO): |
|
doc = docx.Document(file) |
|
text = "" |
|
for para in doc.paragraphs: |
|
text += para.text + "\n" |
|
return text |
|
|
|
|
|
def parse_pdf(file: BytesIO): |
|
try: |
|
doc = fitz.open(stream=file, filetype="pdf") |
|
text = "" |
|
for page_num in range(doc.page_count): |
|
page = doc.load_page(page_num) |
|
text += page.get_text() |
|
return text |
|
except Exception as e: |
|
logging.error(f"Error while processing PDF: {str(e)}") |
|
raise HTTPException(status_code=500, detail="Error processing PDF file") |
|
|
|
|
|
def parse_txt(file: BytesIO): |
|
return file.read().decode("utf-8") |
|
|
|
|
|
@app.post("/upload/") |
|
async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)): |
|
file_contents = None |
|
try: |
|
if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': |
|
file_contents = parse_docx(BytesIO(await file.read())) |
|
elif file.content_type == 'application/pdf': |
|
file_contents = parse_pdf(BytesIO(await file.read())) |
|
elif file.content_type == 'text/plain': |
|
file_contents = parse_txt(BytesIO(await file.read())) |
|
else: |
|
raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.") |
|
|
|
logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}") |
|
|
|
|
|
if len(file_contents) > 10000: |
|
return {"message": "File contains more than 10,000 characters."} |
|
|
|
|
|
cleaned_text = file_contents.replace("\n", "").replace("\t", "") |
|
|
|
|
|
label, perplexity = await asyncio.to_thread(classify_text, cleaned_text) |
|
return {"result": label, "perplexity": round(perplexity, 2)} |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing file: {str(e)}") |
|
raise HTTPException(status_code=500, detail="Error processing the file") |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "ok"} |
|
|
|
|
|
@app.get("/") |
|
def index(): |
|
return { |
|
"message": "FastAPI AI Text Detector is running.", |
|
"usage": "Use /docs or /analyze to test the API." |
|
} |
|
|