Lord-Raven
Blocking requests from other origins.
fd25b82
raw
history blame
1.89 kB
import gradio
import json
import torch
from transformers import AutoTokenizer
from transformers import pipeline
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from optimum.onnxruntime import ORTModelForSequenceClassification
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://jhuhman.com"], #["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_name = "xenova/deberta-v3-base-tasksource-nli" # "xenova/mobilebert-uncased-mnli"
file_name = "onnx/model_quantized.onnx"
tokenizer_name = "sileod/deberta-v3-base-tasksource-nli" # "typeform/mobilebert-uncased-mnli"
model = ORTModelForSequenceClassification.from_pretrained(model_name, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, model_max_length=512)
# file = cached_download("https://huggingface.co/" + model_name + "")
# sess = InferenceSession(file)
classifier = pipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer)
def zero_shot_classification(data_string, request: gradio.Request):
if request:
print("Request headers dictionary:", request.headers)
if !(request.origin in ["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win", "https://jhuhman-statosphere-backend.hf.space"])
return ""
print(data_string)
data = json.loads(data_string)
print(data)
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = zero_shot_classification,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch()