Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import streamlit | |
| import torch | |
| from transformers import AutoModelForSequenceClassification,AutoTokenizer | |
| import numpy as np | |
| import plotly.express as px | |
| model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/') | |
| tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/') | |
| def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer): | |
| try: | |
| labels=labels.split(',') | |
| labels=[l.lower() for l in labels] | |
| except: | |
| raise Exception("please pass atleast 2 labels to classify") | |
| premise=premise.lower() | |
| labels_prob=[] | |
| for l in labels: | |
| hypothesis= f'this is an example of {l}' | |
| input = tokenizer.encode(premise,hypothesis, | |
| return_tensors='pt', | |
| truncation_strategy='only_first') | |
| output = model(input) | |
| entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() | |
| labels_prob.append(entail_contra_prob) | |
| labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob] | |
| df=pd.DataFrame({'labels':labels, | |
| 'Probability':labels_prob_norm}) | |
| fig=px.bar(x='Probability', | |
| y='labels', | |
| text='Probability', | |
| data_frame=df, | |
| title='Zero Shot Normalized Probabilities') | |
| return fig | |
| # zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable', | |
| # labels='science, sports, museum') | |