Spaces:
Sleeping
Sleeping
File size: 2,079 Bytes
ebb0dae d13dd09 ebb0dae b0926d9 ebb0dae d13dd09 b0926d9 ebb0dae d13dd09 ebb0dae b0926d9 ebb0dae b0926d9 ebb0dae b0926d9 ebb0dae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from typing import List, Literal
from pydantic import BaseModel, Field
import gradio as gr
from fastapi import FastAPI, APIRouter, Request
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
import uvicorn
import requests
# Khởi tạo FastAPI
app = FastAPI()
# Thêm middleware CORS để cho phép yêu cầu từ Gradio
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Cho phép tất cả các nguồn
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Tải mô hình
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base')
# Định nghĩa mô hình dữ liệu cho yêu cầu
class PostEmbeddings(BaseModel):
type: Literal['default', 'disease', 'gte'] = Field(default='default')
sentences: List[str]
# Tạo router cho API
router = APIRouter(
prefix="/retrieval",
tags=["retrieval"],
responses={404: {"description": "Not found"}},
)
@app.post("/retrieval/embeddings")
def post_embeddings(data: PostEmbeddings):
embeddings = model.encode(data.sentences)
return {
'data': {
'embeddings': embeddings.tolist(),
'type': data.type
}
}
# Hàm Gradio để gọi API FastAPI
def call_api(sentences: List[str]):
response = requests.post("http://127.0.0.1:8000/retrieval/embeddings", json={"sentences": sentences})
return response.json()["data"]
# Tạo giao diện Gradio
demo = gr.Interface(
fn=call_api,
inputs=gr.Textbox(lines=5, placeholder="Nhập các câu ở đây, mỗi câu trên một dòng..."),
outputs=gr.JSON(label="Kết quả mã hóa"),
title="Mô hình GTE Multilingual",
description="Nhập các câu để nhận mã hóa từ mô hình GTE Multilingual."
)
# Khởi động server
if __name__ == "__main__":
import threading
# Khởi động FastAPI trong một thread riêng
threading.Thread(target=uvicorn.run, args=(app,), kwargs={"host": "0.0.0.0", "port": 8000}).start()
# Khởi động Gradio
demo.launch()
|