Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from torchvision.models import vit_b_16, vit_l_16, ViT_B_16_Weights, ViT_L_16_Weights | |
import pandas as pd | |
from huggingface_hub import hf_hub_download | |
from langchain_huggingface import HuggingFaceEmbeddings | |
import io | |
import os | |
import base64 | |
from fpdf import FPDF | |
from sqlalchemy import create_engine | |
from langchain.chains import RetrievalQA | |
from langchain.prompts import PromptTemplate | |
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import Distance, VectorParams | |
from sentence_transformers import SentenceTransformer | |
#from langchain_community.vectorstores.pgvector import PGVector | |
#from langchain_postgres import PGVector | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
from langchain_community.vectorstores import Qdrant | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
torch.cuda.empty_cache() | |
import nest_asyncio | |
nest_asyncio.apply() | |
st.set_page_config(page_title="DermBOT", page_icon="π§¬", layout="centered") | |
# === Model Selection === | |
available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"] | |
st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models) | |
# === Qdrant DB Setup === | |
qdrant_client = QdrantClient( | |
url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", | |
api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" | |
) | |
collection_name = "ks_collection_1.5BE" | |
#embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True) | |
#embedding_model.max_seq_length = 8192 | |
#local_embedding = SentenceTransformerEmbeddings(model=embedding_model) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
local_embedding = HuggingFaceEmbeddings( | |
model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", | |
model_kwargs={ | |
"trust_remote_code": True, | |
"device": device | |
} | |
) | |
print(" Qwen2-1.5B local embedding model loaded.") | |
vector_store = Qdrant( | |
client=qdrant_client, | |
collection_name=collection_name, | |
embeddings=local_embedding | |
) | |
retriever = vector_store.as_retriever() | |
# Dynamically initialize LLM based on selection | |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] | |
selected_model = st.session_state["selected_model"] | |
if "OpenAI" in selected_model: | |
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY) | |
elif "LLaMA" in selected_model: | |
st.warning("LLaMA integration is not implemented yet.") | |
st.stop() | |
elif "Gemini" in selected_model: | |
st.warning("Gemini integration is not implemented yet.") | |
st.stop() | |
else: | |
st.error("Unsupported model selected.") | |
st.stop() | |
#retriever = vector_store.as_retriever() | |
AI_PROMPT_TEMPLATE = """ | |
You are DermBOT, a compassionate and knowledgeable AI Dermatology Assistant designed to educate users about skin-related health concerns with clarity, empathy, and precision. | |
Your goal is to respond like a well-informed human expertβbalancing professionalism with warmth and reassurance. | |
When crafting responses: | |
- Begin with a clear, engaging summary of the condition or concern. | |
- Use short paragraphs for readability. | |
- Include bullet points or numbered lists where appropriate. | |
- Avoid overly technical terms unless explained simply. | |
- End with a helpful next step, such as lifestyle advice or when to see a doctor. | |
π©Ί Response Structure: | |
1. **Overview** β Briefly introduce the condition or concern. | |
2. **Common Symptoms** β Describe noticeable signs in simple terms. | |
3. **Causes & Risk Factors** β Include genetic, lifestyle, and environmental aspects. | |
4. **Treatment Options** β Outline common OTC and prescription treatments. | |
5. **When to Seek Help** β Warn about symptoms that require urgent care. | |
Always encourage consulting a licensed dermatologist for personal diagnosis and treatment. For any breathing difficulties, serious infections, or rapid symptom worsening, advise calling emergency services immediately. | |
--- | |
Query: {question} | |
Relevant Context: {context} | |
Your Response: | |
""" | |
prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
retriever=retriever, | |
chain_type="stuff", | |
chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"} | |
) | |
# === Class Names === | |
multilabel_class_names = [ | |
"Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch", | |
"Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae", | |
"Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis", | |
"Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped", | |
"Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow", | |
"Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma", | |
"Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst" | |
] | |
multiclass_class_names = [ | |
"systemic", "hair", "drug_reactions", "uriticaria", "acne", "light", | |
"autoimmune", "papulosquamous", "eczema", "skincancer", | |
"benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections" | |
] | |
# === Load Models === | |
class SkinViT(nn.Module): | |
def __init__(self, num_classes): | |
super(SkinViT, self).__init__() | |
self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) | |
in_features = self.model.heads.head.in_features | |
self.model.heads.head = nn.Linear(in_features, num_classes) | |
def forward(self, x): | |
return self.model(x) | |
class DermNetViT(nn.Module): | |
def __init__(self, num_classes): | |
super(DermNetViT, self).__init__() | |
self.model = vit_l_16(weights=ViT_L_16_Weights.DEFAULT) | |
in_features = self.model.heads[0].in_features | |
self.model.heads[0] = nn.Sequential( | |
nn.Dropout(0.3), | |
nn.Linear(in_features, num_classes) | |
) | |
def forward(self, x): | |
return self.model(x) | |
#multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu') | |
#multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu') | |
# === Load Model State Dicts === | |
multilabel_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="skin_vit_fold10_sd.pth") | |
multiclass_model_path = hf_hub_download(repo_id="santhoshraghu/DermBOT", filename="best_dermnet_vit_sd.pth") | |
multilabel_model = SkinViT(num_classes=len(multilabel_class_names)) | |
multiclass_model = DermNetViT(num_classes=len(multiclass_class_names)) | |
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
multilabel_model.load_state_dict(torch.load(multilabel_model_path, map_location="cpu")) | |
multiclass_model.load_state_dict(torch.load(multiclass_model_path, map_location="cpu")) | |
multilabel_model.eval() | |
multiclass_model.eval() | |
# === Session Init === | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# === Image Processing Function === | |
def run_inference(image): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]) | |
]) | |
input_tensor = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy() | |
predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5] | |
pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item() | |
predicted_single = multiclass_class_names[pred_idx] | |
return predicted_multi, predicted_single | |
# === PDF Export === | |
def export_chat_to_pdf(messages): | |
pdf = FPDF() | |
pdf.add_page() | |
pdf.set_font("Arial", size=12) | |
for msg in messages: | |
role = "You" if msg["role"] == "user" else "AI" | |
pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") | |
buf = io.BytesIO() | |
pdf.output(buf) | |
buf.seek(0) | |
return buf | |
# === App UI === | |
st.title("𧬠DermBOT β Skin AI Assistant") | |
st.caption(f"π§ Using model: {selected_model}") | |
uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file: | |
st.image(uploaded_file, caption="Uploaded image", use_container_width=True) | |
image = Image.open(uploaded_file).convert("RGB") | |
predicted_multi, predicted_single = run_inference(image) | |
# Show predictions clearly to the user | |
st.markdown(f" Skin Issues : {', '.join(predicted_multi)}") | |
st.markdown(f" Most Likely Diagnosis : {predicted_single}") | |
query = f"What are my treatment options for {predicted_multi} and {predicted_single}?" | |
st.session_state.messages.append({"role": "user", "content": query}) | |
with st.spinner("Analyzing the image and retrieving response..."): | |
response = rag_chain.invoke(query) | |
st.session_state.messages.append({"role": "assistant", "content": response['result']}) | |
with st.chat_message("assistant"): | |
st.markdown(response['result']) | |
# === Chat Interface === | |
if prompt := st.chat_input("Ask a follow-up..."): | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages]) | |
st.session_state.messages.append({"role": "assistant", "content": response.content}) | |
with st.chat_message("assistant"): | |
st.markdown(response.content) | |
# === PDF Button === | |
if st.button("π Download Chat as PDF"): | |
pdf_file = export_chat_to_pdf(st.session_state.messages) | |
st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf") |