File size: 5,971 Bytes
26f400e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from models import SBW
import torchaudio
import torch
import os
def hamming_distance(s1, s2, base=2):
"""
Calculate the absolute difference between two strings interpreted in chunks of a specified base.
Args:
s1 (str): The first binary string.
s2 (str): The second binary string.
base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal).
Returns:
int: The sum of absolute differences for corresponding chunks.
Raises:
ValueError: If the strings are not of equal length or invalid for the given base.
"""
if len(s1) != len(s2):
raise ValueError("Both strings must be of equal length")
# Determine the chunk size for the given base
import math
chunk_size = int(math.log(base, 2))
if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0:
raise ValueError(
f"Binary strings must be a multiple of {chunk_size} bits for base {base}"
)
# Split binary strings into chunks
def to_chunks(binary_str):
return [
binary_str[i : i + chunk_size]
for i in range(0, len(binary_str), chunk_size)
]
chunks_s1 = to_chunks(s1)
chunks_s2 = to_chunks(s2)
# Convert chunks to integers and calculate absolute difference
def absolute_difference_chunks(chunk1, chunk2):
int1 = int(chunk1, 2)
int2 = int(chunk2, 2)
return abs(int1 - int2)
return sum(
absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2)
)
def random_message(nbits: int, batch_size: int) -> torch.Tensor:
"""Return random message as 0/1 tensor."""
if nbits == 0:
return torch.tensor([])
return torch.randint(0, 2, (batch_size, nbits))
def string_to_message(message_str: str, batch_size: int) -> torch.Tensor:
"""
Convert a binary string to a message tensor.
Args:
message_str (str): A string of '0's and '1's.
batch_size (int): The batch size for the output tensor.
Returns:
torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str.
"""
# Convert the string to a list of integers (0 and 1)
message_list = [int(bit) for bit in message_str]
nbits = len(message_list)
# Convert the list to a tensor and repeat it for the batch size
message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1)
return message_tensor
class WatermarkSolver:
def __init__(self):
self.device = 'cpu'
self.sample_rate = 16000
self.nbits = 16
self.model = SBW()
def load_model(
self, checkpoint_dir, checkpoint_name, strict=False
):
"""
Load the latest model weights from the checkpoint directory.
"""
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
if not checkpoint_files:
print("No checkpoint files found in the directory.")
return
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
print(f"Loading model weights from {checkpoint_path}...")
# Load the checkpoint
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False)
# Load model state dict with strict=False to allow missing keys (semantic_encoder)
model_state_dict = checkpoint["model_state_dict"]
new_state_dict = {}
for k, v in model_state_dict.items():
new_key = k.replace("module.", "")
new_state_dict[new_key] = v
if not strict:
new_state_dict = {
k: v
for k, v in new_state_dict.items()
if k in self.model.state_dict()
and self.model.state_dict()[k].shape == v.shape
}
self.model.load_state_dict(new_state_dict, strict=False)
print("Model state dict loaded successfully.")
self.epoch = checkpoint["epoch"]
print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.")
def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"):
message_str = watermark
self.model.eval()
waveform, sample_rate = torchaudio.load(input_audio_path)
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
)
waveform = resampler(waveform)
waveform = waveform[:1].to(self.device).unsqueeze(1)
message = string_to_message(message_str=message_str, batch_size=1).to(
self.device
)
with torch.no_grad():
output = self.model(
waveform,
message=message,
)
y_wm = output["recon_wm"]
os.makedirs(output_audio_path, exist_ok=True)
watermarked_audio_path = os.path.join(
output_audio_path,
f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav",
)
torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate)
return watermarked_audio_path
def decode_for_ui(self, input_audio_path):
self.model.eval()
waveform, sample_rate = torchaudio.load(input_audio_path)
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
)
waveform = resampler(waveform)
waveform = waveform[:1].to(self.device).unsqueeze(1)
with torch.no_grad():
detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform)
detected_id = "".join(
map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist())
)
return detect_prob,detected_id
|