fixes to incorpoate new metrics
Browse files- healthcare_standards_raft.py +81 -48
healthcare_standards_raft.py
CHANGED
@@ -9,67 +9,55 @@ import torch
|
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
from llama_index.core import VectorStoreIndex, load_index_from_storage
|
11 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
|
12 |
|
13 |
class HealthcareStandardsRAFT:
|
14 |
"""
|
15 |
Healthcare Standards RAFT system that combines RAG and LoRA fine-tuning.
|
16 |
"""
|
17 |
-
|
18 |
def __init__(self, model_path=None, device="cuda" if torch.cuda.is_available() else "cpu"):
|
19 |
"""
|
20 |
Initialize the Healthcare Standards RAFT system.
|
21 |
-
|
22 |
Args:
|
23 |
model_path: Path to model directory or Hugging Face repo name
|
24 |
device: Device to use for inference (cuda/cpu)
|
25 |
"""
|
26 |
-
#
|
27 |
if model_path is None:
|
28 |
-
model_path = "
|
29 |
-
|
30 |
-
self.device = device
|
31 |
self.model_path = model_path
|
32 |
-
|
|
|
|
|
33 |
# Load tokenizer
|
34 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
35 |
-
|
36 |
-
# Load base model and apply
|
37 |
self._load_model()
|
38 |
-
|
39 |
# Load vector index for RAG
|
40 |
self._load_vector_index()
|
41 |
-
|
42 |
def _load_model(self):
|
43 |
-
"""Load base model and apply LoRA weights."""
|
44 |
-
|
|
|
45 |
self.model = AutoModelForCausalLM.from_pretrained(
|
46 |
"microsoft/phi-4-mini-instruct",
|
47 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
48 |
device_map="auto" if self.device == "cuda" else None
|
49 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
if os.path.isdir(self.model_path):
|
53 |
-
# Local directory
|
54 |
-
adapter_path = os.path.join(self.model_path, "model", "adapter_model.bin")
|
55 |
-
else:
|
56 |
-
# Download adapter weights from Hugging Face
|
57 |
-
from huggingface_hub import hf_hub_download
|
58 |
-
adapter_path = hf_hub_download(
|
59 |
-
repo_id=self.model_path,
|
60 |
-
filename="model/adapter_model.bin"
|
61 |
-
)
|
62 |
-
|
63 |
-
# Load LoRA weights
|
64 |
-
if os.path.exists(adapter_path):
|
65 |
-
# Load the weights using PEFT or direct state dict loading
|
66 |
-
# Implementation depends on your specific LoRA setup
|
67 |
-
weights = torch.load(adapter_path, map_location="cpu")
|
68 |
-
# This is a simplified example - you'll need to adapt this
|
69 |
-
# to your specific LoRA implementation
|
70 |
-
self.model.load_state_dict(weights, strict=False)
|
71 |
-
else:
|
72 |
-
print(f"Warning: Adapter weights not found at {adapter_path}")
|
73 |
|
74 |
def _load_vector_index(self):
|
75 |
"""Load the vector index for RAG."""
|
@@ -78,10 +66,8 @@ class HealthcareStandardsRAFT:
|
|
78 |
if os.path.isdir(self.model_path):
|
79 |
index_path = os.path.join(self.model_path, "vector_index")
|
80 |
else:
|
81 |
-
|
82 |
-
|
83 |
-
index_path = "vector_index" # Local cached path
|
84 |
-
# You would need to download the index files here
|
85 |
|
86 |
# Create embedding model
|
87 |
embed_model = HuggingFaceEmbedding(
|
@@ -163,24 +149,71 @@ Question: {question}
|
|
163 |
Answer:"""
|
164 |
|
165 |
# Generate response
|
166 |
-
|
167 |
-
if self.
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
# Generate
|
171 |
with torch.no_grad():
|
172 |
outputs = self.model.generate(
|
173 |
-
inputs
|
|
|
174 |
max_new_tokens=max_tokens,
|
175 |
temperature=temperature,
|
176 |
top_p=0.9,
|
177 |
-
do_sample=temperature > 0
|
|
|
178 |
)
|
179 |
|
180 |
# Decode response
|
181 |
response = self.tokenizer.decode(
|
182 |
-
outputs[0][inputs
|
183 |
skip_special_tokens=True
|
184 |
)
|
185 |
|
186 |
-
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
10 |
from llama_index.core import VectorStoreIndex, load_index_from_storage
|
11 |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
12 |
+
from peft import PeftModel, PeftConfig
|
13 |
|
14 |
class HealthcareStandardsRAFT:
|
15 |
"""
|
16 |
Healthcare Standards RAFT system that combines RAG and LoRA fine-tuning.
|
17 |
"""
|
18 |
+
|
19 |
def __init__(self, model_path=None, device="cuda" if torch.cuda.is_available() else "cpu"):
|
20 |
"""
|
21 |
Initialize the Healthcare Standards RAFT system.
|
22 |
+
|
23 |
Args:
|
24 |
model_path: Path to model directory or Hugging Face repo name
|
25 |
device: Device to use for inference (cuda/cpu)
|
26 |
"""
|
27 |
+
# Handle local fallback if no path is provided
|
28 |
if model_path is None:
|
29 |
+
model_path = "./healthcare-standards-raft"
|
30 |
+
|
|
|
31 |
self.model_path = model_path
|
32 |
+
self.adapter_dir = os.path.join(self.model_path, "model")
|
33 |
+
self.device = device
|
34 |
+
|
35 |
# Load tokenizer
|
36 |
+
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-4-mini-instruct")
|
37 |
+
|
38 |
+
# Load base model and apply LoRA adapter
|
39 |
self._load_model()
|
40 |
+
|
41 |
# Load vector index for RAG
|
42 |
self._load_vector_index()
|
43 |
+
|
44 |
def _load_model(self):
|
45 |
+
"""Load base model and apply LoRA weights using PEFT ."""
|
46 |
+
|
47 |
+
print("Loading base Phi-4-mini model...")
|
48 |
self.model = AutoModelForCausalLM.from_pretrained(
|
49 |
"microsoft/phi-4-mini-instruct",
|
50 |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
51 |
device_map="auto" if self.device == "cuda" else None
|
52 |
)
|
53 |
+
|
54 |
+
|
55 |
+
adapter_dir = os.path.join(self.model_path, "model")
|
56 |
+
print(f"Applying LoRA adapter from: {adapter_dir}")
|
57 |
+
self.model = PeftModel.from_pretrained(self.model,adapter_dir)
|
58 |
+
|
59 |
|
60 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def _load_vector_index(self):
|
63 |
"""Load the vector index for RAG."""
|
|
|
66 |
if os.path.isdir(self.model_path):
|
67 |
index_path = os.path.join(self.model_path, "vector_index")
|
68 |
else:
|
69 |
+
index_path = "healthcare-standards-raft/vector_index" # Local cached path
|
70 |
+
|
|
|
|
|
71 |
|
72 |
# Create embedding model
|
73 |
embed_model = HuggingFaceEmbedding(
|
|
|
149 |
Answer:"""
|
150 |
|
151 |
# Generate response
|
152 |
+
# Ensure pad token is set
|
153 |
+
if self.tokenizer.pad_token is None:
|
154 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
155 |
+
|
156 |
+
# Tokenize
|
157 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
|
158 |
+
|
159 |
+
# Create attention mask
|
160 |
+
inputs["attention_mask"] = (inputs["input_ids"] != self.tokenizer.pad_token_id).long()
|
161 |
+
|
162 |
+
# Move to device
|
163 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
164 |
+
|
165 |
# Generate
|
166 |
with torch.no_grad():
|
167 |
outputs = self.model.generate(
|
168 |
+
inputs["input_ids"],
|
169 |
+
attention_mask=inputs["attention_mask"],
|
170 |
max_new_tokens=max_tokens,
|
171 |
temperature=temperature,
|
172 |
top_p=0.9,
|
173 |
+
do_sample=temperature > 0,
|
174 |
+
pad_token_id=self.tokenizer.pad_token_id # good practice
|
175 |
)
|
176 |
|
177 |
# Decode response
|
178 |
response = self.tokenizer.decode(
|
179 |
+
outputs[0][inputs["input_ids"].shape[1]:],
|
180 |
skip_special_tokens=True
|
181 |
)
|
182 |
|
183 |
+
return response
|
184 |
+
|
185 |
+
|
186 |
+
def get_retrieved_contexts(self, question):
|
187 |
+
"""
|
188 |
+
Get the contexts retrieved for a specific question.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
question (str): The question to retrieve contexts for
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
list: List of retrieved context strings
|
195 |
+
"""
|
196 |
+
try:
|
197 |
+
if hasattr(self, 'index') and self.index is not None:
|
198 |
+
# Use the retriever to get relevant documents
|
199 |
+
retriever = self.index.as_retriever(similarity_top_k=3)
|
200 |
+
nodes = retriever.retrieve(question)
|
201 |
+
|
202 |
+
# Extract text from the retrieved nodes
|
203 |
+
contexts = []
|
204 |
+
for node in nodes:
|
205 |
+
if hasattr(node, 'node') and hasattr(node.node, 'get_content'):
|
206 |
+
contexts.append(node.node.get_content())
|
207 |
+
elif hasattr(node, 'get_content'):
|
208 |
+
contexts.append(node.get_content())
|
209 |
+
elif hasattr(node, 'text'):
|
210 |
+
contexts.append(node.text)
|
211 |
+
else:
|
212 |
+
contexts.append(str(node))
|
213 |
+
|
214 |
+
return contexts
|
215 |
+
else:
|
216 |
+
return ["Vector index not available"]
|
217 |
+
except Exception as e:
|
218 |
+
print(f"Error retrieving contexts: {e}")
|
219 |
+
return [f"Error retrieving contexts: {e}"]
|