thecuong's picture
feat: update fast API
ebb0dae
raw
history blame
2.08 kB
from typing import List, Literal
from pydantic import BaseModel, Field
import gradio as gr
from fastapi import FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
import uvicorn
import requests
# Khởi tạo FastAPI
app = FastAPI()
# Thêm middleware CORS để cho phép yêu cầu từ Gradio
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Cho phép tất cả các nguồn
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Tải mô hình
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base')
# Định nghĩa mô hình dữ liệu cho yêu cầu
class PostEmbeddings(BaseModel):
type: Literal['default', 'disease', 'gte'] = Field(default='default')
sentences: List[str]
# Tạo router cho API
router = APIRouter(
prefix="/retrieval",
tags=["retrieval"],
responses={404: {"description": "Not found"}},
)
@app.post("/retrieval/embeddings")
def post_embeddings(data: PostEmbeddings):
embeddings = model.encode(data.sentences)
return {
'data': {
'embeddings': embeddings.tolist(),
'type': data.type
}
}
# Hàm Gradio để gọi API FastAPI
def call_api(sentences: List[str]):
response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences})
return response.json()["data"]
# Tạo giao diện Gradio
demo = gr.Interface(
fn=call_api,
inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
outputs=gr.JSON(label="Kết quả mã hóa"),
title="Mô hình GTE Multilingual",
description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual."
)
# Khởi động server
if __name__ == "__main__":
import threading
# Khởi động FastAPI trong một thread riêng
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
# Khởi động Gradio
demo.launch()