Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,988 Bytes
93643d5 040c521 b0d2a02 d04fba3 e83c60c d54f118 6b9e813 b0d2a02 d04fba3 b0d2a02 6cf159b b0d2a02 31fb3f9 6b9e813 d184de8 31fb3f9 6b9e813 b0d2a02 47a0109 daac94f 47a0109 daac94f 9704577 0686401 5071704 93643d5 daac94f 0686401 93643d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio
import json
import torch
from transformers import pipeline
from transformers import AutoTokenizer
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import (
InferenceSession, SessionOptions, GraphOptimizationLevel
)
from transformers import (
TokenClassificationPipeline, AutoTokenizer, AutoModelForTokenClassification
)
class OnnxTokenClassificationPipeline(TokenClassificationPipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _forward(self, model_inputs):
"""
Forward pass through the model. This method is not to be called by the user directly and is only used
by the pipeline to perform the actual predictions.
This is where we will define the actual process to do inference with the ONNX model and the session created
before.
"""
# This comes from the original implementation of the pipeline
special_tokens_mask = model_inputs.pop("special_tokens_mask")
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
inputs = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} # dict of numpy arrays
outputs_name = session.get_outputs()[0].name # get the name of the output tensor
logits = session.run(output_names=[outputs_name], input_feed=inputs)[0] # run the session
logits = torch.tensor(logits) # convert to torch tensor to be compatible with the original implementation
return {
"logits": logits,
"special_tokens_mask": special_tokens_mask,
"offset_mapping": offset_mapping,
"sentence": sentence,
**model_inputs,
}
# We need to override the preprocess method because the onnx model is waiting for the attention masks as inputs
# along with the embeddings.
def preprocess(self, sentence, offset_mapping=None):
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
model_inputs = self.tokenizer(
sentence,
return_attention_mask=True, # This is the only difference from the original implementation
return_tensors=self.framework,
truncation=truncation,
return_special_tokens_mask=True,
return_offsets_mapping=self.tokenizer.is_fast,
)
if offset_mapping:
model_inputs["offset_mapping"] = offset_mapping
model_inputs["sentence"] = sentence
return model_inputs
# CORS Config
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["https://jhuhman.com"], #["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
session = InferenceSession("onnx/model.onnx", sess_options=options, providers=["CPUExecutionProvider"])
session.disable_fallback()
model_name = "xenova/mobilebert-uncased-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
classifier = OnnxTokenClassificationPipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, framework="pt", aggregation_strategy="simple")
def zero_shot_classification(data_string):
print(data_string)
data = json.loads(data_string)
print(data)
results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
response_string = json.dumps(results)
return response_string
gradio_interface = gradio.Interface(
fn = zero_shot_classification,
inputs = gradio.Textbox(label="JSON Input"),
outputs = gradio.Textbox()
)
gradio_interface.launch() |