File size: 4,018 Bytes
5c9bc3a
 
a1ee699
5c9bc3a
 
 
9e11359
5c9bc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1ee699
5c9bc3a
 
 
a1ee699
5c9bc3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e11359
 
d337955
9e11359
d337955
 
 
 
 
 
 
 
 
 
 
 
 
9e11359
 
 
d337955
9e11359
 
 
 
 
 
 
 
 
 
 
 
 
d337955
9e11359
 
 
 
 
 
 
 
 
 
 
 
 
a1ee699
d337955
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel
from PIL import Image
import gradio as gr

# Model definition and setup
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super(VisionLanguageModel, self).__init__()
        self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.language_model = BertModel.from_pretrained('bert-base-uncased')
        self.classifier = nn.Linear(
            self.vision_model.config.hidden_size + self.language_model.config.hidden_size,
            2  # Number of classes: benign or malignant
        )

    def forward(self, input_ids, attention_mask, pixel_values):
        vision_outputs = self.vision_model(pixel_values=pixel_values)
        vision_pooled_output = vision_outputs.pooler_output

        language_outputs = self.language_model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        language_pooled_output = language_outputs.pooler_output

        combined_features = torch.cat(
            (vision_pooled_output, language_pooled_output),
            dim=1
        )

        logits = self.classifier(combined_features)
        return logits

model = VisionLanguageModel()
model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True))
model.eval()

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

def predict(image, text_input):
    image = feature_extractor(images=image, return_tensors="pt").pixel_values
    encoding = tokenizer(
        text_input,
        add_special_tokens=True,
        max_length=256,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    with torch.no_grad():
        outputs = model(
            input_ids=encoding['input_ids'],
            attention_mask=encoding['attention_mask'],
            pixel_values=image
        )
    _, prediction = torch.max(outputs, dim=1)
    return prediction.item()  # 1 for Malignant, 0 for Benign

# Enhanced UI with background image and color-coded prediction display
with gr.Blocks(css="""
    body { 
        background: url('./skin_cancer_detection/melanoma.png') no-repeat center center fixed; 
        background-size: cover;
    }
    .benign, .malignant { 
        background-color: white; 
        border: 1px solid lightgray; 
        padding: 10px; 
        border-radius: 5px; 
    }
    .benign.correct, .malignant.correct { 
        background-color: lightgreen; 
    }
""") as demo:
    gr.Markdown(
        """
        # 🩺 SKIN LESION CLASSIFICATION
        Upload an image of a skin lesion and provide clinical details to get a prediction of benign or malignant.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Skin Lesion Image")
            text_input = gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)")

        with gr.Column(scale=1):
            benign_output = gr.HTML("<div class='benign'>Benign</div>")
            malignant_output = gr.HTML("<div class='malignant'>Malignant</div>")
            gr.Markdown("## Example:")
            example_image = gr.Image(value="./skin_cancer_detection/Unknown-4.png", interactive=False)

    def display_prediction(image, text_input):
        prediction = predict(image, text_input)
        benign_html = "<div class='benign{}'>Benign</div>".format(" correct" if prediction == 0 else "")
        malignant_html = "<div class='malignant{}'>Malignant</div>".format(" correct" if prediction == 1 else "")
        return benign_html, malignant_html

    # Submit button and prediction outputs
    submit_btn = gr.Button("Get Prediction")
    submit_btn.click(display_prediction, inputs=[image_input, text_input], outputs=[benign_output, malignant_output])

demo.launch()