medical-test / app.py
CodCodingCode's picture
switched to new model name
ee5d942
raw
history blame
14 kB
import os
import torch
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# β€”β€”β€” CONFIG β€”β€”β€”
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical-v1.1"
SUBFOLDER = "checkpoint-2250"
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
if not HF_TOKEN:
raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env")
# β€”β€”β€” 1) Download the full repo β€”β€”β€”
local_cache = snapshot_download(
repo_id=REPO_ID,
token=HF_TOKEN,
)
print("[DEBUG] snapshot_download β†’ local_cache:", local_cache)
import pathlib
print(
"[DEBUG] MODEL root contents:",
list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")),
)
# β€”β€”β€” 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder β€”β€”β€”
MODEL_DIR = local_cache
MODEL_SUBFOLDER = SUBFOLDER
print("[DEBUG] MODEL_DIR:", MODEL_DIR)
print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR))
print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER)))
# β€”β€”β€” 3) Load tokenizer & model from disk β€”β€”β€”
tokenizer = AutoTokenizer.from_pretrained(
MODEL_DIR,
use_fast=True,
)
print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer))
# Confirm tokenizer files are present
import os
print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR))
# Inspect tokenizer's initialization arguments
try:
print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs)
except AttributeError:
print("[DEBUG] No init_kwargs attribute on tokenizer.")
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
subfolder=MODEL_SUBFOLDER,
device_map="auto",
torch_dtype=torch.float16,
)
model.eval()
print(
"[DEBUG] Loaded model object:",
model.__class__.__name__,
"device:",
next(model.parameters()).device,
)
# === Role Agent with instruction/input/output format ===
class RoleAgent:
def __init__(self, role_instruction, tokenizer, model):
self.tokenizer = tokenizer
self.model = model
self.role_instruction = role_instruction
def act(self, input_text):
# Initialize thinking variable at the start
thinking = "" # Initialize here, at the beginning of the method
prompt = (
f"instruction: {self.role_instruction}\n"
f"input: {input_text}\n"
f"output:"
)
encoding = self.tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(self.model.device) for k, v in encoding.items()}
outputs = self.model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.3,
pad_token_id=self.tokenizer.eos_token_id,
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"[DEBUG] Generated response: {response}")
# Extract only the new generated content after the prompt
prompt_length = len(prompt)
if len(response) > prompt_length:
generated_text = response[prompt_length:].strip()
else:
generated_text = response.strip()
# Clean up the response - remove any repeated instruction/input/output patterns
lines = generated_text.split("\n")
clean_lines = []
for line in lines:
line = line.strip()
# Skip lines that look like instruction formatting
if (
line.startswith("instruction:")
or line.startswith("input:")
or line.startswith("output:")
or line == ""
):
continue
clean_lines.append(line)
# Join the clean lines and take the first substantial response
if clean_lines:
answer = clean_lines[0]
# If there are multiple clean lines, take the first one that's substantial
for line in clean_lines:
if len(line) > 20: # Arbitrary threshold for substantial content
answer = line
break
else:
# Fallback: try to extract after "output:" if present
if "output:" in generated_text.lower():
parts = generated_text.lower().split("output:")
if len(parts) > 1:
answer = parts[-1].strip()
else:
answer = generated_text
else:
answer = generated_text
# Additional cleanup - remove any remaining instruction artifacts
answer = (
answer.replace("instruction:", "")
.replace("input:", "")
.replace("output:", "")
.strip()
)
# If answer is still messy, try to extract the actual medical content
if "patient" in answer.lower() and len(answer) > 100:
# Look for sentences that contain medical information
sentences = answer.split(".")
medical_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if len(sentence) > 10 and any(
word in sentence.lower()
for word in [
"patient",
"pain",
"symptom",
"diagnosis",
"treatment",
"knee",
"reports",
"experiencing",
]
):
medical_sentences.append(sentence)
if medical_sentences:
answer = ". ".join(
medical_sentences[:2]
) # Take first 2 medical sentences
if not answer.endswith("."):
answer += "."
print(
f"[CLEANED RESPONSE] Original length: {len(response)}, Cleaned: '{answer}'"
)
# Return both thinking and answer
return {"thinking": thinking, "output": answer}
# === Agents ===
# β€”β€”β€” Instantiate RoleAgents for each of your eight roles β€”β€”β€”
summarizer = RoleAgent(
role_instruction=(
"β€œYou are a clinical summarizer. Given a transcript of a doctor–patient dialogue, "
"extract a structured clinical vignette summarizing the key symptoms, relevant history, "
"and any diagnostic clues.”"
),
tokenizer=tokenizer,
model=model,
)
treatment_agent = RoleAgent(
role_instruction=(
"You are a board-certified clinician. Based on the provided diagnosis and patient vignette, "
"propose a realistic, evidence-based treatment plan suitable for initiation by a primary care "
"physician or psychiatrist."
),
tokenizer=tokenizer,
model=model,
)
diagnoser_early = RoleAgent(
role_instruction=(
"You are a diagnostic reasoning model (Early Stage). Based on the patient vignette and "
"early-stage observations, generate a list of plausible diagnoses with reasoning. Focus on "
"broad differentials, considering common and uncommon conditions."
),
tokenizer=tokenizer,
model=model,
)
diagnoser_middle = RoleAgent(
role_instruction=(
"You are a diagnostic reasoning model (Middle Stage). Given the current vignette, prior dialogue, "
"and diagnostic hypothesis, refine the list of possible diagnoses with concise justifications for each. "
"Aim to reduce diagnostic uncertainty."
),
tokenizer=tokenizer,
model=model,
)
diagnoser_late = RoleAgent(
role_instruction=(
"You are a diagnostic reasoning model (Late Stage). Based on the final patient vignette summary and full conversation, "
"provide the most likely diagnosis with structured reasoning. Confirm diagnostic certainty and include END if no more questioning is necessary."
),
tokenizer=tokenizer,
model=model,
)
questioner_early = RoleAgent(
role_instruction=(
"You are a questioning agent (Early Stage). Your task is to propose highly relevant early-stage questions "
"that can open the differential diagnosis widely. Use epidemiology, demographics, and vague presenting symptoms as guides."
),
tokenizer=tokenizer,
model=model,
)
questioner_middle = RoleAgent(
role_instruction=(
"You are a questioning agent (Middle Stage). Using the current diagnosis, past questions, and patient vignette, "
"generate a specific question to refine the current differential diagnosis. Return your reasoning and next question."
),
tokenizer=tokenizer,
model=model,
)
questioner_late = RoleAgent(
role_instruction=(
"You are a questioning agent (Late Stage). Based on narrowed differentials and previous dialogue, "
"generate a focused question that would help confirm or eliminate the final 1-2 suspected diagnoses."
),
tokenizer=tokenizer,
model=model,
)
"""[DEBUG] prompt: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.
Input: Doctor: What brings you in today?
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee?
Previous Vignette:
Output:
Instruction: You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues.
Input: Doctor: What brings you in today?
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee?
Previous Vignette:
Output: The patient is a 15-year-old male presenting with knee pain."""
# === Inference State ===
conversation_history = []
summary = ""
diagnosis = ""
# === Gradio Inference ===
def simulate_interaction(user_input, conversation_history=None):
"""Single turn interaction - no iterations, uses accumulated history"""
if conversation_history is None:
history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"]
else:
history = conversation_history.copy()
history.append(f"Patient: {user_input}")
# Summarize the full conversation history
sum_in = "\n".join(history)
sum_out = summarizer.act(sum_in)
summary = sum_out["output"]
# Diagnose based on summary
diag_out = diagnoser_middle.act(summary)
diagnosis = diag_out["output"]
# Generate next question based on current understanding
q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diagnosis}"
q_out = questioner_middle.act(q_in)
# Add doctor's response to history
history.append(f"Doctor: {q_out['output']}")
# Generate treatment plan (but don't end conversation)
treatment_out = treatment_agent.act(f"Diagnosis: {diagnosis}\nVignette: {summary}")
return {
"summary": sum_out,
"diagnosis": diag_out,
"question": q_out,
"treatment": treatment_out,
"conversation": history, # Return full history list
}
# === Gradio UI ===
def ui_fn(user_input):
"""Non-stateful version for testing"""
res = simulate_interaction(user_input)
return f"""πŸ“‹ Vignette Summary:
πŸ’­ THINKING: {res['summary']['thinking']}
πŸ“ SUMMARY: {res['summary']['output']}
🩺 Diagnosis:
πŸ’­ THINKING: {res['diagnosis']['thinking']}
πŸ” DIAGNOSIS: {res['diagnosis']['output']}
❓ Follow-up Question:
πŸ’­ THINKING: {res['question']['thinking']}
πŸ‘¨β€βš•οΈ DOCTOR: {res['question']['output']}
πŸ’Š Treatment Plan:
πŸ’­ THINKING: {res['treatment']['thinking']}
πŸ“‹ TREATMENT: {res['treatment']['output']}
πŸ’¬ Full Conversation:
{chr(10).join(res['conversation'])}
"""
# === Stateful Gradio UI ===
def stateful_ui_fn(user_input, history):
"""Proper stateful conversation handler"""
# Initialize history if first interaction
if history is None:
history = []
# Run one turn of interaction with accumulated history
res = simulate_interaction(user_input, history)
# Get the updated conversation history
updated_history = res["conversation"]
# Format the display output
display_output = f"""πŸ’¬ Conversation:
{chr(10).join(updated_history)}
πŸ“‹ Current Assessment:
πŸ” Diagnosis: {res['diagnosis']['output']}
πŸ’Š Treatment Plan: {res['treatment']['output']}
"""
# Return display text and updated history for next turn
return display_output, updated_history
def chat_interface(user_input, history):
"""Alternative chat-style interface"""
if history is None:
history = []
# Run interaction
res = simulate_interaction(user_input, history)
updated_history = res["conversation"]
# Return just the doctor's latest response and updated history
doctor_response = res["question"]["output"]
return doctor_response, updated_history
# Create two different interfaces
demo_stateful = gr.Interface(
fn=stateful_ui_fn,
inputs=[
gr.Textbox(
label="Patient Response",
placeholder="Describe your symptoms or answer the doctor's question...",
),
gr.State(), # holds the conversation history
],
outputs=[
gr.Textbox(label="Medical Consultation", lines=15),
gr.State(), # returns the updated history
],
title="🧠 AI Doctor - Full Medical Consultation",
description="Have a conversation with an AI doctor. Each response builds on the previous conversation.",
)
demo_chat = gr.Interface(
fn=chat_interface,
inputs=[
gr.Textbox(label="Your Response", placeholder="Tell me about your symptoms..."),
gr.State(),
],
outputs=[
gr.Textbox(label="Doctor", lines=5),
gr.State(),
],
title="🩺 AI Doctor Chat",
description="Simple chat interface with the AI doctor.",
)
if __name__ == "__main__":
# Launch the stateful version by default
demo_stateful.launch(share=True)
# Uncomment the line below to use the chat version instead:
# demo_chat.launch(share=True)