File size: 2,530 Bytes
e9f0d54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import asyncio

# FastAPI app instance
app = FastAPI()

# Global model and tokenizer variables
model, tokenizer = None, None

# Function to load model and tokenizer
def load_model():
    model_path = "./Ai-Text-Detector/model"
    weights_path = "./Ai-Text-Detector/model_weights.pth"

    try:
        tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
        config = GPT2Config.from_pretrained(model_path)
        model = GPT2LMHeadModel(config)
        model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
        model.eval()  # Set model to evaluation mode
    except Exception as e:
        raise RuntimeError(f"Error loading model: {str(e)}")

    return model, tokenizer

# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model, tokenizer
    model, tokenizer = load_model()
    yield

# Attach startup loader
app = FastAPI(lifespan=lifespan)

# Input schema
class TextInput(BaseModel):
    text: str

# Sync text classification
def classify_text(sentence: str):
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        perplexity = torch.exp(loss).item()

    if perplexity < 60:
        result = "AI-generated"
    elif perplexity < 80:
        result = "Probably AI-generated"
    else:
        result = "Human-written"

    return result, perplexity

# POST route to analyze text
@app.post("/analyze")
async def analyze_text(data: TextInput):
    user_input = data.text.strip()
    if not user_input:
        raise HTTPException(status_code=400, detail="Text cannot be empty")

    # Run classification asynchronously to prevent blocking
    result, perplexity = await asyncio.to_thread(classify_text, user_input)
    
    return {
        "result": result,
        "perplexity": round(perplexity, 2),
    }

# Health check route
@app.get("/health")
async def health_check():
    return {"status": "ok"}

# Simple index route
@app.get("/")
def index():
    return {
        "message": "FastAPI API is up.",
        "try": "/docs to test the API.",
        "status": "OK"
    }