Create gradio_app.py
Browse files- gradio_app.py +156 -0
gradio_app.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dermbot_gradio_app.py
|
2 |
+
import gradio as gr
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from langchain.chains import RetrievalQA
|
10 |
+
from langchain.prompts import PromptTemplate
|
11 |
+
from qdrant_client import QdrantClient
|
12 |
+
from langchain_community.vectorstores import Qdrant
|
13 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
14 |
+
from langchain_openai import ChatOpenAI
|
15 |
+
import os
|
16 |
+
import io
|
17 |
+
from fpdf import FPDF
|
18 |
+
|
19 |
+
# === Constants ===
|
20 |
+
multilabel_class_names = [
|
21 |
+
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
|
22 |
+
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
|
23 |
+
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
|
24 |
+
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
|
25 |
+
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
|
26 |
+
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
|
27 |
+
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
|
28 |
+
]
|
29 |
+
|
30 |
+
multiclass_class_names = [
|
31 |
+
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
|
32 |
+
"autoimmune", "papulosquamous", "eczema", "skincancer",
|
33 |
+
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
|
34 |
+
]
|
35 |
+
|
36 |
+
# === Models ===
|
37 |
+
class SkinViT(nn.Module):
|
38 |
+
def __init__(self, num_classes):
|
39 |
+
super().__init__()
|
40 |
+
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
41 |
+
in_features = self.model.heads.head.in_features
|
42 |
+
self.model.heads.head = nn.Linear(in_features, num_classes)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
return self.model(x)
|
46 |
+
|
47 |
+
class DermNetViT(nn.Module):
|
48 |
+
def __init__(self, num_classes):
|
49 |
+
super().__init__()
|
50 |
+
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
|
51 |
+
in_features = self.model.heads[0].in_features
|
52 |
+
self.model.heads = nn.Sequential(
|
53 |
+
nn.Linear(in_features, 1024),
|
54 |
+
nn.ReLU(),
|
55 |
+
nn.Linear(1024, num_classes)
|
56 |
+
)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
return self.model(x)
|
60 |
+
|
61 |
+
# === Load Model State Dicts ===
|
62 |
+
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
|
63 |
+
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
|
64 |
+
|
65 |
+
multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
|
66 |
+
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
|
67 |
+
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
|
68 |
+
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
|
69 |
+
multilabel_model.eval()
|
70 |
+
multiclass_model.eval()
|
71 |
+
|
72 |
+
# === RAG Setup ===
|
73 |
+
os.environ["OPENAI_API_KEY"] = "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"
|
74 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.2)
|
75 |
+
|
76 |
+
qdrant_client = QdrantClient(
|
77 |
+
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
78 |
+
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
79 |
+
)
|
80 |
+
|
81 |
+
local_embedding = HuggingFaceEmbeddings(
|
82 |
+
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
|
83 |
+
model_kwargs={"trust_remote_code": True, "device": "cpu"}
|
84 |
+
)
|
85 |
+
|
86 |
+
vector_store = Qdrant(
|
87 |
+
client=qdrant_client,
|
88 |
+
collection_name="ks_collection_1.5BE",
|
89 |
+
embeddings=local_embedding
|
90 |
+
)
|
91 |
+
|
92 |
+
retriever = vector_store.as_retriever()
|
93 |
+
|
94 |
+
|
95 |
+
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
96 |
+
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
97 |
+
|
98 |
+
Guidelines:
|
99 |
+
1. Symptoms - Explain in simple terms with proper medical definitions.
|
100 |
+
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
101 |
+
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
102 |
+
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
103 |
+
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
104 |
+
|
105 |
+
Query: {question}
|
106 |
+
Relevant Information: {context}
|
107 |
+
Answer:
|
108 |
+
"""
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
113 |
+
|
114 |
+
rag_chain = RetrievalQA.from_chain_type(
|
115 |
+
llm=llm,
|
116 |
+
retriever=retriever,
|
117 |
+
chain_type="stuff",
|
118 |
+
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
119 |
+
)
|
120 |
+
|
121 |
+
# === Inference ===
|
122 |
+
def run_diagnosis(image):
|
123 |
+
transform = transforms.Compose([
|
124 |
+
transforms.Resize((224, 224)),
|
125 |
+
transforms.ToTensor(),
|
126 |
+
transforms.Normalize([0.5], [0.5])
|
127 |
+
])
|
128 |
+
input_tensor = transform(image).unsqueeze(0)
|
129 |
+
with torch.no_grad():
|
130 |
+
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
|
131 |
+
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
|
132 |
+
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
|
133 |
+
predicted_single = multiclass_class_names[pred_idx]
|
134 |
+
return predicted_multi, predicted_single
|
135 |
+
|
136 |
+
# === Chat Function ===
|
137 |
+
def chat_with_bot(image, history=[]):
|
138 |
+
predicted_multi, predicted_single = run_diagnosis(image)
|
139 |
+
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
|
140 |
+
response = rag_chain.invoke(query)["result"]
|
141 |
+
history.append((f"User: {query}", f"AI: {response}"))
|
142 |
+
return response, history
|
143 |
+
|
144 |
+
# === Gradio App ===
|
145 |
+
with gr.Blocks() as demo:
|
146 |
+
gr.Markdown("# 🧬 DermBOT — Skin AI Assistant")
|
147 |
+
chatbot = gr.Chatbot()
|
148 |
+
img_input = gr.Image(type="pil")
|
149 |
+
output_text = gr.Textbox(label="DermBOT Response")
|
150 |
+
btn = gr.Button("Analyze & Diagnose")
|
151 |
+
|
152 |
+
state = gr.State([])
|
153 |
+
btn.click(fn=chat_with_bot, inputs=[img_input, state], outputs=[output_text, state])
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
demo.launch()
|