import uvicorn from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Khởi tạo FastAPI app = FastAPI() # Tải model và tokenizer khi ứng dụng khởi động model_name = "Qwen/Qwen2.5-0.5B" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", attn_implementation="eager" # Tránh cảnh báo sdpa ) print("Model and tokenizer loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise # Định nghĩa request body class TextInput(BaseModel): prompt: str max_length: int = 100 # API endpoint để sinh văn bản @app.post("/generate") async def generate_text(input: TextInput): try: # Mã hóa đầu vào inputs = tokenizer(input.prompt, return_tensors="pt").to(model.device) # Sinh văn bản outputs = model.generate( inputs["input_ids"], max_length=input.max_length, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95 ) # Giải mã kết quả generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": generated_text} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Endpoint kiểm tra sức khỏe @app.get("/") async def root(): return {"message": "Qwen2.5-0.5B API is running!"} # Endpoint hiển thị API URL đầy đủ @app.get("/api_link") async def get_api_link(request: Request): # Lấy host từ request host = request.client.host # Lấy port từ server (nếu chạy local thì mặc định là 7860) port = request.url.port if request.url.port else 7860 # Tạo URL đầy đủ base_url = f"http://{host}:{port}" return { "api_url": base_url, "endpoints": { "health_check": f"{base_url}/", "generate_text": f"{base_url}/generate", "api_link": f"{base_url}/api_link" } } # Chạy server khi file được gọi trực tiếp if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)