File size: 3,988 Bytes
93643d5
040c521
b0d2a02
d04fba3
e83c60c
d54f118
6b9e813
b0d2a02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d04fba3
b0d2a02
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf159b
b0d2a02
 
 
 
 
31fb3f9
 
 
6b9e813
d184de8
31fb3f9
6b9e813
 
 
 
b0d2a02
 
 
 
 
 
 
 
 
 
 
 
 
 
47a0109
daac94f
47a0109
daac94f
9704577
0686401
 
5071704
93643d5
 
daac94f
0686401
93643d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio
import json
import torch
from transformers import pipeline
from transformers import AutoTokenizer
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from onnxruntime import (
    InferenceSession, SessionOptions, GraphOptimizationLevel
)
from transformers import (
    TokenClassificationPipeline, AutoTokenizer, AutoModelForTokenClassification
)

class OnnxTokenClassificationPipeline(TokenClassificationPipeline):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    
    def _forward(self, model_inputs):
        """
        Forward pass through the model. This method is not to be called by the user directly and is only used
        by the pipeline to perform the actual predictions.
        This is where we will define the actual process to do inference with the ONNX model and the session created
        before.
        """

        # This comes from the original implementation of the pipeline
        special_tokens_mask = model_inputs.pop("special_tokens_mask")
        offset_mapping = model_inputs.pop("offset_mapping", None)
        sentence = model_inputs.pop("sentence")

        inputs = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()} # dict of numpy arrays
        outputs_name = session.get_outputs()[0].name # get the name of the output tensor

        logits = session.run(output_names=[outputs_name], input_feed=inputs)[0] # run the session
        logits = torch.tensor(logits) # convert to torch tensor to be compatible with the original implementation

        return {
            "logits": logits,
            "special_tokens_mask": special_tokens_mask,
            "offset_mapping": offset_mapping,
            "sentence": sentence,
            **model_inputs,
        }

    # We need to override the preprocess method because the onnx model is waiting for the attention masks as inputs
    # along with the embeddings.
    def preprocess(self, sentence, offset_mapping=None):
        truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
        model_inputs = self.tokenizer(
            sentence,
            return_attention_mask=True, # This is the only difference from the original implementation
            return_tensors=self.framework,
            truncation=truncation,
            return_special_tokens_mask=True,
            return_offsets_mapping=self.tokenizer.is_fast,
        )
        if offset_mapping:
            model_inputs["offset_mapping"] = offset_mapping

        model_inputs["sentence"] = sentence

        return model_inputs

# CORS Config
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://jhuhman.com"], #["https://statosphere-3704059fdd7e.c5v4v4jx6pq5.win"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

session = InferenceSession("onnx/model.onnx", sess_options=options, providers=["CPUExecutionProvider"])

session.disable_fallback()

model_name = "xenova/mobilebert-uncased-mnli"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

classifier = OnnxTokenClassificationPipeline(task="zero-shot-classification", model=model, tokenizer=tokenizer, framework="pt", aggregation_strategy="simple")

def zero_shot_classification(data_string):
    print(data_string)
    data = json.loads(data_string)
    print(data)
    results = classifier(data['sequence'], candidate_labels=data['candidate_labels'], hypothesis_template=data['hypothesis_template'], multi_label=data['multi_label'])
    response_string = json.dumps(results)
    return response_string

gradio_interface = gradio.Interface(
    fn = zero_shot_classification,
    inputs = gradio.Textbox(label="JSON Input"),
    outputs = gradio.Textbox()
)
gradio_interface.launch()