X

Transformers
code
Not-For-All-Audiences
legal
anything
medical
biology
X
File size: 2,717 Bytes
c2ad250
 
 
ffb427f
c2ad250
 
 
 
ffb427f
 
 
c2ad250
 
 
ffb427f
 
 
 
c2ad250
 
 
 
 
 
 
ffb427f
c2ad250
ffb427f
 
 
 
c2ad250
ffb427f
 
 
 
 
 
 
 
 
 
 
c2ad250
ffb427f
 
 
c2ad250
 
 
 
 
 
 
 
 
 
 
 
ffb427f
 
c2ad250
 
ffb427f
c2ad250
 
 
ffb427f
 
 
 
 
c2ad250
 
 
ffb427f
c2ad250
 
 
ffb427f
c2ad250
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# handler.py

import requests
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Optional

# Setup logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

router = APIRouter()

# Your Hugging Face model URL โ€“ must be public
MODEL_URL = "https://api-inference.huggingface.co/models/CLASSIFIED-HEX/X"

# Input model
class PromptInput(BaseModel):
    prompt: str
    max_tokens: Optional[int] = 250
    temperature: Optional[float] = 0.7
    top_p: Optional[float] = 0.95
    top_k: Optional[int] = 50
    repetition_penalty: Optional[float] = 1.2
    trim_output: Optional[bool] = False  # New feature to remove prompt from result

# Root health check
@router.get("/")
async def root():
    return {"message": "AI text generation backend is running ๐Ÿš€"}

# Ping model check
@router.get("/ping-model")
async def ping_model():
    try:
        response = requests.post(MODEL_URL, json={"inputs": "ping test"})
        if response.status_code == 200:
            return {"status": "Model is online โœ…"}
        else:
            return {"status": "Model responded with error โŒ", "details": response.json()}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Could not reach model: {str(e)}")

# Main generation route
@router.post("/generate")
async def generate_text(input_data: PromptInput):
    payload = {
        "inputs": input_data.prompt,
        "parameters": {
            "max_new_tokens": input_data.max_tokens,
            "temperature": input_data.temperature,
            "top_p": input_data.top_p,
            "top_k": input_data.top_k,
            "repetition_penalty": input_data.repetition_penalty
        }
    }

    try:
        logger.info(f"Sending prompt to model: {input_data.prompt}")
        response = requests.post(MODEL_URL, json=payload)

        if response.status_code != 200:
            logger.error(f"Model error: {response.status_code} - {response.text}")
            raise HTTPException(status_code=response.status_code, detail=response.json())

        result = response.json()
        raw_output = result[0].get("generated_text") if isinstance(result, list) else result.get("generated_text", "")

        # Optionally trim prompt from beginning
        if input_data.trim_output and raw_output.startswith(input_data.prompt):
            raw_output = raw_output[len(input_data.prompt):].lstrip()

        return {
            "status": "success",
            "output": raw_output
        }

    except Exception as e:
        logger.exception("Text generation failed")
        raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")