DermBOT / app.py
santhoshraghu's picture
Update app.py
6457643 verified
import streamlit as st
from PIL import Image
import torch
import cohere
import torch.nn as nn
from torchvision import transforms
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights
import pandas as pd
from huggingface_hub import hf_hub_download
from langchain_huggingface import HuggingFaceEmbeddings
import io
import os
import base64
from fpdf import FPDF
from sqlalchemy import create_engine
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from sentence_transformers import SentenceTransformer
#from langchain_community.vectorstores.pgvector import PGVector
#from langchain_postgres import PGVector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Qdrant
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import SentenceTransformerEmbeddings
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
import nest_asyncio
torch.cuda.empty_cache()
nest_asyncio.apply()
co = cohere.Client(st.secrets["COHERE_API_KEY"])
st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered")
# === Model Selection ===
available_models = ["GPT-4o", "LLaMA 4 Maverick", "Gemini 2.5 Pro","All"]
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models)
# === Qdrant DB Setup ===
qdrant_client = QdrantClient(
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io",
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q"
)
collection_name = "ks_collection_1.5BE"
#embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True)
#embedding_model.max_seq_length = 8192
#local_embedding = SentenceTransformerEmbeddings(model=embedding_model)
device = "cuda" if torch.cuda.is_available() else "cpu"
def get_safe_embedding_model():
model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
try:
print("Trying to load embedding model on CUDA...")
embedding = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={
"trust_remote_code": True,
"device": "cuda"
}
)
print("Loaded embedding model on GPU.")
return embedding
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print("CUDA OOM. Falling back to CPU.")
else:
print(" Error loading model on CUDA:", str(e))
print("Loading embedding model on CPU...")
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs={
"trust_remote_code": True,
"device": "cpu"
}
)
# Replace your old local_embedding line with this
local_embedding = get_safe_embedding_model()
print(" Qwen2-1.5B local embedding model loaded.")
vector_store = Qdrant(
client=qdrant_client,
collection_name=collection_name,
embeddings=local_embedding
)
retriever = vector_store.as_retriever()
pair_ranker = pipeline(
"text-classification",
model="llm-blender/PairRM",
tokenizer="llm-blender/PairRM",
return_all_scores=True
)
gen_fuser = pipeline(
"text-generation",
model="llm-blender/gen_fuser_3b",
tokenizer="llm-blender/gen_fuser_3b",
max_length=2048,
do_sample=False
)
#selected_model = st.session_state["selected_model"]
if "OpenAI" in selected_model:
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
elif "LLaMA" in selected_model:
from groq import Groq
client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml`
def get_llama_response(prompt):
completion = client.chat.completions.create(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=1,
max_completion_tokens=1024,
top_p=1,
stream=False
)
return completion.choices[0].message.content
llm = get_llama_response # use this in place of llm.invoke()
elif "Gemini" in selected_model:
import google.generativeai as genai
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml`
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
def get_gemini_response(prompt):
response = gemini_model.generate_content(prompt)
return response.text
llm = get_gemini_response
elif "All" in selected_model:
from groq import Groq
import google.generativeai as genai
genai.configure(api_key=st.secrets["GEMINI_API_KEY"])
def get_all_model_responses(prompt):
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke(
[{"role": "system", "content": prompt}]).content
gemini = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
gemini_resp = gemini.generate_content(prompt).text
llama = Groq(api_key=st.secrets["GROQ_API_KEY"])
llama_resp = llama.chat.completions.create(
model="meta-llama/llama-4-maverick-17b-128e-instruct",
messages=[{"role": "user", "content": prompt}],
temperature=1, max_completion_tokens=1024, top_p=1, stream=False
).choices[0].message.content
return [openai_resp, gemini_resp, llama_resp]
def rank_and_fuse(prompt, responses):
ranked = [(resp, pair_ranker(f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses]
ranked.sort(key=lambda x: x[1], reverse=True)
fusion_input = "\n\n".join([f"[Answer {i+1}]: {ans}" for i, (ans, _) in enumerate(ranked)])
return gen_fuser(f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text']
else:
st.error("Unsupported model selected.")
st.stop()
#retriever = vector_store.as_retriever()
AI_PROMPT_TEMPLATE = """
You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision.
Your goal is to respond like a well-informed human expertβ€”balancing professionalism with warmth and reassurance.
When crafting responses:
- Begin with a clear, engaging summary of the condition or concern.
- Use short paragraphs for readability.
- Include bullet points or numbered lists where appropriate.
- Avoid overly technical terms unless explained simply.
- End with a helpful next step, such as lifestyle advice or when to see a doctor.
🩺 Response Structure:
1. **Overview** β€” Briefly introduce the condition or concern.
2. **Common Symptoms** β€” Describe noticeable signs in simple terms.
3. **Causes & Risk Factors** β€” Include genetic, lifestyle, and environmental aspects.
4. **Treatment Options** β€” Outline common OTC and prescription treatments.
5. **When to Seek Help** β€” Warn about symptoms that require urgent care.
Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately.
---
Query: {question}
Relevant Context: {context}
Your Response:
"""
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"])
#rag_chain = RetrievalQA.from_chain_type(
# llm=llm,
# retriever=retriever,
# chain_type="stuff",
# chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"}
#)
# === Class Names ===
multilabel_class_names = [
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch",
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae",
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis",
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped",
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow",
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma",
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst"
]
multiclass_class_names = [
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light",
"autoimmune", "papulosquamous", "eczema", "skincancer",
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections"
]
# === Load Models ===
class SkinViT(nn.Module):
def __init__(self, num_classes):
super(SkinViT, self).__init__()
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
in_features = self.model.heads.head.in_features
self.model.heads.head = nn.Linear(in_features, num_classes)
def forward(self, x):
return self.model(x)
class DermNetViT(nn.Module):
def __init__(self, num_classes):
super(DermNetViT, self).__init__()
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT)
in_features = self.model.heads[0].in_features
self.model.heads[0] = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(in_features, num_classes)
)
def forward(self, x):
return self.model(x)
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu')
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu')
# === Load Model State Dicts ===
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth")
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth")
def load_model_with_fallback(model_class, weight_path, num_classes, model_name):
try:
print(f"πŸ” Loading {model_name} on GPU...")
model = model_class(num_classes)
model.load_state_dict(torch.load(weight_path, map_location="cuda"))
model.to("cuda")
print(f"βœ… {model_name} loaded on GPU.")
return model
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print(f"⚠️ {model_name} OOM. Falling back to CPU.")
else:
print(f"❌ Error loading {model_name} on CUDA: {e}")
print(f"πŸ”„ Loading {model_name} on CPU...")
model = model_class(num_classes)
model.load_state_dict(torch.load(weight_path, map_location="cpu"))
model.to("cpu")
return model
# Load both models with fallback
multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT")
multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names), "DermNetViT")
multilabel_model.eval()
multiclass_model.eval()
# === Session Init ===
if "messages" not in st.session_state:
st.session_state.messages = []
# === Image Processing Function ===
def run_inference(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
input_tensor = transform(image).unsqueeze(0)
# Automatically match model device (GPU or CPU)
model_device = next(multilabel_model.parameters()).device
input_tensor = input_tensor.to(model_device)
with torch.no_grad():
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().cpu().numpy()
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item()
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5]
predicted_single = multiclass_class_names[pred_idx]
return predicted_multi, predicted_single
# === PDF Export ===
def export_chat_to_pdf(messages):
pdf = FPDF()
pdf.add_page()
pdf.set_font("Arial", size=12)
for msg in messages:
role = "You" if msg["role"] == "user" else "AI"
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n")
buf = io.BytesIO()
pdf.output(buf)
buf.seek(0)
return buf
#Reranker utility
def rerank_with_cohere(query, documents, top_n=5):
if not documents:
return []
raw_texts = [doc.page_content for doc in documents]
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5")
return [documents[result.index] for result in results]
# Final answer generation using reranked context
def get_reranked_response(query):
docs = retriever.get_relevant_documents(query)
reranked_docs = rerank_with_cohere(query, docs)
context = "\n\n".join([doc.page_content for doc in reranked_docs])
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
if selected_model == "All":
responses = get_all_model_responses(prompt)
fused = rank_and_fuse(prompt, responses)
return type("Obj", (), {"content": fused})
if callable(llm):
return type("Obj", (), {"content": llm(prompt)})
else:
return llm.invoke([{"role": "system", "content": prompt}])
# === App UI ===
st.title("🧬 DermBOT β€” Skin AI Assistant")
st.caption(f"🧠 Using model: {selected_model}")
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"])
if uploaded_file:
st.image(uploaded_file, caption="Uploaded image", use_container_width=True)
image = Image.open(uploaded_file).convert("RGB")
predicted_multi, predicted_single = run_inference(image)
# Show predictions clearly to the user
st.markdown(f"🧾 **Skin Issues**: {', '.join(predicted_multi)}")
st.markdown(f"πŸ“Œ **Most Likely Diagnosis**: {predicted_single}")
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?"
st.session_state.messages.append({"role": "user", "content": query})
with st.spinner("πŸ”Ž Analyzing and retrieving context..."):
response = get_reranked_response(query)
st.session_state.messages.append({"role": "assistant", "content": response.content})
with st.chat_message("assistant"):
st.markdown(response.content)
# === Chat Interface ===
if prompt := st.chat_input("Ask a follow-up..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
response = get_reranked_response(prompt)
st.session_state.messages.append({"role": "assistant", "content": response.content})
with st.chat_message("assistant"):
st.markdown(response.content)
# === PDF Button ===
if st.button("πŸ“„ Download Chat as PDF"):
pdf_file = export_chat_to_pdf(st.session_state.messages)
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")