santhoshraghu commited on
Commit
6e6ac11
·
verified ·
1 Parent(s): 8a4905f

Upload dermo.py

Browse files

Streamlit UI of DermBOT V-3 ( Pre-trained Vit - SKINCON/Dermnet & RAG v_3( Qwen & Qdrant)

Files changed (1) hide show
  1. dermo.py +223 -0
dermo.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import transforms
6
+ from torchvision.models import vit_b_16, ViT_B_16_Weights
7
+ import pandas as pd
8
+ import io
9
+ import os
10
+ import base64
11
+ from fpdf import FPDF
12
+ from sqlalchemy import create_engine
13
+ from langchain.chains import RetrievalQA
14
+ from langchain.prompts import PromptTemplate
15
+ from qdrant_client import QdrantClient
16
+ from qdrant_client.http.models import Distance, VectorParams
17
+ from sentence_transformers import SentenceTransformer
18
+ from langchain_community.vectorstores.pgvector import PGVector
19
+ from langchain_postgres import PGVector
20
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
21
+ from langchain_community.vectorstores import Qdrant
22
+ from langchain_community.embeddings import HuggingFaceEmbeddings
23
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
24
+
25
+
26
+
27
+ import nest_asyncio
28
+ nest_asyncio.apply()
29
+
30
+ st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
31
+
32
+ #os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
33
+
34
+ # === Model Selection ===
35
+ available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
36
+ st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
37
+
38
+
39
+ # === Qdrant DB Setup ===
40
+ qdrant_client = QdrantClient(
41
+ url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
42
+ api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
43
+ )
44
+ collection_name = "ks_collection_1.5BE"
45
+ #embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
46
+ #embedding_model.max_seq_length = 8192
47
+ #local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
48
+
49
+
50
+ local_embedding = HuggingFaceEmbeddings(
51
+ model_name="D:/DR/RAG/gte-Qwen2-1.5B-instruct",
52
+ model_kwargs={"trust_remote_code": True, "device":"cpu"}
53
+ )
54
+ print(" Qwen2-1.5B local embedding model loaded.")
55
+
56
+
57
+ vector_store = Qdrant(
58
+ client=qdrant_client,
59
+ collection_name=collection_name,
60
+ embeddings=local_embedding
61
+ )
62
+ retriever = vector_store.as_retriever()
63
+
64
+ '''
65
+ # === Init LLM and Vector DB ===
66
+
67
+ CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
68
+ engine = create_engine(CONNECTION_STRING)
69
+ embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
70
+ '''
71
+ # Dynamically initialize LLM based on selection
72
+ OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
73
+ selected_model = st.session_state["selected_model"]
74
+ if "OpenAI" in selected_model:
75
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY)
76
+ elif "LLaMA" in selected_model:
77
+ st.warning("LLaMA integration is not implemented yet.")
78
+ st.stop()
79
+ elif "Gemini" in selected_model:
80
+ st.warning("Gemini integration is not implemented yet.")
81
+ st.stop()
82
+ else:
83
+ st.error("Unsupported model selected.")
84
+ st.stop()
85
+
86
+ '''
87
+ vector_store = PGVector.from_existing_index(
88
+ embedding=embedding_model,
89
+ connection=engine,
90
+ collection_name="documents"
91
+ )
92
+ '''
93
+ #retriever = vector_store.as_retriever()
94
+
95
+ AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
96
+ You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
97
+
98
+ Guidelines:
99
+ 1. Symptoms - Explain in simple terms with proper medical definitions.
100
+ 2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
101
+ 3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
102
+ 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
103
+ 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
104
+
105
+ Query: {question}
106
+ Relevant Information: {context}
107
+ Answer:
108
+ """
109
+ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
110
+
111
+ rag_chain = RetrievalQA.from_chain_type(
112
+ llm=llm,
113
+ retriever=retriever,
114
+ chain_type="stuff",
115
+ chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
116
+ )
117
+
118
+ # === Class Names ===
119
+ multilabel_class_names = [
120
+ "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
121
+ "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
122
+ "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
123
+ "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
124
+ "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
125
+ "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
126
+ "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
127
+ ]
128
+
129
+ multiclass_class_names = [
130
+ "systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
131
+ "autoimmune", "papulosquamous", "eczema", "skincancer",
132
+ "benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
133
+ ]
134
+
135
+ # === Load Models ===
136
+ class SkinViT(nn.Module):
137
+ def __init__(self, num_classes):
138
+ super(SkinViT, self).__init__()
139
+ self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
140
+ in_features = self.model.heads[0].in_features
141
+ self.model.heads[0] = nn.Linear(in_features, num_classes)
142
+ def forward(self, x):
143
+ return self.model(x)
144
+
145
+ multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
146
+ multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
147
+ multilabel_model.eval()
148
+ multiclass_model.eval()
149
+
150
+ # === Session Init ===
151
+ if "messages" not in st.session_state:
152
+ st.session_state.messages = []
153
+
154
+ # === Image Processing Function ===
155
+ def run_inference(image):
156
+ transform = transforms.Compose([
157
+ transforms.Resize((224, 224)),
158
+ transforms.ToTensor(),
159
+ transforms.Normalize([0.5], [0.5])
160
+ ])
161
+ input_tensor = transform(image).unsqueeze(0)
162
+ with torch.no_grad():
163
+ probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
164
+ predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
165
+ pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
166
+ predicted_single = multiclass_class_names[pred_idx]
167
+ return predicted_multi, predicted_single
168
+
169
+ # === PDF Export ===
170
+ def export_chat_to_pdf(messages):
171
+ pdf = FPDF()
172
+ pdf.add_page()
173
+ pdf.set_font("Arial", size=12)
174
+ for msg in messages:
175
+ role = "You" if msg["role"] == "user" else "AI"
176
+ pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
177
+ buf = io.BytesIO()
178
+ pdf.output(buf)
179
+ buf.seek(0)
180
+ return buf
181
+
182
+ # === App UI ===
183
+
184
+ st.title("🧬 DermBOT — Skin AI Assistant")
185
+ st.caption(f"🧠 Using model: {selected_model}")
186
+ uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
187
+
188
+ if uploaded_file:
189
+ st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
190
+ image = Image.open(uploaded_file).convert("RGB")
191
+
192
+
193
+ predicted_multi, predicted_single = run_inference(image)
194
+
195
+ # Show predictions clearly to the user
196
+ st.markdown(f" Skin Issues : {', '.join(predicted_multi)}")
197
+ st.markdown(f" Most Likely Diagnosis : {predicted_single}")
198
+
199
+ query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
200
+ st.session_state.messages.append({"role": "user", "content": query})
201
+
202
+ with st.spinner("Analyzing the image and retrieving response..."):
203
+ response = rag_chain.invoke(query)
204
+ st.session_state.messages.append({"role": "assistant", "content": response['result']})
205
+
206
+ with st.chat_message("assistant"):
207
+ st.markdown(response['result'])
208
+
209
+ # === Chat Interface ===
210
+ if prompt := st.chat_input("Ask a follow-up..."):
211
+ st.session_state.messages.append({"role": "user", "content": prompt})
212
+ with st.chat_message("user"):
213
+ st.markdown(prompt)
214
+
215
+ response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
216
+ st.session_state.messages.append({"role": "assistant", "content": response.content})
217
+ with st.chat_message("assistant"):
218
+ st.markdown(response.content)
219
+
220
+ # === PDF Button ===
221
+ if st.button("📄 Download Chat as PDF"):
222
+ pdf_file = export_chat_to_pdf(st.session_state.messages)
223
+ st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")