venkatviswa commited on
Commit
1e29a5d
·
verified ·
1 Parent(s): ed8fcd4

fixes to incorpoate new metrics

Browse files
Files changed (1) hide show
  1. healthcare_standards_raft.py +81 -48
healthcare_standards_raft.py CHANGED
@@ -9,67 +9,55 @@ import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from llama_index.core import VectorStoreIndex, load_index_from_storage
11
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
 
12
 
13
  class HealthcareStandardsRAFT:
14
  """
15
  Healthcare Standards RAFT system that combines RAG and LoRA fine-tuning.
16
  """
17
-
18
  def __init__(self, model_path=None, device="cuda" if torch.cuda.is_available() else "cpu"):
19
  """
20
  Initialize the Healthcare Standards RAFT system.
21
-
22
  Args:
23
  model_path: Path to model directory or Hugging Face repo name
24
  device: Device to use for inference (cuda/cpu)
25
  """
26
- # Use this repo if no model path provided
27
  if model_path is None:
28
- model_path = "venkatviswa/healthcare-standards-raft"
29
-
30
- self.device = device
31
  self.model_path = model_path
32
-
 
 
33
  # Load tokenizer
34
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
35
-
36
- # Load base model and apply weights
37
  self._load_model()
38
-
39
  # Load vector index for RAG
40
  self._load_vector_index()
41
-
42
  def _load_model(self):
43
- """Load base model and apply LoRA weights."""
44
- # Load base model
 
45
  self.model = AutoModelForCausalLM.from_pretrained(
46
  "microsoft/phi-4-mini-instruct",
47
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
48
  device_map="auto" if self.device == "cuda" else None
49
  )
 
 
 
 
 
 
50
 
51
- # Check if model_path is local directory or Hugging Face repo
52
- if os.path.isdir(self.model_path):
53
- # Local directory
54
- adapter_path = os.path.join(self.model_path, "model", "adapter_model.bin")
55
- else:
56
- # Download adapter weights from Hugging Face
57
- from huggingface_hub import hf_hub_download
58
- adapter_path = hf_hub_download(
59
- repo_id=self.model_path,
60
- filename="model/adapter_model.bin"
61
- )
62
-
63
- # Load LoRA weights
64
- if os.path.exists(adapter_path):
65
- # Load the weights using PEFT or direct state dict loading
66
- # Implementation depends on your specific LoRA setup
67
- weights = torch.load(adapter_path, map_location="cpu")
68
- # This is a simplified example - you'll need to adapt this
69
- # to your specific LoRA implementation
70
- self.model.load_state_dict(weights, strict=False)
71
- else:
72
- print(f"Warning: Adapter weights not found at {adapter_path}")
73
 
74
  def _load_vector_index(self):
75
  """Load the vector index for RAG."""
@@ -78,10 +66,8 @@ class HealthcareStandardsRAFT:
78
  if os.path.isdir(self.model_path):
79
  index_path = os.path.join(self.model_path, "vector_index")
80
  else:
81
- # Download index from Hugging Face if needed
82
- # This is a simplified implementation
83
- index_path = "vector_index" # Local cached path
84
- # You would need to download the index files here
85
 
86
  # Create embedding model
