import torch from transformers import BertForSequenceClassification import gradio as gr from transformers import BertTokenizer import torch from transformers import BertForSequenceClassification, BertTokenizer import gradio as gr # Load your BERT model # Load the model architecture model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) # Load the state dict without weights_only try: model.load_state_dict(torch.load('bert_model_complete.pth', map_location=torch.device('cpu')), strict=False) except Exception as e: print(f"Error loading state dict: {e}") model.eval() # Set the model to evaluation mode # Load the tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') def predict(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = logits.argmax().item() return predicted_class # Set up the Gradio interface interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification") # Load model and tokenizer model = BertForSequenceClassification.from_pretrained('bert-base-uncased') model.load_state_dict(torch.load('bert_model_complete.pth')) model.eval() tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Define prediction function def predict(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class = logits.argmax().item() return predicted_class # Set up Gradio interface interface = gr.Interface(fn=predict, inputs="text", outputs="label", title="BERT Text Classification") # Launch the interface if __name__ == "__main__": interface.launch()