import streamlit as st from PIL import Image import torch import torch.nn as nn from torchvision import transforms from torchvision.models import vit_b_16, ViT_B_16_Weights import pandas as pd import io import os import base64 from fpdf import FPDF from sqlalchemy import create_engine from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from sentence_transformers import SentenceTransformer #from langchain_community.vectorstores.pgvector import PGVector #from langchain_postgres import PGVector from langchain_openai import OpenAIEmbeddings, ChatOpenAI from langchain_community.vectorstores import Qdrant from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.embeddings import SentenceTransformerEmbeddings import nest_asyncio nest_asyncio.apply() st.set_page_config(page_title="DermBOT", page_icon="🧬", layout="centered") #os.environ["PGVECTOR_CONNECTION_STRING"] = "postgresql+psycopg2://postgres:postgres@localhost:5432/VectorDB" # === Model Selection === available_models = ["OpenAI GPT-4o", "LLaMA 3", "Gemini Pro"] st.session_state["selected_model"] = st.sidebar.selectbox("Select LLM Model", available_models) # === Qdrant DB Setup === qdrant_client = QdrantClient( url="https://2715ddd8-647f-40ee-bca4-9027d193e8aa.us-east-1-0.aws.cloud.qdrant.io", api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.HXzezXdWMFeeR16F7zvqgjzsqrcm8hqa-StXdToFP9Q" ) collection_name = "ks_collection_1.5BE" #embedding_model = SentenceTransformer("D:\DR\RAG\gte-Qwen2-1.5B-instruct", trust_remote_code=True) #embedding_model.max_seq_length = 8192 #local_embedding = SentenceTransformerEmbeddings(model=embedding_model) #local_embedding = HuggingFaceEmbeddings( # model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", # model_kwargs={"trust_remote_code": True, "device":"cpu"} #) local_embedding = HuggingFaceEmbeddings( model_name="Alibaba-NLP/gte-Qwen2-1.5B-instruct", model_kwargs={ "trust_remote_code": True, "device": "cpu", "torch_dtype": "float32" } ) print(" Qwen2-1.5B local embedding model loaded.") vector_store = Qdrant( client=qdrant_client, collection_name=collection_name, embeddings=local_embedding ) retriever = vector_store.as_retriever() # Dynamically initialize LLM based on selection OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] selected_model = st.session_state["selected_model"] if "OpenAI" in selected_model: llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY) elif "LLaMA" in selected_model: st.warning("LLaMA integration is not implemented yet.") st.stop() elif "Gemini" in selected_model: st.warning("Gemini integration is not implemented yet.") st.stop() else: st.error("Unsupported model selected.") st.stop() ''' vector_store = PGVector.from_existing_index( embedding=embedding_model, connection=engine, collection_name="documents" ) ''' #retriever = vector_store.as_retriever() AI_PROMPT_TEMPLATE = """You are an AI-assisted Dermatology Chatbot, specializing in diagnosing and educating users about skin diseases. You provide accurate, compassionate, and detailed explanations while using correct medical terminology. Guidelines: 1. Symptoms - Explain in simple terms with proper medical definitions. 2. Causes - Include genetic, environmental, and lifestyle-related risk factors. 3. Medications & Treatments - Provide common prescription and over-the-counter treatments. 4. Warnings & Emergencies - Always recommend consulting a licensed dermatologist. 5. Emergency Note - If symptoms worsen or include difficulty breathing, **advise calling 911 immediately. Query: {question} Relevant Information: {context} Answer: """ prompt_template = PromptTemplate(template=AI_PROMPT_TEMPLATE, input_variables=["question", "context"]) rag_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, chain_type="stuff", chain_type_kwargs={"prompt": prompt_template, "document_variable_name": "context"} ) # === Class Names === multilabel_class_names = [ "Vesicle", "Papule", "Macule", "Plaque", "Abscess", "Pustule", "Bulla", "Patch", "Nodule", "Ulcer", "Crust", "Erosion", "Excoriation", "Atrophy", "Exudate", "Purpura/Petechiae", "Fissure", "Induration", "Xerosis", "Telangiectasia", "Scale", "Scar", "Friable", "Sclerosis", "Pedunculated", "Exophytic/Fungating", "Warty/Papillomatous", "Dome-shaped", "Flat topped", "Brown(Hyperpigmentation)", "Translucent", "White(Hypopigmentation)", "Purple", "Yellow", "Black", "Erythema", "Comedo", "Lichenification", "Blue", "Umbilicated", "Poikiloderma", "Salmon", "Wheal", "Acuminate", "Burrow", "Gray", "Pigmented", "Cyst" ] multiclass_class_names = [ "systemic", "hair", "drug_reactions", "uriticaria", "acne", "light", "autoimmune", "papulosquamous", "eczema", "skincancer", "benign_tumors", "bacteria_parasetic_infections", "fungal_infections", "viral_skin_infections" ] # === Load Models === class SkinViT(nn.Module): def __init__(self, num_classes): super(SkinViT, self).__init__() self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT) in_features = self.model.heads[0].in_features self.model.heads[0] = nn.Linear(in_features, num_classes) def forward(self, x): return self.model(x) #multilabel_model = torch.load("D:/DR/RAG/BestModels2703/skin_vit_fold10.pth", map_location='cpu') #multiclass_model = torch.load("D:/DR/RAG/BestModels2703/best_dermnet_vit.pth", map_location='cpu') multilabel_model = SkinViT(num_classes=len(multilabel_class_names)) multiclass_model = SkinViT(num_classes=len(multiclass_class_names)) multilabel_model.load_state_dict(torch.hub.load_state_dict_from_url( "https://huggingface.co/santhoshraghu/DermBOT/resolve/main/skin_vit_fold10.pth", map_location="cpu" )) multiclass_model.load_state_dict(torch.hub.load_state_dict_from_url( "https://huggingface.co/santhoshraghu/DermBOT/resolve/main/best_dermnet_vit.pth", map_location="cpu" )) multilabel_model.eval() multiclass_model.eval() # === Session Init === if "messages" not in st.session_state: st.session_state.messages = [] # === Image Processing Function === def run_inference(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) input_tensor = transform(image).unsqueeze(0) with torch.no_grad(): probs_multi = torch.sigmoid(multilabel_model(input_tensor)).squeeze().numpy() predicted_multi = [multilabel_class_names[i] for i, p in enumerate(probs_multi) if p > 0.5] pred_idx = torch.argmax(multiclass_model(input_tensor), dim=1).item() predicted_single = multiclass_class_names[pred_idx] return predicted_multi, predicted_single # === PDF Export === def export_chat_to_pdf(messages): pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) for msg in messages: role = "You" if msg["role"] == "user" else "AI" pdf.multi_cell(0, 10, f"{role}: {msg['content']}\n") buf = io.BytesIO() pdf.output(buf) buf.seek(0) return buf # === App UI === st.title("🧬 DermBOT — Skin AI Assistant") st.caption(f"🧠 Using model: {selected_model}") uploaded_file = st.file_uploader("Upload a skin image", type=["jpg", "jpeg", "png"]) if uploaded_file: st.image(uploaded_file, caption="Uploaded image", use_column_width=True) image = Image.open(uploaded_file).convert("RGB") predicted_multi, predicted_single = run_inference(image) # Show predictions clearly to the user st.markdown(f" Skin Issues : {', '.join(predicted_multi)}") st.markdown(f" Most Likely Diagnosis : {predicted_single}") query = f"What are my treatment options for {predicted_multi} and {predicted_single}?" st.session_state.messages.append({"role": "user", "content": query}) with st.spinner("Analyzing the image and retrieving response..."): response = rag_chain.invoke(query) st.session_state.messages.append({"role": "assistant", "content": response['result']}) with st.chat_message("assistant"): st.markdown(response['result']) # === Chat Interface === if prompt := st.chat_input("Ask a follow-up..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) response = llm.invoke([{"role": m["role"], "content": m["content"]} for m in st.session_state.messages]) st.session_state.messages.append({"role": "assistant", "content": response.content}) with st.chat_message("assistant"): st.markdown(response.content) # === PDF Button === if st.button("📄 Download Chat as PDF"): pdf_file = export_chat_to_pdf(st.session_state.messages) st.download_button("Download PDF", data=pdf_file, file_name="chat_history.pdf", mime="application/pdf")