|
import streamlit as st
|
|
from PIL import Image
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import transforms
|
|
from torchvision.models import vit_b_16, ViT_B_16_Weights
|
|
import pandas as pd
|
|
import io
|
|
import os
|
|
import base64
|
|
from fpdf import FPDF
|
|
from sqlalchemy import create_engine
|
|
from langchain.chains import RetrievalQA
|
|
from langchain.prompts import PromptTemplate
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http.models import Distance, VectorParams
|
|
from sentence_transformers import SentenceTransformer
|
|
from langchain_community.vectorstores.pgvector import PGVector
|
|
from langchain_postgres import PGVector
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
from langchain_community.vectorstores import Qdrant
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
|
|
|
|
|
|
|
import nest_asyncio
|
|
nest_asyncio.apply()
|
|
|
|
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
|
|
|
|
|
|
|
|
|
|
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"]
|
|
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
|
|
|
|
|
|
|
|
qdrant_client = QdrantClient(
|
|
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
|
|
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
|
|
)
|
|
collection_name = "ks_collection_1.5BE"
|
|
|
|
|
|
|
|
|
|
|
|
local_embedding = HuggingFaceEmbeddings(
|
|
model_name="D:/DR/RAG/gte-Qwen2-1.5B-instruct",
|
|
model_kwargs={"trust_remote_code": True, "device":"cpu"}
|
|
)
|
|
print(" Qwen2-1.5B local embedding model loaded.")
|
|
|
|
|
|
vector_store = Qdrant(
|
|
client=qdrant_client,
|
|
collection_name=collection_name,
|
|
embeddings=local_embedding
|
|
)
|
|
retriever = vector_store.as_retriever()
|
|
|
|
'''
|
|
# === Init LLM and Vector DB ===
|
|
|
|
CONNECTION_STRING = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB"
|
|
engine = create_engine(CONNECTION_STRING)
|
|
embedding_model = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
|
|
'''
|
|
|
|
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
|
|
selected_model = st.session_state["selected_model"]
|
|
if "OpenAI" in selected_model:
|
|
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY)
|
|
elif "LLaMA" in selected_model:
|
|
st.warning("LLaMA integration is not implemented yet.")
|
|
st.stop()
|
|
elif "Gemini" in selected_model:
|
|
st.warning("Gemini integration is not implemented yet.")
|
|
st.stop()
|
|
else:
|
|
st.error("Unsupported model selected.")
|
|
st.stop()
|
|
|
|
'''
|
|
vector_store = PGVector.from_existing_index(
|
|
embedding=embedding_model,
|
|
connection=engine,
|
|
collection_name="documents"
|
|
)
|
|
'''
|
|
|
|
|
|
AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases.
|
|
You provide accurate, compassionate, and detailed explanations while using correct medical terminology.
|
|
|
|
Guidelines:
|
|
1. Symptoms - Explain in simple terms with proper medical definitions.
|
|
2. Causes - Include genetic, environmental, and lifestyle-related risk factors.
|
|
3. Medications & Treatments - Provide common prescription and over-the-counter treatments.
|
|
4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist.
|
|
5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately.
|
|
|
|
Query: {question}
|
|
Relevant Information: {context}
|
|
Answer:
|
|
"""
|
|
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
|
|
|
|
rag_chain = RetrievalQA.from_chain_type(
|
|
llm=llm,
|
|
retriever=retriever,
|
|
chain_type="stuff",
|
|
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
|
|
)
|
|
|
|
|
|
multilabel_class_names = [
|
|
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
|
|
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
|
|
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
|
|
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
|
|
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
|
|
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
|
|
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
|
|
]
|
|
|
|
multiclass_class_names = [
|
|
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
|
|
"autoimmune", "papulosquamous", "eczema", "skincancer",
|
|
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
|
|
]
|
|
|
|
|
|
class SkinViT(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(SkinViT, self).__init__()
|
|
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
|
|
in_features = self.model.heads[0].in_features
|
|
self.model.heads[0] = nn.Linear(in_features, num_classes)
|
|
def forward(self, x):
|
|
return self.model(x)
|
|
|
|
multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
|
|
multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
|
|
multilabel_model.eval()
|
|
multiclass_model.eval()
|
|
|
|
|
|
if "messages" not in st.session_state:
|
|
st.session_state.messages = []
|
|
|
|
|
|
def run_inference(image):
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5])
|
|
])
|
|
input_tensor = transform(image).unsqueeze(0)
|
|
with torch.no_grad():
|
|
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy()
|
|
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
|
|
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
|
|
predicted_single = multiclass_class_names[pred_idx]
|
|
return predicted_multi, predicted_single
|
|
|
|
|
|
def export_chat_to_pdf(messages):
|
|
pdf = FPDF()
|
|
pdf.add_page()
|
|
pdf.set_font("Arial", size=12)
|
|
for msg in messages:
|
|
role = "You" if msg["role"] == "user" else "AI"
|
|
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
|
|
buf = io.BytesIO()
|
|
pdf.output(buf)
|
|
buf.seek(0)
|
|
return buf
|
|
|
|
|
|
|
|
st.title("🧬 DermBOT — Skin AI Assistant")
|
|
st.caption(f"🧠 Using model: {selected_model}")
|
|
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
|
|
|
|
if uploaded_file:
|
|
st.image(uploaded_file, caption="Uploaded image", use_column_width=True)
|
|
image = Image.open(uploaded_file).convert("RGB")
|
|
|
|
|
|
predicted_multi, predicted_single = run_inference(image)
|
|
|
|
|
|
st.markdown(f" Skin Issues : {', '.join(predicted_multi)}")
|
|
st.markdown(f" Most Likely Diagnosis : {predicted_single}")
|
|
|
|
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
|
|
st.session_state.messages.append({"role": "user", "content": query})
|
|
|
|
with st.spinner("Analyzing the image and retrieving response..."):
|
|
response = rag_chain.invoke(query)
|
|
st.session_state.messages.append({"role": "assistant", "content": response['result']})
|
|
|
|
with st.chat_message("assistant"):
|
|
st.markdown(response['result'])
|
|
|
|
|
|
if prompt := st.chat_input("Ask a follow-up..."):
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|
with st.chat_message("user"):
|
|
st.markdown(prompt)
|
|
|
|
response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages])
|
|
st.session_state.messages.append({"role": "assistant", "content": response.content})
|
|
with st.chat_message("assistant"):
|
|
st.markdown(response.content)
|
|
|
|
|
|
if st.button("📄 Download Chat as PDF"):
|
|
pdf_file = export_chat_to_pdf(st.session_state.messages)
|
|
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf") |