canspace / app.py
Pujan Neupane
Project : pushing all the files to hugging face
e9f0d54
raw
history blame
2.53 kB
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import asyncio
# FastAPI app instance
app = FastAPI()
# Global model and tokenizer variables
model, tokenizer = None, None
# Function to load model and tokenizer
def load_model():
model_path = "./Ai-Text-Detector/model"
weights_path = "./Ai-Text-Detector/model_weights.pth"
try:
tokenizer = GPT2TokenizerFast.from_pretrained(model_path)
config = GPT2Config.from_pretrained(model_path)
model = GPT2LMHeadModel(config)
model.load_state_dict(torch.load(weights_path, map_location=torch.device("cpu")))
model.eval() # Set model to evaluation mode
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
return model, tokenizer
# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer
model, tokenizer = load_model()
yield
# Attach startup loader
app = FastAPI(lifespan=lifespan)
# Input schema
class TextInput(BaseModel):
text: str
# Sync text classification
def classify_text(sentence: str):
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
result = "AI-generated"
elif perplexity < 80:
result = "Probably AI-generated"
else:
result = "Human-written"
return result, perplexity
# POST route to analyze text
@app.post("/analyze")
async def analyze_text(data: TextInput):
user_input = data.text.strip()
if not user_input:
raise HTTPException(status_code=400, detail="Text cannot be empty")
# Run classification asynchronously to prevent blocking
result, perplexity = await asyncio.to_thread(classify_text, user_input)
return {
"result": result,
"perplexity": round(perplexity, 2),
}
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {
"message": "FastAPI API is up.",
"try": "/docs to test the API.",
"status": "OK"
}