87
  embed_model = HuggingFaceEmbedding(
@@ -163,24 +149,71 @@ Question: {question}
163
  Answer:"""
164
 
165
  # Generate response
166
- inputs = self.tokenizer(prompt, return_tensors="pt")
167
- if self.device == "cuda":
168
- inputs = inputs.to("cuda")
169
-
 
 
 
 
 
 
 
 
 
170
  # Generate
171
  with torch.no_grad():
172
  outputs = self.model.generate(
173
- inputs.input_ids,
 
174
  max_new_tokens=max_tokens,
175
  temperature=temperature,
176
  top_p=0.9,
177
- do_sample=temperature > 0
 
178
  )
179
 
180
  # Decode response
181
  response = self.tokenizer.decode(
182
- outputs[0][inputs.input_ids.shape[1]:],
183
  skip_special_tokens=True
184
  )
185
 
186
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
  from llama_index.core import VectorStoreIndex, load_index_from_storage
11
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from peft import PeftModel, PeftConfig
13
 
14
  class HealthcareStandardsRAFT:
15
  """
16
  Healthcare Standards RAFT system that combines RAG and LoRA fine-tuning.
17
  """
18
+
19
  def __init__(self, model_path=None, device="cuda" if torch.cuda.is_available() else "cpu"):
20
  """
21
  Initialize the Healthcare Standards RAFT system.
22
+
23
  Args:
24
  model_path: Path to model directory or Hugging Face repo name
25
  device: Device to use for inference (cuda/cpu)
26
  """
27
+ # Handle local fallback if no path is provided
28
  if model_path is None:
29
+ model_path = "./healthcare-standards-raft"
30
+
 
31
  self.model_path = model_path
32
+ self.adapter_dir = os.path.join(self.model_path, "model")
33
+ self.device = device
34
+
35
  # Load tokenizer
36
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-4-mini-instruct")
37
+
38
+ # Load base model and apply LoRA adapter
39
  self._load_model()
40
+
41
  # Load vector index for RAG
42
  self._load_vector_index()
43
+
44
  def _load_model(self):
45
+ """Load base model and apply LoRA weights using PEFT ."""
46
+
47
+ print("Loading base Phi-4-mini model...")
48
  self.model = AutoModelForCausalLM.from_pretrained(
49
  "microsoft/phi-4-mini-instruct",
50
  torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
51
  device_map="auto" if self.device == "cuda" else None
52
  )
53
+
54
+
55
+ adapter_dir = os.path.join(self.model_path, "model")
56
+ print(f"Applying LoRA adapter from: {adapter_dir}")
57
+ self.model = PeftModel.from_pretrained(self.model,adapter_dir)
58
+
59
 
60
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def _load_vector_index(self):
63
  """Load the vector index for RAG."""
 
66
  if os.path.isdir(self.model_path):
67
  index_path = os.path.join(self.model_path, "vector_index")
68
  else:
69
+ index_path = "healthcare-standards-raft/vector_index" # Local cached path
70
+
 
 
71
 
72
  # Create embedding model
73
  embed_model = HuggingFaceEmbedding(
 
149
  Answer:"""
150
 
151
  # Generate response
152
+ # Ensure pad token is set
153
+ if self.tokenizer.pad_token is None:
154
+ self.tokenizer.pad_token = self.tokenizer.eos_token
155
+
156
+ # Tokenize
157
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
158
+
159
+ # Create attention mask
160
+ inputs["attention_mask"] = (inputs["input_ids"] != self.tokenizer.pad_token_id).long()
161
+
162
+ # Move to device
163
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
164
+
165
  # Generate
166
  with torch.no_grad():
167
  outputs = self.model.generate(
168
+ inputs["input_ids"],
169
+ attention_mask=inputs["attention_mask"],
170
  max_new_tokens=max_tokens,
171
  temperature=temperature,
172
  top_p=0.9,
173
+ do_sample=temperature > 0,
174
+ pad_token_id=self.tokenizer.pad_token_id # good practice
175
  )
176
 
177
  # Decode response
178
  response = self.tokenizer.decode(
179
+ outputs[0][inputs["input_ids"].shape[1]:],
180
  skip_special_tokens=True
181
  )
182
 
183
+ return response
184
+
185
+
186
+ def get_retrieved_contexts(self, question):
187
+ """
188
+ Get the contexts retrieved for a specific question.
189
+
190
+ Args:
191
+ question (str): The question to retrieve contexts for
192
+
193
+ Returns:
194
+ list: List of retrieved context strings
195
+ """
196
+ try:
197
+ if hasattr(self, 'index') and self.index is not None:
198
+ # Use the retriever to get relevant documents
199
+ retriever = self.index.as_retriever(similarity_top_k=3)
200
+ nodes = retriever.retrieve(question)
201
+
202
+ # Extract text from the retrieved nodes
203
+ contexts = []
204
+ for node in nodes:
205
+ if hasattr(node, 'node') and hasattr(node.node, 'get_content'):
206
+ contexts.append(node.node.get_content())
207
+ elif hasattr(node, 'get_content'):
208
+ contexts.append(node.get_content())
209
+ elif hasattr(node, 'text'):
210
+ contexts.append(node.text)
211
+ else:
212
+ contexts.append(str(node))
213
+
214
+ return contexts
215
+ else:
216
+ return ["Vector index not available"]
217
+ except Exception as e:
218
+ print(f"Error retrieving contexts: {e}")
219
+ return [f"Error retrieving contexts: {e}"]