VoiceMark / infer.py
ordinaryaccount's picture
Add file
26f400e
from models import SBW
import torchaudio
import torch
import os
def hamming_distance(s1, s2, base=2):
"""
Calculate the absolute difference between two strings interpreted in chunks of a specified base.
Args:
s1 (str): The first binary string.
s2 (str): The second binary string.
base (int): The base to interpret the binary strings (e.g., 2 for binary, 16 for hexadecimal).
Returns:
int: The sum of absolute differences for corresponding chunks.
Raises:
ValueError: If the strings are not of equal length or invalid for the given base.
"""
if len(s1) != len(s2):
raise ValueError("Both strings must be of equal length")
# Determine the chunk size for the given base
import math
chunk_size = int(math.log(base, 2))
if len(s1) % chunk_size != 0 or len(s2) % chunk_size != 0:
raise ValueError(
f"Binary strings must be a multiple of {chunk_size} bits for base {base}"
)
# Split binary strings into chunks
def to_chunks(binary_str):
return [
binary_str[i : i + chunk_size]
for i in range(0, len(binary_str), chunk_size)
]
chunks_s1 = to_chunks(s1)
chunks_s2 = to_chunks(s2)
# Convert chunks to integers and calculate absolute difference
def absolute_difference_chunks(chunk1, chunk2):
int1 = int(chunk1, 2)
int2 = int(chunk2, 2)
return abs(int1 - int2)
return sum(
absolute_difference_chunks(c1, c2) for c1, c2 in zip(chunks_s1, chunks_s2)
)
def random_message(nbits: int, batch_size: int) -> torch.Tensor:
"""Return random message as 0/1 tensor."""
if nbits == 0:
return torch.tensor([])
return torch.randint(0, 2, (batch_size, nbits))
def string_to_message(message_str: str, batch_size: int) -> torch.Tensor:
"""
Convert a binary string to a message tensor.
Args:
message_str (str): A string of '0's and '1's.
batch_size (int): The batch size for the output tensor.
Returns:
torch.Tensor: A tensor of shape (batch_size, nbits) where nbits is the length of message_str.
"""
# Convert the string to a list of integers (0 and 1)
message_list = [int(bit) for bit in message_str]
nbits = len(message_list)
# Convert the list to a tensor and repeat it for the batch size
message_tensor = torch.tensor(message_list).unsqueeze(0).repeat(batch_size, 1)
return message_tensor
class WatermarkSolver:
def __init__(self):
self.device = 'cpu'
self.sample_rate = 16000
self.nbits = 16
self.model = SBW()
def load_model(
self, checkpoint_dir, checkpoint_name, strict=False
):
"""
Load the latest model weights from the checkpoint directory.
"""
checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
if not checkpoint_files:
print("No checkpoint files found in the directory.")
return
checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
print(f"Loading model weights from {checkpoint_path}...")
# Load the checkpoint
checkpoint = torch.load(checkpoint_path,map_location=torch.device('cpu'),weights_only=False)
# Load model state dict with strict=False to allow missing keys (semantic_encoder)
model_state_dict = checkpoint["model_state_dict"]
new_state_dict = {}
for k, v in model_state_dict.items():
new_key = k.replace("module.", "")
new_state_dict[new_key] = v
if not strict:
new_state_dict = {
k: v
for k, v in new_state_dict.items()
if k in self.model.state_dict()
and self.model.state_dict()[k].shape == v.shape
}
self.model.load_state_dict(new_state_dict, strict=False)
print("Model state dict loaded successfully.")
self.epoch = checkpoint["epoch"]
print(f"Checkpoint loaded from epoch {checkpoint['epoch']}.")
def infer_for_ui(self, input_audio_path, watermark, output_audio_path="tmp"):
message_str = watermark
self.model.eval()
waveform, sample_rate = torchaudio.load(input_audio_path)
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
)
waveform = resampler(waveform)
waveform = waveform[:1].to(self.device).unsqueeze(1)
message = string_to_message(message_str=message_str, batch_size=1).to(
self.device
)
with torch.no_grad():
output = self.model(
waveform,
message=message,
)
y_wm = output["recon_wm"]
os.makedirs(output_audio_path, exist_ok=True)
watermarked_audio_path = os.path.join(
output_audio_path,
f"{os.path.splitext(os.path.basename(input_audio_path))[0]}_watermarked.wav",
)
torchaudio.save(watermarked_audio_path, y_wm.squeeze(1).cpu(), self.sample_rate)
return watermarked_audio_path
def decode_for_ui(self, input_audio_path):
self.model.eval()
waveform, sample_rate = torchaudio.load(input_audio_path)
if sample_rate != self.sample_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.sample_rate
)
waveform = resampler(waveform)
waveform = waveform[:1].to(self.device).unsqueeze(1)
with torch.no_grad():
detect_prob, detected_message_tensor, _ = self.model.detect_watermark(waveform)
detected_id = "".join(
map(str, detected_message_tensor.squeeze(0).cpu().numpy().tolist())
)
return detect_prob,detected_id