techindia2025 commited on
Commit
c66e1bd
·
verified ·
1 Parent(s): 7dd1c93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -100
app.py CHANGED
@@ -1,12 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import spaces
5
- from langchain_community.llms import HuggingFacePipeline
6
- from langchain_core.prompts import PromptTemplate
7
- from langchain.chains import LLMChain
8
- from langchain_core.runnables import RunnableWithMessageHistory
9
- from langchain.memory import ConversationBufferMemory
10
 
11
  # Model configuration
12
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
@@ -38,105 +33,110 @@ Patient information: {patient_info}
38
  <|im_start|>assistant
39
  """
40
 
41
- # Track conversation turns
 
 
 
 
42
  conversation_turns = 0
43
  patient_data = []
44
 
45
- # Create a GPU-decorated function for model loading
46
- @spaces.GPU
47
- def load_models():
48
- print("Loading Llama-2 model...")
49
- llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
50
- llama_model = AutoModelForCausalLM.from_pretrained(
51
- LLAMA_MODEL,
52
- torch_dtype=torch.float16,
53
- device_map="auto"
54
- )
55
-
56
- # Create a pipeline for LangChain
57
- llama_pipeline = pipeline(
58
- "text-generation",
59
- model=llama_model,
60
- tokenizer=llama_tokenizer,
61
- max_new_tokens=512,
62
- temperature=0.7,
63
- top_p=0.9,
64
- do_sample=True
65
- )
66
- llama_llm = HuggingFacePipeline(pipeline=llama_pipeline)
67
- print("Llama-2 model loaded successfully!")
68
-
69
- print("Loading Meditron model...")
70
- meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
71
- meditron_model = AutoModelForCausalLM.from_pretrained(
72
- MEDITRON_MODEL,
73
- torch_dtype=torch.float16,
74
- device_map="auto"
75
- )
76
- # Create a pipeline for Meditron
77
- meditron_pipeline = pipeline(
78
- "text-generation",
79
- model=meditron_model,
80
- tokenizer=meditron_tokenizer,
81
- max_new_tokens=256,
82
- temperature=0.7,
83
- top_p=0.9,
84
- do_sample=True
85
- )
86
- meditron_llm = HuggingFacePipeline(pipeline=meditron_pipeline)
87
- print("Meditron model loaded successfully!")
88
-
89
- return llama_llm, meditron_llm, llama_tokenizer, meditron_tokenizer
90
-
91
- # Load models
92
- llama_llm, meditron_llm, llama_tokenizer, meditron_tokenizer = load_models()
93
 
94
- # Create LangChain conversation with memory
95
- memory = ConversationBufferMemory(return_messages=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Create a template for the Meditron model
98
- meditron_template = PromptTemplate(
99
- input_variables=["patient_info"],
100
- template=MEDITRON_PROMPT
101
- )
102
- meditron_chain = LLMChain(
103
- llm=meditron_llm,
104
- prompt=meditron_template,
105
- verbose=True
106
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  @spaces.GPU
109
  def generate_response(message, history):
 
110
  global conversation_turns, patient_data
 
 
 
 
 
111
  conversation_turns += 1
112
 
113
- # Store patient message
114
  patient_data.append(message)
115
 
116
- # Format the prompt with system instructions
117
- if conversation_turns >= 4:
118
- # Add summarization instruction after 4 turns
119
- prompt = f"{SYSTEM_PROMPT}\n\nNow summarize what you've learned and suggest when professional care may be needed.\n\n{message}"
120
- else:
121
- prompt = f"{SYSTEM_PROMPT}\n\n{message}"
122
-
123
  # Build the prompt with proper Llama-2 formatting
124
- formatted_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"
125
 
126
- # Add conversation history
127
- for user_msg, assistant_msg in history:
128
- formatted_prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
129
 
130
- # Add the current user input
131
- formatted_prompt += f"{message} [/INST] "
132
 
133
- # Generate response using Llama model
134
- inputs = llama_tokenizer(formatted_prompt, return_tensors="pt").to("cuda")
 
135
 
 
136
  with torch.no_grad():
137
- outputs = llama_llm.pipeline.model.generate(
138
- inputs.input_ids,
139
- attention_mask=inputs.attention_mask,
140
  max_new_tokens=512,
141
  temperature=0.7,
142
  top_p=0.9,
@@ -153,20 +153,8 @@ def generate_response(message, history):
153
  # Collect full patient conversation
154
  full_patient_info = "\n".join(patient_data) + "\n\nSummary: " + llama_response
155
 
156
- # Get medicine suggestions using Meditron
157
- inputs = meditron_tokenizer(MEDITRON_PROMPT.format(patient_info=full_patient_info), return_tensors="pt").to("cuda")
158
-
159
- with torch.no_grad():
160
- outputs = meditron_llm.pipeline.model.generate(
161
- inputs.input_ids,
162
- attention_mask=inputs.attention_mask,
163
- max_new_tokens=256,
164
- temperature=0.7,
165
- top_p=0.9,
166
- do_sample=True
167
- )
168
-
169
- medicine_suggestions = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
170
 
171
  # Format final response
172
  final_response = (
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import spaces
 
 
 
 
 
5
 
6
  # Model configuration
7
  LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf"
 
33
  <|im_start|>assistant
34
  """
