canspace / app.py
Pujan-Dev's picture
feat: add file parsing support and enforce 10,000-character limit
bc07cfe verified
raw
history blame
6 kB
from fastapi import FastAPI, HTTPException, Depends, UploadFile, File
from fastapi.security import HTTPBearer
from pydantic import BaseModel
from transformers import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config
import torch
import os
import asyncio
from contextlib import asynccontextmanager
import logging
from io import BytesIO
import docx
import fitz # PyMuPDF
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
SECRET_TOKEN = os.getenv("SECRET_TOKEN")
bearer_scheme = HTTPBearer()
# Ai-Text-Detector
MODEL_PATH = "./Ai-Text-Detector/model"
WEIGHTS_PATH = "./Ai-Text-Detector/model_weights.pth"
# FastAPI app instance
app = FastAPI()
# Global model and tokenizer variables
model, tokenizer = None, None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Logging setup
logging.basicConfig(level=logging.DEBUG)
# Load model and tokenizer function
def load_model():
global model, tokenizer
try:
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_PATH)
config = GPT2Config.from_pretrained(MODEL_PATH)
model_instance = GPT2LMHeadModel(config)
model_instance.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
model_instance.to(device)
model_instance.eval()
model, tokenizer = model_instance, tokenizer
logging.info("Model loaded successfully.")
except Exception as e:
logging.error(f"Error loading model: {str(e)}")
raise RuntimeError(f"Error loading model: {str(e)}")
# Load model on app startup
@asynccontextmanager
async def lifespan(app: FastAPI):
load_model() # Load model when FastAPI app starts
yield
# Attach the lifespan to the app instance
app = FastAPI(lifespan=lifespan)
# Input schema for text analysis
class TextInput(BaseModel):
text: str
# Function to classify text using the model
def classify_text(text: str):
if not model or not tokenizer:
raise RuntimeError("Model or tokenizer not loaded.")
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
if perplexity < 60:
return "AI-generated", perplexity
elif perplexity < 80:
return "Probably AI-generated", perplexity
else:
return "Human-written", perplexity
# POST route to analyze text with Bearer token
@app.post("/analyze")
async def analyze_text(data: TextInput, token: str = Depends(bearer_scheme)):
# Verify token
if token.credentials != SECRET_TOKEN:
raise HTTPException(status_code=401, detail="Invalid token")
text = data.text.strip()
# Input validation
if not text:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(text.split()) < 2:
raise HTTPException(status_code=400, detail="Text must contain at least two words")
try:
# Classify text
label, perplexity = await asyncio.to_thread(classify_text, text)
return {"result": label, "perplexity": round(perplexity, 2)}
except Exception as e:
logging.error(f"Error processing text: {str(e)}")
raise HTTPException(status_code=500, detail="Model processing error")
# Function to parse .docx files
def parse_docx(file: BytesIO):
doc = docx.Document(file)
text = ""
for para in doc.paragraphs:
text += para.text + "\n"
return text
# Function to parse .pdf files
def parse_pdf(file: BytesIO):
try:
doc = fitz.open(stream=file, filetype="pdf")
text = ""
for page_num in range(doc.page_count):
page = doc.load_page(page_num)
text += page.get_text()
return text
except Exception as e:
logging.error(f"Error while processing PDF: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing PDF file")
# Function to parse .txt files
def parse_txt(file: BytesIO):
return file.read().decode("utf-8")
# POST route to upload files and analyze content
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...), token: str = Depends(bearer_scheme)):
file_contents = None
try:
if file.content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document':
file_contents = parse_docx(BytesIO(await file.read()))
elif file.content_type == 'application/pdf':
file_contents = parse_pdf(BytesIO(await file.read()))
elif file.content_type == 'text/plain':
file_contents = parse_txt(BytesIO(await file.read()))
else:
raise HTTPException(status_code=400, detail="Invalid file type. Only .docx, .pdf, and .txt are allowed.")
logging.debug(f"Extracted Text from {file.filename}:\n{file_contents}")
# Check if the text length exceeds 10,000 characters
if len(file_contents) > 10000:
return {"message": "File contains more than 10,000 characters."}
# Clean the text by removing newline and tab characters
cleaned_text = file_contents.replace("\n", "").replace("\t", "")
# Analyze the cleaned text
label, perplexity = await asyncio.to_thread(classify_text, cleaned_text)
return {"result": label, "perplexity": round(perplexity, 2)}
except Exception as e:
logging.error(f"Error processing file: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing the file")
# Health check route
@app.get("/health")
async def health_check():
return {"status": "ok"}
# Simple index route
@app.get("/")
def index():
return {
"message": "FastAPI AI Text Detector is running.",
"usage": "Use /docs or /analyze to test the API."
}