import torch import torch import torch.nn as nn from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel from PIL import Image import gradio as gr class VisionLanguageModel(nn.Module): def __init__(self): super(VisionLanguageModel, self).__init__() self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') self.language_model = BertModel.from_pretrained('bert-base-uncased') self.classifier = nn.Linear( self.vision_model.config.hidden_size + self.language_model.config.hidden_size, 2 # Number of classes: benign or malignant ) def forward(self, input_ids, attention_mask, pixel_values): vision_outputs = self.vision_model(pixel_values=pixel_values) vision_pooled_output = vision_outputs.pooler_output language_outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask ) language_pooled_output = language_outputs.pooler_output combined_features = torch.cat( (vision_pooled_output, language_pooled_output), dim=1 ) logits = self.classifier(combined_features) return logits # Load the model checkpoint with safer loading model = VisionLanguageModel() model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True)) model.eval() tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') def predict(image, text_input): # Preprocess the image image = feature_extractor(images=image, return_tensors="pt").pixel_values # Preprocess the text encoding = tokenizer( text_input, add_special_tokens=True, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) # Make a prediction with torch.no_grad(): outputs = model( input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'], pixel_values=image ) _, prediction = torch.max(outputs, dim=1) return "Malignant" if prediction.item() == 1 else "Benign" # Define Gradio interface with updated component syntax iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="pil", label="Upload Skin Lesion Image"), gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)") ], outputs="text", title="Skin Lesion Classification Demo", description="This model classifies skin lesions as benign or malignant based on an image and clinical information." ) iface.launch()