|
from models import SBW |
|
import torchaudio |
|
import torch |
|
import os |
|
|
|
def hamming_distance(s1, s2, base=2): |
|
""" |
|
Calculate the absolute difference between two strings interpreted in chunks of a specified base. |
|
|
|
Args: |
|
s1 (str): The first binary string. |
|
s2 (str): The second binary string. |
|
base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal). |
|
|
|
Returns: |
|
int: The sum of absolute differences for corresponding chunks. |
|
|
|
Raises: |
|
ValueError: If the strings are not of equal length or invalid for the given base. |
|
""" |
|
if len(s1) != len(s2): |
|
raise ValueError("Both strings must be of equal length") |
|
|
|
|
|
import math |
|
|
|
chunk_size = int(math.log(base, 2)) |
|
|
|
if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0: |
|
raise ValueError( |
|
f"Binary strings must be a multiple of {chunk_size} bits for base {base}" |
|
) |
|
|
|
|
|
def to_chunks(binary_str): |
|
return [ |
|
binary_str[i : i + chunk_size] |
|
for i in range(0, len(binary_str), chunk_size) |
|
] |
|
|
|
chunks_s1 = to_chunks(s1) |
|
chunks_s2 = to_chunks(s2) |
|
|
|
|
|
def absolute_difference_chunks(chunk1, chunk2): |
|
int1 = int(chunk1, 2) |
|
int2 = int(chunk2, 2) |
|
return abs(int1 - int2) |
|
|
|
return sum( |
|
absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2) |
|
) |
|
|
|
def random_message(nbits: int, batch_size: int) -> torch.Tensor: |
|
"""Return random message as 0/1 tensor.""" |
|
if nbits == 0: |
|
return torch.tensor([]) |
|
return torch.randint(0, 2, (batch_size, nbits)) |
|
|
|
|
|
def string_to_message(message_str: str, batch_size: int) -> torch.Tensor: |
|
""" |
|
Convert a binary string to a message tensor. |
|
|
|
Args: |
|
message_str (str): A string of '0's and '1's. |
|
batch_size (int): The batch size for the output tensor. |
|
|
|
Returns: |
|
torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str. |
|
""" |
|
|
|
message_list = [int(bit) for bit in message_str] |
|
nbits = len(message_list) |
|
|
|
message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1) |
|
return message_tensor |
|
|
|
|
|
class WatermarkSolver: |
|
def __init__(self): |
|
self.device = 'cpu' |
|
self.sample_rate = 16000 |
|
self.nbits = 16 |
|
self.model = SBW() |
|
|
|
def load_model( |
|
self, checkpoint_dir, checkpoint_name, strict=False |
|
): |
|
""" |
|
Load the latest model weights from the checkpoint directory. |
|
""" |
|
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")] |
|
if not checkpoint_files: |
|
print("No checkpoint files found in the directory.") |
|
return |
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) |
|
print(f"Loading model weights from {checkpoint_path}...") |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False) |
|
|
|
|
|
model_state_dict = checkpoint["model_state_dict"] |
|
|
|
new_state_dict = {} |
|
for k, v in model_state_dict.items(): |
|
new_key = k.replace("module.", "") |
|
new_state_dict[new_key] = v |
|
if not strict: |
|
new_state_dict = { |
|
k: v |
|
for k, v in new_state_dict.items() |
|
if k in self.model.state_dict() |
|
and self.model.state_dict()[k].shape == v.shape |
|
} |
|
self.model.load_state_dict(new_state_dict, strict=False) |
|
|
|
print("Model state dict loaded successfully.") |
|
self.epoch = checkpoint["epoch"] |
|
print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.") |
|
|
|
def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"): |
|
message_str = watermark |
|
self.model.eval() |
|
|
|
waveform, sample_rate = torchaudio.load(input_audio_path) |
|
|
|
if sample_rate != self.sample_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, new_freq=self.sample_rate |
|
) |
|
waveform = resampler(waveform) |
|
|
|
waveform = waveform[:1].to(self.device).unsqueeze(1) |
|
message = string_to_message(message_str=message_str, batch_size=1).to( |
|
self.device |
|
) |
|
|
|
with torch.no_grad(): |
|
output = self.model( |
|
waveform, |
|
message=message, |
|
) |
|
y_wm = output["recon_wm"] |
|
|
|
os.makedirs(output_audio_path, exist_ok=True) |
|
watermarked_audio_path = os.path.join( |
|
output_audio_path, |
|
f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav", |
|
) |
|
torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate) |
|
return watermarked_audio_path |
|
|
|
def decode_for_ui(self, input_audio_path): |
|
self.model.eval() |
|
|
|
waveform, sample_rate = torchaudio.load(input_audio_path) |
|
|
|
if sample_rate != self.sample_rate: |
|
resampler = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, new_freq=self.sample_rate |
|
) |
|
waveform = resampler(waveform) |
|
|
|
waveform = waveform[:1].to(self.device).unsqueeze(1) |
|
|
|
with torch.no_grad(): |
|
detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform) |
|
detected_id = "".join( |
|
map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist()) |
|
) |
|
|
|
return detect_prob,detected_id |
|
|