Spaces:
Sleeping
Sleeping
from typing import List, Literal | |
from pydantic import BaseModel, Field | |
import gradio as gr | |
from fastapi import FastAPI, APIRouter, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from sentence_transformers import SentenceTransformer | |
import uvicorn | |
import requests | |
# Khởi tạo FastAPI | |
app = FastAPI() | |
# Thêm middleware CORS để cho phép yêu cầu từ Gradio | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Cho phép tất cả các nguồn | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Tải mô hình | |
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base') | |
# Định nghĩa mô hình dữ liệu cho yêu cầu | |
class PostEmbeddings(BaseModel): | |
type: Literal['default', 'disease', 'gte'] = Field(default='default') | |
sentences: List[str] | |
# Tạo router cho API | |
router = APIRouter( | |
prefix="/retrieval", | |
tags=["retrieval"], | |
responses={404: {"description": "Not found"}}, | |
) | |
def post_embeddings(data: PostEmbeddings): | |
embeddings = model.encode(data.sentences) | |
return { | |
'data': { | |
'embeddings': embeddings.tolist(), | |
'type': data.type | |
} | |
} | |
# Hàm Gradio để gọi API FastAPI | |
def call_api(sentences: List[str]): | |
response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences}) | |
return response.json()["data"] | |
# Tạo giao diện Gradio | |
demo = gr.Interface( | |
fn=call_api, | |
inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."), | |
outputs=gr.JSON(label="Kết quả mã hóa"), | |
title="Mô hình GTE Multilingual", | |
description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual." | |
) | |
# Khởi động server | |
if __name__ == "__main__": | |
import threading | |
# Khởi động FastAPI trong một thread riêng | |
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start() | |
# Khởi động Gradio | |
demo.launch() | |