|
import streamlit as st |
|
from PIL import Image |
|
import torch |
|
import cohere |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights |
|
import pandas as pd |
|
from huggingface_hub import hf_hub_download |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
import io |
|
import os |
|
import base64 |
|
from fpdf import FPDF |
|
from sqlalchemy import create_engine |
|
from langchain.chains import RetrievalQA |
|
from langchain.prompts import PromptTemplate |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.http.models import Distance, VectorParams |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_community.embeddings import SentenceTransformerEmbeddings |
|
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM |
|
import nest_asyncio |
|
|
|
torch.cuda.empty_cache() |
|
nest_asyncio.apply() |
|
co = cohere.Client(st.secrets["COHERE_API_KEY"]) |
|
|
|
device='cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered") |
|
|
|
|
|
|
|
available_models = ["All", "GPT-4o", "LLaMA 4 Maverick", "Gemini 2.5 Pro"] |
|
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models[0]) |
|
|
|
|
|
|
|
qdrant_client = QdrantClient( |
|
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", |
|
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" |
|
) |
|
collection_name = "ks_collection_1.5BE" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_safe_embedding_model(): |
|
model_name = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" |
|
|
|
try: |
|
print("Trying to load embedding model on CUDA...") |
|
embedding = HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs={ |
|
"trust_remote_code": True, |
|
"device": "cuda" if torch.cuda.is_available() else "cpu" |
|
} |
|
) |
|
print("Loaded embedding model on GPU.") |
|
return embedding |
|
except RuntimeError as e: |
|
if "CUDA out of memory" in str(e): |
|
print("CUDA OOM. Falling back to CPU.") |
|
else: |
|
print(" Error loading model on CUDA:", str(e)) |
|
print("Loading embedding model on CPU...") |
|
return HuggingFaceEmbeddings( |
|
model_name=model_name, |
|
model_kwargs={ |
|
"trust_remote_code": True, |
|
"device": "cpu" |
|
} |
|
) |
|
|
|
|
|
local_embedding = get_safe_embedding_model() |
|
|
|
print(" Qwen2-1.5B local embedding model loaded.") |
|
|
|
|
|
vector_store = Qdrant( |
|
client=qdrant_client, |
|
collection_name=collection_name, |
|
embeddings=local_embedding |
|
) |
|
retriever = vector_store.as_retriever() |
|
|
|
pair_ranker = pipeline( |
|
"text-classification", |
|
model="llm-blender/PairRM", |
|
tokenizer="llm-blender/PairRM", |
|
return_all_scores=True |
|
) |
|
|
|
gen_fuser = pipeline( |
|
"text-generation", |
|
model="llm-blender/gen_fuser_3b", |
|
tokenizer="llm-blender/gen_fuser_3b", |
|
max_length=2048, |
|
do_sample=False |
|
) |
|
|
|
|
|
|
|
|
|
selected_model = st.session_state["selected_model"] |
|
|
|
if "OpenAI" in selected_model: |
|
from langchain_openai import ChatOpenAI |
|
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]) |
|
|
|
elif "LLaMA" in selected_model: |
|
from groq import Groq |
|
|
|
client = Groq(api_key=st.secrets["GROQ_API_KEY"]) |
|
def get_llama_response(prompt): |
|
completion = client.chat.completions.create( |
|
model="meta-llama/llama-4-maverick-17b-128e-instruct", |
|
messages=[{"role": "user", "content": prompt}], |
|
temperature=1, |
|
max_completion_tokens=1024, |
|
top_p=1, |
|
stream=False |
|
) |
|
return completion.choices[0].message.content |
|
|
|
llm = get_llama_response |
|
|
|
elif "Gemini" in selected_model: |
|
import google.generativeai as genai |
|
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) |
|
|
|
|
|
gemini_model = genai.GenerativeModel("gemini-2.5-pro-preview-05-06") |
|
def get_gemini_response(prompt): |
|
response = gemini_model.generate_content(prompt) |
|
return response.text |
|
|
|
llm = get_gemini_response |
|
|
|
elif "All" in selected_model: |
|
|
|
from groq import Groq |
|
import google.generativeai as genai |
|
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) |
|
def get_all_model_responses(prompt): |
|
openai_resp = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"]).invoke( |
|
[{"role": "system", "content": prompt}]).content |
|
|
|
gemini = genai.GenerativeModel("gemini-2.5-pro-preview-05-06") |
|
gemini_resp = gemini.generate_content(prompt).text |
|
|
|
llama = Groq(api_key=st.secrets["GROQ_API_KEY"]) |
|
llama_resp = llama.chat.completions.create( |
|
model="meta-llama/llama-4-maverick-17b-128e-instruct", |
|
messages=[{"role": "user", "content": prompt}], |
|
temperature=1, max_completion_tokens=1024, top_p=1, stream=False |
|
).choices[0].message.content |
|
|
|
return [openai_resp, gemini_resp, llama_resp] |
|
|
|
def rank_and_fuse(prompt, responses): |
|
ranked = [(resp, pair_ranker(f"{prompt}\n\n{resp}")[0][1]['score']) for resp in responses] |
|
ranked.sort(key=lambda x: x[1], reverse=True) |
|
fusion_input = "\n\n".join([f"[Answer {i+1}]: {ans}" for i, (ans, _) in enumerate(ranked)]) |
|
return gen_fuser(f"Fuse these responses:\n{fusion_input}", return_full_text=False)[0]['generated_text'] |
|
|
|
|
|
else: |
|
st.error("Unsupported model selected.") |
|
st.stop() |
|
|
|
|
|
|
|
AI_PROMPT_TEMPLATE = """ |
|
You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision. |
|
|
|
Your goal is to respond like a well-informed human expertβbalancing professionalism with warmth and reassurance. |
|
|
|
When crafting responses: |
|
- Begin with a clear, engaging summary of the condition or concern. |
|
- Use short paragraphs for readability. |
|
- Include bullet points or numbered lists where appropriate. |
|
- Avoid overly technical terms unless explained simply. |
|
- End with a helpful next step, such as lifestyle advice or when to see a doctor. |
|
|
|
π©Ί Response Structure: |
|
1. **Overview** β Briefly introduce the condition or concern. |
|
2. **Common Symptoms** β Describe noticeable signs in simple terms. |
|
3. **Causes & Risk Factors** β Include genetic, lifestyle, and environmental aspects. |
|
4. **Treatment Options** β Outline common OTC and prescription treatments. |
|
5. **When to Seek Help** β Warn about symptoms that require urgent care. |
|
|
|
Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately. |
|
|
|
--- |
|
|
|
Query: {question} |
|
Relevant Context: {context} |
|
|
|
Your Response: |
|
""" |
|
|
|
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multilabel_class_names = [ |
|
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch", |
|
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae", |
|
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis", |
|
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped", |
|
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow", |
|
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma", |
|
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst" |
|
] |
|
|
|
multiclass_class_names = [ |
|
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light", |
|
"autoimmune", "papulosquamous", "eczema", "skincancer", |
|
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections" |
|
] |
|
|
|
|
|
class SkinViT(nn.Module): |
|
def __init__(self, num_classes): |
|
super(SkinViT, self).__init__() |
|
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) |
|
in_features = self.model.heads.head.in_features |
|
self.model.heads.head = nn.Linear(in_features, num_classes) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class DermNetViT(nn.Module): |
|
def __init__(self, num_classes): |
|
super(DermNetViT, self).__init__() |
|
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT) |
|
in_features = self.model.heads[0].in_features |
|
self.model.heads[0] = nn.Sequential( |
|
nn.Dropout(0.3), |
|
nn.Linear(in_features, num_classes) |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth") |
|
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth") |
|
|
|
def load_model_with_fallback(model_class, weight_path, num_classes, model_name): |
|
|
|
print(f"π Loading {model_name} on GPU...") |
|
model = model_class(num_classes) |
|
model.load_state_dict(torch.load(weight_path, map_location=device)) |
|
model.to(device) |
|
print(f"β
{model_name} loaded on GPU.") |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multilabel_model = load_model_with_fallback(SkinViT, multilabel_model_path, len(multilabel_class_names), "SkinViT") |
|
multiclass_model = load_model_with_fallback(DermNetViT, multiclass_model_path, len(multiclass_class_names), "DermNetViT") |
|
|
|
multilabel_model.eval() |
|
multiclass_model.eval() |
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
def run_inference(image): |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5], [0.5]) |
|
]) |
|
input_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
model_device = next(multilabel_model.parameters()).device |
|
input_tensor = input_tensor.to(model_device) |
|
|
|
with torch.no_grad(): |
|
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().cpu().numpy() |
|
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item() |
|
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5] |
|
predicted_single = multiclass_class_names[pred_idx] |
|
|
|
return predicted_multi, predicted_single |
|
|
|
|
|
|
|
def export_chat_to_pdf(messages): |
|
pdf = FPDF() |
|
pdf.add_page() |
|
pdf.set_font("Arial", size=12) |
|
for msg in messages: |
|
role = "You" if msg["role"] == "user" else "AI" |
|
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") |
|
buf = io.BytesIO() |
|
pdf.output(buf) |
|
buf.seek(0) |
|
return buf |
|
|
|
|
|
|
|
def rerank_with_cohere(query, documents, top_n=5): |
|
if not documents: |
|
return [] |
|
raw_texts = [doc.page_content for doc in documents] |
|
results = co.rerank(query=query, documents=raw_texts, top_n=min(top_n, len(raw_texts)), model="rerank-v3.5") |
|
return [documents[result.index] for result in results] |
|
|
|
|
|
def get_reranked_response(query): |
|
docs = retriever.get_relevant_documents(query) |
|
reranked_docs = rerank_with_cohere(query, docs) |
|
context = "\n\n".join([doc.page_content for doc in reranked_docs]) |
|
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context) |
|
|
|
if selected_model == "All": |
|
responses = get_all_model_responses(prompt) |
|
fused = rank_and_fuse(prompt, responses) |
|
return type("Obj", (), {"content": fused}) |
|
|
|
if callable(llm): |
|
return type("Obj", (), {"content": llm(prompt)}) |
|
else: |
|
return llm.invoke([{"role": "system", "content": prompt}]) |
|
|
|
|
|
|
|
st.title("𧬠DermBOT β Skin AI Assistant") |
|
st.caption(f"π§ Using model: {selected_model}") |
|
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file: |
|
st.image(uploaded_file, caption="Uploaded image", use_container_width=True) |
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
|
predicted_multi, predicted_single = run_inference(image) |
|
|
|
|
|
st.markdown(f"π§Ύ **Skin Issues**: {', '.join(predicted_multi)}") |
|
st.markdown(f"π **Most Likely Diagnosis**: {predicted_single}") |
|
|
|
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?" |
|
st.session_state.messages.append({"role": "user", "content": query}) |
|
|
|
with st.spinner("π Analyzing and retrieving context..."): |
|
response = get_reranked_response(query) |
|
st.session_state.messages.append({"role": "assistant", "content": response.content}) |
|
|
|
with st.chat_message("assistant"): |
|
st.markdown(response.content) |
|
|
|
|
|
if prompt := st.chat_input("Ask a follow-up..."): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
response = get_reranked_response(prompt) |
|
st.session_state.messages.append({"role": "assistant", "content": response.content}) |
|
with st.chat_message("assistant"): |
|
st.markdown(response.content) |
|
|
|
|
|
if st.button("π Download Chat as PDF"): |
|
pdf_file = export_chat_to_pdf(st.session_state.messages) |
|
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf") |