File size: 5,015 Bytes
cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 eb57a64 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 b045d61 203001d b045d61 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 8af4e75 cb3a670 c044359 8af4e75 8b507f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import os
import warnings
from typing import *
from dotenv import load_dotenv
from transformers import logging
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from interface import create_demo
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv()
def initialize_agent(
prompt_file,
tools_to_use=None,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
model="gpt-4o-mini",
temperature=0.7,
top_p=0.95,
openai_kwargs={}
):
"""Initialize the MedRAX agent with specified tools and configuration.
Args:
prompt_file (str): Path to file containing system prompts
tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
device (str, optional): Device to run models on. Defaults to "cuda".
model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
temperature (float, optional): Temperature for the model. Defaults to 0.7.
top_p (float, optional): Top P for the model. Defaults to 0.95.
openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
Returns:
Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
"""
prompts = load_prompts_from_file(prompt_file)
prompt = prompts["MEDICAL_ASSISTANT"]
all_tools = {
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
cache_dir=model_dir, device=device
),
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
),
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
),
"ImageVisualizerTool": lambda: ImageVisualizerTool(),
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
}
# Initialize only selected tools or all if none specified
tools_dict = {}
tools_to_use = tools_to_use or all_tools.keys()
for tool_name in tools_to_use:
if tool_name in all_tools:
tools_dict[tool_name] = all_tools[tool_name]()
checkpointer = MemorySaver()
model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
agent = Agent(
model,
tools=list(tools_dict.values()),
log_tools=True,
log_dir="logs",
system_prompt=prompt,
checkpointer=checkpointer,
)
print("Agent initialized")
return agent, tools_dict
if __name__ == "__main__":
"""
This is the main entry point for the MedRAX application.
It initializes the agent with the selected tools and creates the demo.
"""
print("Starting server...")
# Example: initialize with only specific tools
# Here three tools are commented out, you can uncomment them to use them
selected_tools = [
"ImageVisualizerTool",
"DicomProcessorTool",
"ChestXRayClassifierTool",
"ChestXRaySegmentationTool",
"ChestXRayReportGeneratorTool",
"XRayVQATool",
# "LlavaMedTool",
# "XRayPhraseGroundingTool",
# "ChestXRayGeneratorTool",
]
# Collect the ENV variables
openai_kwargs = {}
if api_key := os.getenv("OPENAI_API_KEY"):
openai_kwargs["api_key"] = api_key
if base_url := os.getenv("OPENAI_BASE_URL"):
openai_kwargs["base_url"] = base_url
agent, tools_dict = initialize_agent(
"medrax/docs/system_prompts.txt",
tools_to_use=selected_tools,
model_dir="./model-weights", # Change this to the path of the model weights
temp_dir="temp", # Change this to the path of the temporary directory
device="cuda", # Change this to the device you want to use
model="gpt-4o-mini", # Change this to the model you want to use, e.g. gpt-4o-mini
temperature=0.7,
top_p=0.95,
openai_kwargs=openai_kwargs
)
demo = create_demo(agent, tools_dict)
# demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
demo.launch(debug=True, ssr_mode=False)
|