medical-test / app.py
CodCodingCode's picture
Initial commit: runtime LFS snapshot download approach
6b29344
raw
history blame
4.81 kB
import os
import torch
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# ——— CONFIG ———
REPO_ID = "CodCodingCode/llama-3.1-8b-clinical"
SUBFOLDER = "checkpoint-45000"
HF_TOKEN = os.environ["HUGGINGFACE_HUB_TOKEN"] # set in Settings→Secrets
# ——— SNAPSHOT & LOAD ———
# This will grab all .json and .safetensors under checkpoint-45000:
local_dir = snapshot_download(
repo_id=REPO_ID,
subfolder=SUBFOLDER,
token=HF_TOKEN,
allow_patterns=["*.json", "*.safetensors"],
)
# Now point at that folder:
MODEL_DIR = local_dir # e.g. ~/.cache/huggingface/…/checkpoint-45000
# Load tokenizer & model from the real files you just pulled:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_DIR,
use_fast=False,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
model.eval()
# === Role Agent with instruction/input/output format ===
class RoleAgent:
def __init__(self, role_instruction):
self.role_instruction = role_instruction
def act(self, input_text):
prompt = (
f"Instruction: {self.role_instruction}\n"
f"Input: {input_text}\n"
f"Output:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# extract THINKING / ANSWER if present
thinking, answer = "", response
if "THINKING:" in response and "ANSWER:" in response and "END" in response:
block = response.split("THINKING:")[1].split("END")[0]
thinking = block.split("ANSWER:")[0].strip()
answer = block.split("ANSWER:")[1].strip()
return {
"instruction": f"You are {self.role_instruction}.",
"input": input_text,
"thinking": thinking,
"output": answer,
}
# === Agents ===
summarizer = RoleAgent(
"You are a clinical summarizer trained to extract structured vignettes from doctor–patient dialogues."
)
diagnoser = RoleAgent(
"You are a board-certified diagnostician that diagnoses patients."
)
questioner = RoleAgent("You are a physician asking questions to diagnose a patient.")
treatment_agent = RoleAgent(
"You are a board-certified clinician. Based on the diagnosis and patient vignette provided below, suggest a concise treatment plan that could realistically be initiated by a primary care physician or psychiatrist."
)
# === Inference State ===
conversation_history = []
summary = ""
diagnosis = ""
# === Gradio Inference ===
def simulate_interaction(user_input, iterations=1):
history = [f"Doctor: What brings you in today?", f"Patient: {user_input}"]
summary, diagnosis = "", ""
for i in range(iterations):
# Summarize
sum_in = "\n".join(history) + f"\nPrevious Vignette: {summary}"
sum_out = summarizer.act(sum_in)
summary = sum_out["output"]
# Diagnose
diag_out = diagnoser.act(summary)
diagnosis = diag_out["output"]
# Question
q_in = f"Vignette: {summary}\nCurrent Estimated Diagnosis: {diag_out['thinking']} {diagnosis}"
q_out = questioner.act(q_in)
history.append(f"Doctor: {q_out['output']}")
history.append("Patient: (awaiting response)")
# Treatment
treatment_out = treatment_agent.act(
f"Diagnosis: {diagnosis}\nVignette: {summary}"
)
return {
"summary": sum_out,
"diagnosis": diag_out,
"question": q_out,
"treatment": treatment_out,
"conversation": "\n".join(history),
}
# === Gradio UI ===
def ui_fn(user_input):
res = simulate_interaction(user_input)
return f"""📋 Vignette Summary:
💭 THINKING: {res['summary']['thinking']}
ANSWER: {res['summary']['output']}
🩺 Diagnosis:
💭 THINKING: {res['diagnosis']['thinking']}
ANSWER: {res['diagnosis']['output']}
T
❓ Follow-up Question:
💭 THINKING: {res['question']['thinking']}
ANSWER: {res['question']['output']}
💊 Treatment Plan:
{res['treatment']['output']}
💬 Conversation:
{res['conversation']}
"""
demo = gr.Interface(
fn=ui_fn,
inputs=gr.Textbox(label="Patient Response"),
outputs=gr.Textbox(label="Doctor Simulation Output"),
title="🧠 AI Doctor Multi-Agent Reasoning",
)
if __name__ == "__main__":
demo.launch(share=True)