santhoshraghu commited on
Commit
9c8eb77
·
verified ·
1 Parent(s): 039752f

Create gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +156 -0
gradio_app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dermbot_gradio_app.py
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
8
+ from huggingface_hub import hf_hub_download
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.prompts import PromptTemplate
11
+ from qdrant_client import QdrantClient
12
+ from langchain_community.vectorstores import Qdrant
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ from langchain_openai import ChatOpenAI
15
+ import os
16
+ import io
17
+ from fpdf import FPDF
18
+
19
+ # === Constants ===
20
+ multilabel_class_names = [
21
+ "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
22
+ "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
23
+ "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
24
+ "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
25
+ "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
26
+ "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
27
+ "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
28
+ ]
29
+
30
+ multiclass_class_names = [
31
+ "systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
32
+ "autoimmune", "papulosquamous", "eczema", "skincancer",
33
+ "benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
34
+ ]
35
+
36
+ # === Models ===
37
+ class SkinViT(nn.Module):
38
+ def __init__(self, num_classes):
39
+ super().__init__()
40
+ self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
41
+ in_features = self.model.heads.head.in_features
42
+ self.model.heads.head = nn.Linear(in_features, num_classes)
43
+
44
+ def forward(self, x):
45
+ return self.model(x)
46
+
47
+ class DermNetViT(nn.Module):
48
+ def __init__(self, num_classes):
49
+ super().__init__()
50
+ self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
51
+ in_features = self.model.heads[0].in_features
52
+ self.model.heads = nn.Sequential(
53
+ nn.Linear(in_features, 1024),
54
+ nn.ReLU(),
55
+ nn.Linear(1024, num_classes)
56
+ )
57
+
58
+ def forward(self, x):
59
+ return self.model(x)
60
+
61
+ # === Load Model State Dicts ===
62
+ multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
63
+ multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
64
+
65
+ multilabel_model = SkinViT(num_classes=len(multilabel_class_names))
66
+ multiclass_model = DermNetViT(num_classes=len(multiclass_class_names))
67
+ multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu"))
68
+ multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu"))
69
+ multilabel_model.eval()
70
+ multiclass_model.eval()
71
+
72
+ # === RAG Setup ===
73
+ os.environ["OPENAI_API_KEY"] = "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"
74
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.2)
75
+
76
+ qdrant_client = QdrantClient(
77
+ url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
78
+ api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
79
+ )
80
+
81
+ local_embedding = HuggingFaceEmbeddings(
82
+ model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct",
83
+ model_kwargs={"trust_remote_code": True, "device": "cpu"}
84
+ )
85
+
86
+ vector_store = Qdrant(
87
+ client=qdrant_client,
88
+ collection_name="ks_collection_1.5BE",
89
+ embeddings=local_embedding
90
+ )
91
+
92
+ retriever = vector_store.as_retriever()
93
+
94
+
95
+ AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
96
+ You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
97
+
98
+ Guidelines:
99
+ 1. Symptoms - Explain in simple terms with proper medical definitions.
100
+ 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
101
+ 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
102
+ 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
103
+ 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
104
+
105
+ Query: {question}
106
+ Relevant Information: {context}
107
+ Answer:
108
+ """
109
+
110
+
111
+
112
+ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
113
+
114
+ rag_chain = RetrievalQA.from_chain_type(
115
+ llm=llm,
116
+ retriever=retriever,
117
+ chain_type="stuff",
118
+ chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
119
+ )
120
+
121
+ # === Inference ===
122
+ def run_diagnosis(image):
123
+ transform = transforms.Compose([
124
+ transforms.Resize((224, 224)),
125
+ transforms.ToTensor(),
126
+ transforms.Normalize([0.5], [0.5])
127
+ ])
128
+ input_tensor = transform(image).unsqueeze(0)
129
+ with torch.no_grad():
130
+ probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
131
+ predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
132
+ pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
133
+ predicted_single = multiclass_class_names[pred_idx]
134
+ return predicted_multi, predicted_single
135
+
136
+ # === Chat Function ===
137
+ def chat_with_bot(image, history=[]):
138
+ predicted_multi, predicted_single = run_diagnosis(image)
139
+ query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
140
+ response = rag_chain.invoke(query)["result"]
141
+ history.append((f"User: {query}", f"AI: {response}"))
142
+ return response, history
143
+
144
+ # === Gradio App ===
145
+ with gr.Blocks() as demo:
146
+ gr.Markdown("# 🧬 DermBOT — Skin AI Assistant")
147
+ chatbot = gr.Chatbot()
148
+ img_input = gr.Image(type="pil")
149
+ output_text = gr.Textbox(label="DermBOT Response")
150
+ btn = gr.Button("Analyze & Diagnose")
151
+
152
+ state = gr.State([])
153
+ btn.click(fn=chat_with_bot, inputs=[img_input, state], outputs=[output_text, state])
154
+
155
+ if __name__ == "__main__":
156
+ demo.launch()