Spaces:
Sleeping
Sleeping
import os | |
import logging | |
from typing import Optional | |
from datetime import datetime | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, HTTPException, Depends, Security, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel, Field | |
import uvicorn | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables for model | |
model = None | |
tokenizer = None | |
model_loaded = False | |
async def lifespan(app: FastAPI): | |
# Startup | |
global model, tokenizer, model_loaded | |
logger.info("Real LLM AI Assistant starting up...") | |
try: | |
# Try to load actual LLM model | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
# Use a better conversational model | |
model_name = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") | |
logger.info(f"Loading real LLM model: {model_name}") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Load model with optimizations | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
model_loaded = True | |
logger.info("Real LLM model loaded successfully!") | |
except Exception as e: | |
logger.warning(f"Could not load LLM model: {e}") | |
logger.info("Will use fallback responses") | |
model_loaded = False | |
yield | |
# Shutdown | |
logger.info("AI Assistant shutting down...") | |
# Initialize FastAPI app with lifespan | |
app = FastAPI( | |
title="Real LLM AI Agent API", | |
description="AI Agent powered by actual LLM models", | |
version="4.0.0", | |
lifespan=lifespan | |
) | |
# CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Security | |
security = HTTPBearer() | |
# Configuration | |
API_KEYS = { | |
os.getenv("API_KEY_1", "27Eud5J73j6SqPQAT2ioV-CtiCg-p0WNqq6I4U0Ig6E"): "user1", | |
os.getenv("API_KEY_2", "QbzG2CqHU1Nn6F1EogZ1d3dp8ilRTMJQBzS-U"): "user2", | |
} | |
# Request/Response models | |
class ChatRequest(BaseModel): | |
message: str = Field(..., min_length=1, max_length=2000) | |
max_length: Optional[int] = Field(200, ge=50, le=500) | |
temperature: Optional[float] = Field(0.8, ge=0.1, le=1.5) | |
top_p: Optional[float] = Field(0.9, ge=0.1, le=1.0) | |
do_sample: Optional[bool] = Field(True) | |
class ChatResponse(BaseModel): | |
response: str | |
model_used: str | |
timestamp: str | |
processing_time: float | |
tokens_used: int | |
model_loaded: bool | |
class HealthResponse(BaseModel): | |
status: str | |
model_loaded: bool | |
timestamp: str | |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)) -> str: | |
"""Verify API key authentication""" | |
api_key = credentials.credentials | |
if api_key not in API_KEYS: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid API key" | |
) | |
return API_KEYS[api_key] | |
def generate_llm_response(message: str, max_length: int = 200, temperature: float = 0.8, top_p: float = 0.9, do_sample: bool = True) -> tuple: | |
"""Generate response using actual LLM model""" | |
global model, tokenizer, model_loaded | |
if not model_loaded or model is None or tokenizer is None: | |
return "I'm currently running in demo mode. The LLM model couldn't be loaded, but I'm still here to help! Please try asking your question again.", "demo_mode", 0 | |
try: | |
# Prepare input with conversation format | |
input_text = f"Human: {message}\nAssistant:" | |
# Tokenize input | |
inputs = tokenizer.encode(input_text, return_tensors="pt") | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_length=inputs.shape[1] + max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=do_sample, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1, | |
repetition_penalty=1.1, | |
length_penalty=1.0 | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
if "Assistant:" in response: | |
response = response.split("Assistant:")[-1].strip() | |
# Remove the input text if it's still there | |
if input_text.replace("Assistant:", "").strip() in response: | |
response = response.replace(input_text.replace("Assistant:", "").strip(), "").strip() | |
# Clean up the response | |
response = response.strip() | |
if not response: | |
response = "I understand your question, but I'm having trouble generating a proper response right now. Could you please rephrase your question?" | |
# Count tokens | |
tokens_used = len(tokenizer.encode(response)) | |
return response, os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium"), tokens_used | |
except Exception as e: | |
logger.error(f"Error generating LLM response: {str(e)}") | |
return f"I encountered an issue while processing your request. Error: {str(e)}", "error_mode", 0 | |
async def root(): | |
"""Health check endpoint""" | |
return HealthResponse( | |
status="healthy", | |
model_loaded=model_loaded, | |
timestamp=datetime.now().isoformat() | |
) | |
async def health_check(): | |
"""Detailed health check""" | |
return HealthResponse( | |
status="healthy" if model_loaded else "demo_mode", | |
model_loaded=model_loaded, | |
timestamp=datetime.now().isoformat() | |
) | |
async def chat( | |
request: ChatRequest, | |
user: str = Depends(verify_api_key) | |
): | |
"""Main chat endpoint using real LLM model""" | |
start_time = datetime.now() | |
try: | |
# Generate response using actual LLM | |
response_text, model_used, tokens_used = generate_llm_response( | |
request.message, | |
request.max_length, | |
request.temperature, | |
request.top_p, | |
request.do_sample | |
) | |
# Calculate processing time | |
processing_time = (datetime.now() - start_time).total_seconds() | |
return ChatResponse( | |
response=response_text, | |
model_used=model_used, | |
timestamp=datetime.now().isoformat(), | |
processing_time=processing_time, | |
tokens_used=tokens_used, | |
model_loaded=model_loaded | |
) | |
except Exception as e: | |
logger.error(f"Error in chat endpoint: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Error generating response: {str(e)}" | |
) | |
async def get_model_info(user: str = Depends(verify_api_key)): | |
"""Get information about the loaded model""" | |
return { | |
"model_name": os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium"), | |
"model_loaded": model_loaded, | |
"status": "active" if model_loaded else "demo_mode", | |
"capabilities": [ | |
"Real LLM text generation", | |
"Conversational AI responses", | |
"Dynamic response generation", | |
"Adjustable temperature and top_p", | |
"Natural language understanding" | |
], | |
"version": "4.0.0", | |
"type": "Real LLM Model" if model_loaded else "Demo Mode" | |
} | |
async def generate_text( | |
request: ChatRequest, | |
user: str = Depends(verify_api_key) | |
): | |
"""Direct text generation endpoint""" | |
start_time = datetime.now() | |
try: | |
# Generate using LLM without conversation formatting | |
if model_loaded and model is not None and tokenizer is not None: | |
inputs = tokenizer.encode(request.message, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_length=inputs.shape[1] + request.max_length, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
do_sample=request.do_sample, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove input text | |
response = response[len(request.message):].strip() | |
tokens_used = len(tokenizer.encode(response)) | |
model_used = os.getenv("MODEL_NAME", "microsoft/DialoGPT-medium") | |
else: | |
response = "Model not loaded. Running in demo mode." | |
tokens_used = 0 | |
model_used = "demo_mode" | |
processing_time = (datetime.now() - start_time).total_seconds() | |
return ChatResponse( | |
response=response, | |
model_used=model_used, | |
timestamp=datetime.now().isoformat(), | |
processing_time=processing_time, | |
tokens_used=tokens_used, | |
model_loaded=model_loaded | |
) | |
except Exception as e: | |
logger.error(f"Error in generate endpoint: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Error generating text: {str(e)}" | |
) | |
if __name__ == "__main__": | |
# For Hugging Face Spaces | |
port = int(os.getenv("PORT", "7860")) | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=port, | |
reload=False | |
) | |