Spaces:
Paused
Paused
import os | |
import torch | |
from huggingface_hub import snapshot_download | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# βββ CONFIG βββ | |
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical-v1.1" | |
SUBFOLDER = "checkpoint-2250" | |
HF_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN") | |
if not HF_TOKEN: | |
raise RuntimeError("Missing HUGGINGFACE_HUB_TOKEN in env") | |
# βββ 1) Download the full repo βββ | |
local_cache = snapshot_download( | |
repo_id=REPO_ID, | |
token=HF_TOKEN, | |
) | |
print("[DEBUG] snapshot_download β local_cache:", local_cache) | |
import pathlib | |
print( | |
"[DEBUG] MODEL root contents:", | |
list(pathlib.Path(local_cache).glob(f"{SUBFOLDER}/*")), | |
) | |
# βββ 2) Repo root contains tokenizer.json; model shards live in the checkpoint subfolder βββ | |
MODEL_DIR = local_cache | |
MODEL_SUBFOLDER = SUBFOLDER | |
print("[DEBUG] MODEL_DIR:", MODEL_DIR) | |
print("[DEBUG] MODEL_DIR files:", os.listdir(MODEL_DIR)) | |
print("[DEBUG] Checkpoint files:", os.listdir(os.path.join(MODEL_DIR, MODEL_SUBFOLDER))) | |
# βββ 3) Load tokenizer & model from disk βββ | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_DIR, | |
use_fast=True, | |
) | |
print("[DEBUG] Loaded fast tokenizer object:", tokenizer, "type:", type(tokenizer)) | |
# Confirm tokenizer files are present | |
import os | |
print("[DEBUG] Files in MODEL_DIR for tokenizer:", os.listdir(MODEL_DIR)) | |
# Inspect tokenizer's initialization arguments | |
try: | |
print("[DEBUG] Tokenizer init_kwargs:", tokenizer.init_kwargs) | |
except AttributeError: | |
print("[DEBUG] No init_kwargs attribute on tokenizer.") | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_DIR, | |
subfolder=MODEL_SUBFOLDER, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
) | |
model.eval() | |
print( | |
"[DEBUG] Loaded model object:", | |
model.__class__.__name__, | |
"device:", | |
next(model.parameters()).device, | |
) | |
# === Role Agent with instruction/input/output format === | |
class RoleAgent: | |
def __init__(self, role_instruction, tokenizer, model): | |
self.tokenizer = tokenizer | |
self.model = model | |
self.role_instruction = role_instruction | |
def act(self, input_text): | |
# Initialize thinking variable at the start | |
thinking = "" # Initialize here, at the beginning of the method | |
prompt = ( | |
f"instruction: {self.role_instruction}\n" | |
f"input: {input_text}\n" | |
f"output:" | |
) | |
encoding = self.tokenizer(prompt, return_tensors="pt") | |
inputs = {k: v.to(self.model.device) for k, v in encoding.items()} | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=128, | |
do_sample=True, | |
temperature=0.3, | |
pad_token_id=self.tokenizer.eos_token_id, | |
) | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(f"[DEBUG] Generated response: {response}") | |
# Extract only the new generated content after the prompt | |
prompt_length = len(prompt) | |
if len(response) > prompt_length: | |
generated_text = response[prompt_length:].strip() | |
else: | |
generated_text = response.strip() | |
# Clean up the response - remove any repeated instruction/input/output patterns | |
lines = generated_text.split("\n") | |
clean_lines = [] | |
for line in lines: | |
line = line.strip() | |
# Skip lines that look like instruction formatting | |
if ( | |
line.startswith("instruction:") | |
or line.startswith("input:") | |
or line.startswith("output:") | |
or line == "" | |
): | |
continue | |
clean_lines.append(line) | |
# Join the clean lines and take the first substantial response | |
if clean_lines: | |
answer = clean_lines[0] | |
# If there are multiple clean lines, take the first one that's substantial | |
for line in clean_lines: | |
if len(line) > 20: # Arbitrary threshold for substantial content | |
answer = line | |
break | |
else: | |
# Fallback: try to extract after "output:" if present | |
if "output:" in generated_text.lower(): | |
parts = generated_text.lower().split("output:") | |
if len(parts) > 1: | |
answer = parts[-1].strip() | |
else: | |
answer = generated_text | |
else: | |
answer = generated_text | |
# Additional cleanup - remove any remaining instruction artifacts | |
answer = ( | |
answer.replace("instruction:", "") | |
.replace("input:", "") | |
.replace("output:", "") | |
.strip() | |
) | |
# If answer is still messy, try to extract the actual medical content | |
if "patient" in answer.lower() and len(answer) > 100: | |
# Look for sentences that contain medical information | |
sentences = answer.split(".") | |
medical_sentences = [] | |
for sentence in sentences: | |
sentence = sentence.strip() | |
if len(sentence) > 10 and any( | |
word in sentence.lower() | |
for word in [ | |
"patient", | |
"pain", | |
"symptom", | |
"diagnosis", | |
"treatment", | |
"knee", | |
"reports", | |
"experiencing", | |
] | |
): | |
medical_sentences.append(sentence) | |
if medical_sentences: | |
answer = ". ".join( | |
medical_sentences[:2] | |
) # Take first 2 medical sentences | |
if not answer.endswith("."): | |
answer += "." | |
print( | |
f"[CLEANED RESPONSE] Original length: {len(response)}, Cleaned: '{answer}'" | |
) | |
# Return both thinking and answer | |
return {"thinking": thinking, "output": answer} | |
# === Agents === | |
# βββ Instantiate RoleAgents for each of your eight roles βββ | |
summarizer = RoleAgent( | |
role_instruction=( | |
"βYou are a clinical summarizer. Given a transcript of a doctorβpatient dialogue, " | |
"extract a structured clinical vignette summarizing the key symptoms, relevant history, " | |
"and any diagnostic clues.β" | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
treatment_agent = RoleAgent( | |
role_instruction=( | |
"You are a board-certified clinician. Based on the provided diagnosis and patient vignette, " | |
"propose a realistic, evidence-based treatment plan suitable for initiation by a primary care " | |
"physician or psychiatrist." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
diagnoser_early = RoleAgent( | |
role_instruction=( | |
"You are a diagnostic reasoning model (Early Stage). Based on the patient vignette and " | |
"early-stage observations, generate a list of plausible diagnoses with reasoning. Focus on " | |
"broad differentials, considering common and uncommon conditions." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
diagnoser_middle = RoleAgent( | |
role_instruction=( | |
"You are a diagnostic reasoning model (Middle Stage). Given the current vignette, prior dialogue, " | |
"and diagnostic hypothesis, refine the list of possible diagnoses with concise justifications for each. " | |
"Aim to reduce diagnostic uncertainty." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
diagnoser_late = RoleAgent( | |
role_instruction=( | |
"You are a diagnostic reasoning model (Late Stage). Based on the final patient vignette summary and full conversation, " | |
"provide the most likely diagnosis with structured reasoning. Confirm diagnostic certainty and include END if no more questioning is necessary." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
questioner_early = RoleAgent( | |
role_instruction=( | |
"You are a questioning agent (Early Stage). Your task is to propose highly relevant early-stage questions " | |
"that can open the differential diagnosis widely. Use epidemiology, demographics, and vague presenting symptoms as guides." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
questioner_middle = RoleAgent( | |
role_instruction=( | |
"You are a questioning agent (Middle Stage). Using the current diagnosis, past questions, and patient vignette, " | |
"generate a specific question to refine the current differential diagnosis. Return your reasoning and next question." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
questioner_late = RoleAgent( | |
role_instruction=( | |
"You are a questioning agent (Late Stage). Based on narrowed differentials and previous dialogue, " | |
"generate a focused question that would help confirm or eliminate the final 1-2 suspected diagnoses." | |
), | |
tokenizer=tokenizer, | |
model=model, | |
) | |
"""[DEBUG] prompt: Instruction: You are a clinical summarizer trained to extract structured vignettes from doctorβpatient dialogues. | |
Input: Doctor: What brings you in today? | |
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? | |
Previous Vignette: | |
Output: | |
Instruction: You are a clinical summarizer trained to extract structured vignettes from doctorβpatient dialogues. | |
Input: Doctor: What brings you in today? | |
Patient: I am a male. I am 15. My knee hurts. What may be the issue with my knee? | |
Previous Vignette: | |
Output: The patient is a 15-year-old male presenting with knee pain.""" | |
# === Inference State === | |
conversation_history = [] | |
summary = "" | |
diagnosis = "" | |
# === Gradio Inference === | |
def simulate_interaction(user_input, conversation_history=None): | |
"""Single turn interaction - no iterations, uses accumulated history""" | |
if conversation_history is None: | |
history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"] | |
else: | |
history = conversation_history.copy() | |
history.append(f"Patient: {user_input}") | |
# Summarize the full conversation history | |
sum_in = "\n".join(history) | |
sum_out = summarizer.act(sum_in) | |
summary = sum_out["output"] | |
# Diagnose based on summary | |
diag_out = diagnoser_middle.act(summary) | |
diagnosis = diag_out["output"] | |
# Generate next question based on current understanding | |
q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diagnosis}" | |
q_out = questioner_middle.act(q_in) | |
# Add doctor's response to history | |
history.append(f"Doctor: {q_out['output']}") | |
# Generate treatment plan (but don't end conversation) | |
treatment_out = treatment_agent.act(f"Diagnosis: {diagnosis}\nVignette: {summary}") | |
return { | |
"summary": sum_out, | |
"diagnosis": diag_out, | |
"question": q_out, | |
"treatment": treatment_out, | |
"conversation": history, # Return full history list | |
} | |
# === Gradio UI === | |
def ui_fn(user_input): | |
"""Non-stateful version for testing""" | |
res = simulate_interaction(user_input) | |
return f"""π Vignette Summary: | |
π THINKING: {res['summary']['thinking']} | |
π SUMMARY: {res['summary']['output']} | |
π©Ί Diagnosis: | |
π THINKING: {res['diagnosis']['thinking']} | |
π DIAGNOSIS: {res['diagnosis']['output']} | |
β Follow-up Question: | |
π THINKING: {res['question']['thinking']} | |
π¨ββοΈ DOCTOR: {res['question']['output']} | |
π Treatment Plan: | |
π THINKING: {res['treatment']['thinking']} | |
π TREATMENT: {res['treatment']['output']} | |
π¬ Full Conversation: | |
{chr(10).join(res['conversation'])} | |
""" | |
# === Stateful Gradio UI === | |
def stateful_ui_fn(user_input, history): | |
"""Proper stateful conversation handler""" | |
# Initialize history if first interaction | |
if history is None: | |
history = [] | |
# Run one turn of interaction with accumulated history | |
res = simulate_interaction(user_input, history) | |
# Get the updated conversation history | |
updated_history = res["conversation"] | |
# Format the display output | |
display_output = f"""π¬ Conversation: | |
{chr(10).join(updated_history)} | |
π Current Assessment: | |
π Diagnosis: {res['diagnosis']['output']} | |
π Treatment Plan: {res['treatment']['output']} | |
""" | |
# Return display text and updated history for next turn | |
return display_output, updated_history | |
def chat_interface(user_input, history): | |
"""Alternative chat-style interface""" | |
if history is None: | |
history = [] | |
# Run interaction | |
res = simulate_interaction(user_input, history) | |
updated_history = res["conversation"] | |
# Return just the doctor's latest response and updated history | |
doctor_response = res["question"]["output"] | |
return doctor_response, updated_history | |
# Create two different interfaces | |
demo_stateful = gr.Interface( | |
fn=stateful_ui_fn, | |
inputs=[ | |
gr.Textbox( | |
label="Patient Response", | |
placeholder="Describe your symptoms or answer the doctor's question...", | |
), | |
gr.State(), # holds the conversation history | |
], | |
outputs=[ | |
gr.Textbox(label="Medical Consultation", lines=15), | |
gr.State(), # returns the updated history | |
], | |
title="π§ AI Doctor - Full Medical Consultation", | |
description="Have a conversation with an AI doctor. Each response builds on the previous conversation.", | |
) | |
demo_chat = gr.Interface( | |
fn=chat_interface, | |
inputs=[ | |
gr.Textbox(label="Your Response", placeholder="Tell me about your symptoms..."), | |
gr.State(), | |
], | |
outputs=[ | |
gr.Textbox(label="Doctor", lines=5), | |
gr.State(), | |
], | |
title="π©Ί AI Doctor Chat", | |
description="Simple chat interface with the AI doctor.", | |
) | |
if __name__ == "__main__": | |
# Launch the stateful version by default | |
demo_stateful.launch(share=True) | |
# Uncomment the line below to use the chat version instead: | |
# demo_chat.launch(share=True) | |