Spaces:
Runtime error
Runtime error
File size: 1,867 Bytes
02d2b5c 46eeeb0 02d2b5c 46eeeb0 d289606 46eeeb0 02d2b5c 46eeeb0 02d2b5c 56c36b4 02d2b5c 56c36b4 02d2b5c 56c36b4 02d2b5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
from transformers import BertForSequenceClassification
import gradio as gr
from transformers import BertTokenizer
import torch
from transformers import BertForSequenceClassification, BertTokenizer
import gradio as gr
# Load your BERT model
# Load the model architecture
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# Load the state dict without weights_only
try:
model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False)
except Exception as e:
print(f"Error loading state dict: {e}")
model.eval() # Set the model to evaluation mode
# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.load_state_dict(torch.load('bert_model_complete.pth'))
model.eval()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Define prediction function
def predict(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax().item()
return predicted_class
# Set up Gradio interface
interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification")
# Launch the interface
if __name__ == "__main__":
interface.launch()
|