Dhruv-Ty commited on
Commit
8b507f5
Β·
verified Β·
1 Parent(s): 67192a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  import warnings
3
  from typing import *
4
- from dotenv import load_dotenv
5
- from transformers import logging
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
- from langchain_openai import ChatOpenAI
9
 
10
  from interface import create_demo
11
  from medrax.agent import *
@@ -13,21 +12,6 @@ from medrax.tools import *
13
  from medrax.utils import *
14
 
15
  warnings.filterwarnings("ignore")
16
- logging.set_verbosity_error()
17
- load_dotenv()
18
-
19
- # Set environment variables explicitly to ensure they're available
20
- api_key = os.getenv("OPENAI_API_KEY")
21
- base_url = os.getenv("OPENAI_BASE_URL")
22
-
23
- if not api_key:
24
- raise ValueError("OPENAI_API_KEY not found in environment variables")
25
- if not base_url:
26
- raise ValueError("OPENAI_BASE_URL not found in environment variables")
27
-
28
- # Set them in environment for libraries that might read directly from os.environ
29
- os.environ["OPENAI_API_KEY"] = api_key
30
- os.environ["OPENAI_BASE_URL"] = base_url
31
 
32
  def initialize_agent(
33
  prompt_file,
@@ -35,7 +19,6 @@ def initialize_agent(
35
  model_dir="./model-weights",
36
  temp_dir="temp",
37
  device="cuda",
38
- model="qwen/qwen2.5-vl-3b-instruct:free",
39
  temperature=0.7,
40
  top_p=0.95
41
  ):
@@ -69,16 +52,24 @@ def initialize_agent(
69
  tools_dict[tool_name] = all_tools[tool_name]()
70
 
71
  checkpointer = MemorySaver()
72
-
73
- # Explicitly pass the API key and base URL
74
- model = ChatOpenAI(
75
- model_name=model,
76
- api_key=api_key,
77
- base_url=base_url,
 
 
 
 
 
78
  temperature=temperature,
79
  top_p=top_p,
 
80
  )
81
 
 
 
82
  agent = Agent(
83
  model,
84
  tools=list(tools_dict.values()),
@@ -113,12 +104,9 @@ if __name__ == "__main__":
113
  model_dir="./model-weights",
114
  temp_dir="temp",
115
  device="cuda",
116
- model="qwen/qwen2.5-vl-3b-instruct:free",
117
  temperature=0.7,
118
  top_p=0.95
119
  )
120
 
121
  demo = create_demo(agent, tools_dict)
122
- # demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
123
- # demo.launch(debug=True, queue=True, ssr_mode=False)
124
- demo.launch(debug=True, ssr_mode=False)
 
1
  import os
2
  import warnings
3
  from typing import *
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ from langchain.llms import HuggingFacePipeline
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
 
8
 
9
  from interface import create_demo
10
  from medrax.agent import *
 
12
  from medrax.utils import *
13
 
14
  warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def initialize_agent(
17
  prompt_file,
 
19
  model_dir="./model-weights",
20
  temp_dir="temp",
21
  device="cuda",
 
22
  temperature=0.7,
23
  top_p=0.95
24
  ):
 
52
  tools_dict[tool_name] = all_tools[tool_name]()
53
 
54
  checkpointer = MemorySaver()
55
+
56
+ # Load local Hugging Face model
57
+ hf_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
58
+ tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
59
+ raw_model = AutoModelForCausalLM.from_pretrained(hf_model_id, device_map="auto")
60
+
61
+ pipe = pipeline(
62
+ "text-generation",
63
+ model=raw_model,
64
+ tokenizer=tokenizer,
65
+ max_new_tokens=512,
66
  temperature=temperature,
67
  top_p=top_p,
68
+ return_full_text=False,
69
  )
70
 
71
+ model = HuggingFacePipeline(pipeline=pipe)
72
+
73
  agent = Agent(
74
  model,
75
  tools=list(tools_dict.values()),
 
104
  model_dir="./model-weights",
105
  temp_dir="temp",
106
  device="cuda",
 
107
  temperature=0.7,
108
  top_p=0.95
109
  )
110
 
111
  demo = create_demo(agent, tools_dict)
112
+ demo.launch(debug=True, ssr_mode=False)