import os import torch from huggingface_hub import snapshot_download from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr # ——— CONFIG ——— REPO_ID = "CodCodingCode/llama-3.1-8b-clinical-v1.1" SUBFOLDER = "checkpoint-2250" HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") if not HF_TOKEN: raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env") # ——— 1) Download the full repo ——— local_cache = snapshot_download( repo_id=REPO_ID, token=HF_TOKEN, ) print("[DEBUG] snapshot_download → local_cache:", local_cache) import pathlib print( "[DEBUG] MODEL root contents:", list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")), ) # ——— 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder ——— MODEL_DIR = local_cache MODEL_SUBFOLDER = SUBFOLDER print("[DEBUG] MODEL_DIR:", MODEL_DIR) print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR)) print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER))) # ——— 3) Load tokenizer & model from disk ——— tokenizer = AutoTokenizer.from_pretrained( MODEL_DIR, use_fast=True, ) print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer)) # Confirm tokenizer files are present import os print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR)) # Inspect tokenizer's initialization arguments try: print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs) except AttributeError: print("[DEBUG] No init_kwargs attribute on tokenizer.") model = AutoModelForCausalLM.from_pretrained( MODEL_DIR, subfolder=MODEL_SUBFOLDER, device_map="auto", torch_dtype=torch.float16, ) model.eval() print( "[DEBUG] Loaded model object:", model.__class__.__name__, "device:", next(model.parameters()).device, ) # === Role Agent with instruction/input/output format === class RoleAgent: def __init__(self, role_instruction, tokenizer, model): self.tokenizer = tokenizer self.model = model self.role_instruction = role_instruction def act(self, input_text): # Initialize thinking variable at the start thinking = "" # Initialize here, at the beginning of the method prompt = ( f"instruction: {self.role_instruction}\n" f"input: {input_text}\n" f"output:" ) encoding = self.tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(self.model.device) for k, v in encoding.items()} outputs = self.model.generate( **inputs, max_new_tokens=128, do_sample=True, temperature=0.3, pad_token_id=self.tokenizer.eos_token_id, ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"[DEBUG] Generated response: {response}") # Extract only the new generated content after the prompt prompt_length = len(prompt) if len(response) > prompt_length: generated_text = response[prompt_length:].strip() else: generated_text = response.strip() # Clean up the response - remove any repeated instruction/input/output patterns lines = generated_text.split("\n") clean_lines = [] for line in lines: line = line.strip() # Skip lines that look like instruction formatting if ( line.startswith("instruction:") or line.startswith("input:") or line.startswith("output:") or line == "" ): continue clean_lines.append(line) # Join the clean lines and take the first substantial response if clean_lines: answer = clean_lines[0] # If there are multiple clean lines, take the first one that's substantial for line in clean_lines: if len(line) > 20: # Arbitrary threshold for substantial content answer = line break else: # Fallback: try to extract after "output:" if present if "output:" in generated_text.lower(): parts = generated_text.lower().split("output:") if len(parts) > 1: answer = parts[-1].strip() else: answer = generated_text else: answer = generated_text # Additional cleanup - remove any remaining instruction artifacts answer = ( answer.replace("instruction:", "") .replace("input:", "") .replace("output:", "") .strip() ) # If answer is still messy, try to extract the actual medical content if "patient" in answer.lower() and len(answer) > 100: # Look for sentences that contain medical information sentences = answer.split(".") medical_sentences = [] for sentence in sentences: sentence = sentence.strip() if len(sentence) > 10 and any( word in sentence.lower() for word in [ "patient", "pain", "symptom", "diagnosis", "treatment", "knee", "reports", "experiencing", ] ): medical_sentences.append(sentence) if medical_sentences: answer = ". ".join( medical_sentences[:2] ) # Take first 2 medical sentences if not answer.endswith("."): answer += "." print( f"[CLEANED RESPONSE] Original length: {len(response)}, Cleaned: '{answer}'" ) # Return both thinking and answer return {"thinking": thinking, "output": answer} # === Agents === # ——— Instantiate RoleAgents for each of your eight roles ——— summarizer = RoleAgent( role_instruction=( "“You are a clinical summarizer. Given a transcript of a doctor–patient dialogue, " "extract a structured clinical vignette summarizing the key symptoms, relevant history, " "and any diagnostic clues.”" ), tokenizer=tokenizer, model=model, ) treatment_agent = RoleAgent( role_instruction=( "You are a board-certified clinician. Based on the provided diagnosis and patient vignette, " "propose a realistic, evidence-based treatment plan suitable for initiation by a primary care " "physician or psychiatrist." ), tokenizer=tokenizer, model=model, ) diagnoser_early = RoleAgent( role_instruction=( "You are a diagnostic reasoning model (Early Stage). Based on the patient vignette and " "early-stage observations, generate a list of plausible diagnoses with reasoning. Focus on " "broad differentials, considering common and uncommon conditions." ), tokenizer=tokenizer, model=model, ) diagnoser_middle = RoleAgent( role_instruction=( "You are a diagnostic reasoning model (Middle Stage). Given the current vignette, prior dialogue, " "and diagnostic hypothesis, refine the list of possible diagnoses with concise justifications for each. " "Aim to reduce diagnostic uncertainty." ), tokenizer=tokenizer, model=model, ) diagnoser_late = RoleAgent( role_instruction=( "You are a diagnostic reasoning model (Late Stage). Based on the final patient vignette summary and full conversation, " "provide the most likely diagnosis with structured reasoning. Confirm diagnostic certainty and include END if no more questioning is necessary." ), tokenizer=tokenizer, model=model, ) questioner_early = RoleAgent( role_instruction=( "You are a questioning agent (Early Stage). Your task is to propose highly relevant early-stage questions " "that can open the differential diagnosis widely. Use epidemiology, demographics, and vague presenting symptoms as guides." ), tokenizer=tokenizer, model=model, ) questioner_middle = RoleAgent( role_instruction=( "You are a questioning agent (Middle Stage). Using the current diagnosis, past questions, and patient vignette, " "generate a specific question to refine the current differential diagnosis. Return your reasoning and next question." ), tokenizer=tokenizer, model=model, ) questioner_late = RoleAgent( role_instruction=( "You are a questioning agent (Late Stage). Based on narrowed differentials and previous dialogue, " "generate a focused question that would help confirm or eliminate the final 1-2 suspected diagnoses." ), tokenizer=tokenizer, model=model, ) """[DEBUG] prompt: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues. Input: Doctor: What brings you in today? Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? Previous Vignette: Output: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues. Input: Doctor: What brings you in today? Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? Previous Vignette: Output: The patient is a 15-year-old male presenting with knee pain.""" # === Inference State === conversation_history = [] summary = "" diagnosis = "" # === Gradio Inference === def simulate_interaction(user_input, conversation_history=None): """Single turn interaction - no iterations, uses accumulated history""" if conversation_history is None: history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"] else: history = conversation_history.copy() history.append(f"Patient: {user_input}") # Summarize the full conversation history sum_in = "\n".join(history) sum_out = summarizer.act(sum_in) summary = sum_out["output"] # Diagnose based on summary diag_out = diagnoser_middle.act(summary) diagnosis = diag_out["output"] # Generate next question based on current understanding q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diagnosis}" q_out = questioner_middle.act(q_in) # Add doctor's response to history history.append(f"Doctor: {q_out['output']}") # Generate treatment plan (but don't end conversation) treatment_out = treatment_agent.act(f"Diagnosis: {diagnosis}\nVignette: {summary}") return { "summary": sum_out, "diagnosis": diag_out, "question": q_out, "treatment": treatment_out, "conversation": history, # Return full history list } # === Gradio UI === def ui_fn(user_input): """Non-stateful version for testing""" res = simulate_interaction(user_input) return f"""📋 Vignette Summary: 💭 THINKING: {res['summary']['thinking']} 📝 SUMMARY: {res['summary']['output']} 🩺 Diagnosis: 💭 THINKING: {res['diagnosis']['thinking']} 🔍 DIAGNOSIS: {res['diagnosis']['output']} ❓ Follow-up Question: 💭 THINKING: {res['question']['thinking']} 👨‍⚕️ DOCTOR: {res['question']['output']} 💊 Treatment Plan: 💭 THINKING: {res['treatment']['thinking']} 📋 TREATMENT: {res['treatment']['output']} 💬 Full Conversation: {chr(10).join(res['conversation'])} """ # === Stateful Gradio UI === def stateful_ui_fn(user_input, history): """Proper stateful conversation handler""" # Initialize history if first interaction if history is None: history = [] # Run one turn of interaction with accumulated history res = simulate_interaction(user_input, history) # Get the updated conversation history updated_history = res["conversation"] # Format the display output display_output = f"""💬 Conversation: {chr(10).join(updated_history)} 📋 Current Assessment: 🔍 Diagnosis: {res['diagnosis']['output']} 💊 Treatment Plan: {res['treatment']['output']} """ # Return display text and updated history for next turn return display_output, updated_history def chat_interface(user_input, history): """Alternative chat-style interface""" if history is None: history = [] # Run interaction res = simulate_interaction(user_input, history) updated_history = res["conversation"] # Return just the doctor's latest response and updated history doctor_response = res["question"]["output"] return doctor_response, updated_history # Create two different interfaces demo_stateful = gr.Interface( fn=stateful_ui_fn, inputs=[ gr.Textbox( label="Patient Response", placeholder="Describe your symptoms or answer the doctor's question...", ), gr.State(), # holds the conversation history ], outputs=[ gr.Textbox(label="Medical Consultation", lines=15), gr.State(), # returns the updated history ], title="🧠 AI Doctor - Full Medical Consultation", description="Have a conversation with an AI doctor. Each response builds on the previous conversation.", ) demo_chat = gr.Interface( fn=chat_interface, inputs=[ gr.Textbox(label="Your Response", placeholder="Tell me about your symptoms..."), gr.State(), ], outputs=[ gr.Textbox(label="Doctor", lines=5), gr.State(), ], title="🩺 AI Doctor Chat", description="Simple chat interface with the AI doctor.", ) if __name__ == "__main__": # Launch the stateful version by default demo_stateful.launch(share=True) # Uncomment the line below to use the chat version instead: # demo_chat.launch(share=True)