35
 
36
+ # Global variables to store models (will be loaded lazily)
37
+ llama_model = None
38
+ llama_tokenizer = None
39
+ meditron_model = None
40
+ meditron_tokenizer = None
41
  conversation_turns = 0
42
  patient_data = []
43
 
44
+ def build_llama2_prompt(system_prompt, history, user_input):
45
+ """Format the conversation history and user input for Llama-2 chat models."""
46
+ prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
47
+
48
+ # Add conversation history
49
+ for user_msg, assistant_msg in history:
50
+ prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
51
+
52
+ # Add the current user input
53
+ prompt += f"{user_input} [/INST] "
54
+
55
+ return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ @spaces.GPU
58
+ def load_models_if_needed():
59
+ """Load models only when GPU is available and only if not already loaded."""
60
+ global llama_model, llama_tokenizer, meditron_model, meditron_tokenizer
61
+
62
+ if llama_model is None:
63
+ print("Loading Llama-2 model...")
64
+ llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL)
65
+ llama_model = AutoModelForCausalLM.from_pretrained(
66
+ LLAMA_MODEL,
67
+ torch_dtype=torch.float16,
68
+ device_map="auto"
69
+ )
70
+ print("Llama-2 model loaded successfully!")
71
+
72
+ if meditron_model is None:
73
+ print("Loading Meditron model...")
74
+ meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL)
75
+ meditron_model = AutoModelForCausalLM.from_pretrained(
76
+ MEDITRON_MODEL,
77
+ torch_dtype=torch.float16,
78
+ device_map="auto"
79
+ )
80
+ print("Meditron model loaded successfully!")
81
 
82
+ @spaces.GPU
83
+ def get_meditron_suggestions(patient_info):
84
+ """Use Meditron model to generate medicine and remedy suggestions."""
85
+ load_models_if_needed() # Ensure models are loaded
86
+
87
+ prompt = MEDITRON_PROMPT.format(patient_info=patient_info)
88
+ inputs = meditron_tokenizer(prompt, return_tensors="pt")
89
+
90
+ # Move inputs to the same device as the model
91
+ if torch.cuda.is_available():
92
+ inputs = {k: v.to(meditron_model.device) for k, v in inputs.items()}
93
+
94
+ with torch.no_grad():
95
+ outputs = meditron_model.generate(
96
+ inputs["input_ids"],
97
+ attention_mask=inputs["attention_mask"],
98
+ max_new_tokens=256,
99
+ temperature=0.7,
100
+ top_p=0.9,
101
+ do_sample=True,
102
+ pad_token_id=meditron_tokenizer.eos_token_id
103
+ )
104
+
105
+ suggestion = meditron_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
106
+ return suggestion
107
 
108
  @spaces.GPU
109
  def generate_response(message, history):
110
+ """Generate a response using both models."""
111
  global conversation_turns, patient_data
112
+
113
+ # Load models if needed
114
+ load_models_if_needed()
115
+
116
+ # Track conversation turns
117
  conversation_turns += 1
118
 
119
+ # Store the entire conversation for reference
120
  patient_data.append(message)
121
 
 
 
 
 
 
 
 
122
  # Build the prompt with proper Llama-2 formatting
123
+ prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
124
 
125
+ # Add summarization instruction after 4 turns
126
+ if conversation_turns >= 4:
127
+ prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
128
 
129
+ inputs = llama_tokenizer(prompt, return_tensors="pt")
 
130
 
131
+ # Move inputs to the same device as the model
132
+ if torch.cuda.is_available():
133
+ inputs = {k: v.to(llama_model.device) for k, v in inputs.items()}
134
 
135
+ # Generate the Llama-2 response
136
  with torch.no_grad():
137
+ outputs = llama_model.generate(
138
+ inputs["input_ids"],
139
+ attention_mask=inputs["attention_mask"],
140
  max_new_tokens=512,
141
  temperature=0.7,
142
  top_p=0.9,
 
153
  # Collect full patient conversation
154
  full_patient_info = "\n".join(patient_data) + "\n\nSummary: " + llama_response
155
 
156
+ # Get medicine suggestions
157
+ medicine_suggestions = get_meditron_suggestions(full_patient_info)
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  # Format final response
160
  final_response = (