Lord-Raven
Trying to use ONNX model.
0cca822
raw
history blame
1.18 kB
import gradio
import json
import torch
from transformers import AutoTokenizer
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from onnx_transformers import pipeline
class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://jhuhman.com"], #["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_name = "xenova/mobilebert-uncased-mnli"
classifier = pipeline(task="zero-shot-classification", model=model_name, onnx=True)
def zero_shot_classification(data_string):
print(data_string)
data = json.loads(data_string)
print(data)
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = zero_shot_classification,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch()