alinikkhah's picture
Update app.py
d289606 verified
raw
history blame
1.7 kB
import torch
from transformers import BertForSequenceClassification
# Load your BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')))
model.eval() # Set the model to evaluation mode
import gradio as gr
from transformers import BertTokenizer
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
import torch
from transformers import BertForSequenceClassification, BertTokenizer
import gradio as gr
# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.load_state_dict(torch.load('bert_model_complete.pth'))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define prediction function
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Launch the interface
if __name__ == "__main__":
interface.launch()