Spaces:
Running
Running
File size: 5,369 Bytes
98352eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""
Monkey patch for Finetrainers to support loading existing LoRA weights as training initialization.
This patch extends the SFTTrainer to accept a --pretrained_lora_path argument that allows
starting training from existing LoRA weights instead of random initialization.
"""
import logging
import json
from typing import Optional, Dict, Any
from pathlib import Path
import safetensors.torch
from peft import set_peft_model_state_dict
logger = logging.getLogger(__name__)
# Global flag to track if patch has been applied
_PATCH_APPLIED = False
def _load_pretrained_lora_weights(self, lora_path: str) -> None:
"""Load existing LoRA weights as training initialization
Args:
lora_path: Path to directory containing pytorch_lora_weights.safetensors
"""
lora_path = Path(lora_path)
# Find the safetensors file
safetensors_file = lora_path / "pytorch_lora_weights.safetensors"
if not safetensors_file.exists():
raise FileNotFoundError(f"LoRA weights file not found: {safetensors_file}")
logger.info(f"Loading pretrained LoRA weights from: {safetensors_file}")
try:
# Load the LoRA weights
lora_state_dict = safetensors.torch.load_file(str(safetensors_file))
# Extract metadata if available
metadata = {}
try:
with open(safetensors_file, 'rb') as f:
# Try to read metadata from safetensors header
header_size = int.from_bytes(f.read(8), 'little')
header_data = f.read(header_size)
header = json.loads(header_data.decode('utf-8'))
metadata = header.get('__metadata__', {})
except Exception as e:
logger.debug(f"Could not read metadata from safetensors: {e}")
# Log metadata info if available
if metadata:
logger.info(f"LoRA metadata: rank={metadata.get('rank', 'unknown')}, "
f"alpha={metadata.get('lora_alpha', 'unknown')}")
# Apply the LoRA weights to the model
set_peft_model_state_dict(self.transformer, lora_state_dict)
logger.info(f"Successfully loaded LoRA weights from {safetensors_file}")
# Log the loaded keys for debugging
logger.debug(f"Loaded LoRA keys: {list(lora_state_dict.keys())}")
except Exception as e:
logger.error(f"Failed to load LoRA weights from {safetensors_file}: {e}")
raise RuntimeError(f"Failed to load LoRA weights: {e}")
def patched_prepare_trainable_parameters(self) -> None:
"""Patched version of _prepare_trainable_parameters that supports pretrained LoRA loading"""
# Call the original method first
original_prepare_trainable_parameters(self)
# Check if pretrained LoRA path is provided
if hasattr(self.args, 'pretrained_lora_path') and self.args.pretrained_lora_path:
logger.info(f"Pretrained LoRA path specified: {self.args.pretrained_lora_path}")
# Only load if we're doing LoRA training
if hasattr(self.args, 'training_type') and str(self.args.training_type) == 'TrainingType.LORA':
self._load_pretrained_lora_weights(self.args.pretrained_lora_path)
else:
logger.warning("pretrained_lora_path specified but training_type is not LORA")
def apply_lora_loading_patch() -> None:
"""Apply the monkey patch to enable LoRA weight loading in Finetrainers"""
global _PATCH_APPLIED
if _PATCH_APPLIED:
logger.debug("Finetrainers LoRA loading patch already applied")
return
try:
from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
# Store reference to original method
global original_prepare_trainable_parameters
original_prepare_trainable_parameters = SFTTrainer._prepare_trainable_parameters
# Apply patches
SFTTrainer._prepare_trainable_parameters = patched_prepare_trainable_parameters
SFTTrainer._load_pretrained_lora_weights = _load_pretrained_lora_weights
_PATCH_APPLIED = True
logger.info("Successfully applied Finetrainers LoRA loading patch")
except ImportError as e:
logger.error(f"Failed to import Finetrainers classes for patching: {e}")
raise
except Exception as e:
logger.error(f"Failed to apply Finetrainers LoRA loading patch: {e}")
raise
def remove_lora_loading_patch() -> None:
"""Remove the monkey patch (for testing purposes)"""
global _PATCH_APPLIED
if not _PATCH_APPLIED:
return
try:
from finetrainers.trainer.sft_trainer.trainer import SFTTrainer
# Restore original method
SFTTrainer._prepare_trainable_parameters = original_prepare_trainable_parameters
# Remove added method
if hasattr(SFTTrainer, '_load_pretrained_lora_weights'):
delattr(SFTTrainer, '_load_pretrained_lora_weights')
_PATCH_APPLIED = False
logger.info("Removed Finetrainers LoRA loading patch")
except Exception as e:
logger.error(f"Failed to remove Finetrainers LoRA loading patch: {e}")
# Store reference to original method (will be set when patch is applied)
original_prepare_trainable_parameters = None |