|
import torch |
|
import torch.nn as nn |
|
from typing import Dict, Optional, Tuple |
|
|
|
import torch.nn.functional as F |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class WMDetector(nn.Module): |
|
""" |
|
Detect watermarks in an audio signal using a Transformer architecture, |
|
where the watermark bits are split into bytes (8 bits each). |
|
We assume nbits is a multiple of 8. |
|
""" |
|
|
|
def __init__( |
|
self, input_channels: int, nbits: int, nchunk_size: int, d_model: int = 512 |
|
): |
|
""" |
|
Args: |
|
input_channels (int): Number of input channels in the audio feature (e.g., mel channels). |
|
nbits (int): Total number of bits in the watermark, must be a multiple of 8. |
|
d_model (int): Embedding dimension for the Transformer. |
|
""" |
|
super().__init__() |
|
self.nchunk_size = nchunk_size |
|
assert nbits % nchunk_size == 0, "nbits must be a multiple of 8!" |
|
self.nbits = nbits |
|
self.d_model = d_model |
|
|
|
self.nchunks = nbits // nchunk_size |
|
|
|
|
|
self.embedding = nn.Conv1d(input_channels, d_model, kernel_size=1) |
|
|
|
|
|
self.transformer = nn.TransformerEncoder( |
|
nn.TransformerEncoderLayer( |
|
d_model=d_model, |
|
nhead=1, |
|
dim_feedforward=d_model * 2, |
|
activation="gelu", |
|
batch_first=True, |
|
), |
|
num_layers=8, |
|
) |
|
|
|
|
|
self.watermark_head = nn.Linear(d_model, 1) |
|
|
|
|
|
self.message_heads = nn.ModuleList( |
|
nn.Linear(d_model, 2**nchunk_size) for _ in range(self.nchunks) |
|
) |
|
|
|
|
|
|
|
self.nchunk_embeddings = nn.Parameter(torch.randn(self.nchunks, d_model)) |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass of the detector. |
|
|
|
Returns: |
|
logits (torch.Tensor): Watermark detection logits of shape [batch, seq_len]. |
|
chunk_logits (torch.Tensor): Byte-level classification logits of shape [batch, nchunks, 256]. |
|
""" |
|
batch_size, input_channels, time_steps = x.shape |
|
|
|
|
|
x = self.embedding(x).permute(0, 2, 1) |
|
|
|
|
|
|
|
nchunk_embeds = self.nchunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
|
x = torch.cat([nchunk_embeds, x], dim=1) |
|
|
|
|
|
x = self.transformer(x) |
|
|
|
|
|
|
|
detection_part = x[:, self.nchunks :] |
|
logits = self.watermark_head(detection_part).squeeze(-1) |
|
|
|
|
|
message_part = x[:, : self.nchunks] |
|
chunk_logits_list = [] |
|
for i, head in enumerate(self.message_heads): |
|
|
|
|
|
chunk_vec = message_part[:, i, :] |
|
chunk_logits_list.append(head(chunk_vec).unsqueeze(1)) |
|
|
|
|
|
chunk_logits = torch.cat(chunk_logits_list, dim=1) |
|
|
|
return logits, chunk_logits |
|
|
|
def detect_watermark( |
|
self, |
|
x: torch.Tensor, |
|
sample_rate: Optional[int] = None, |
|
threshold: float = 0.5, |
|
) -> Tuple[float, torch.Tensor, torch.Tensor]: |
|
""" |
|
A convenience function for inference. |
|
|
|
Returns: |
|
detect_prob (float): Probability that the audio is watermarked. |
|
binary_message (torch.Tensor): The recovered message of shape [batch, nbits] (binary). |
|
detected (torch.Tensor): The sigmoid values of the per-timestep watermark detection. |
|
""" |
|
logits, chunk_logits = self.forward(x) |
|
|
|
|
|
|
|
|
|
detected = torch.sigmoid(logits) |
|
detect_prob = detected.mean(dim=-1).cpu().item() |
|
|
|
|
|
chunk_probs = F.softmax(chunk_logits, dim=-1) |
|
chunk_indices = torch.argmax( |
|
chunk_probs, dim=-1 |
|
) |
|
|
|
|
|
binary_message = [] |
|
for i in range(self.nchunks): |
|
chunk_val = chunk_indices[:, i] |
|
|
|
chunk_bits = [] |
|
for b in range(self.nchunk_size): |
|
bit_b = (chunk_val >> b) & 1 |
|
chunk_bits.append(bit_b.unsqueeze(-1)) |
|
|
|
chunk_bits = torch.cat(chunk_bits, dim=-1) |
|
binary_message.append(chunk_bits) |
|
|
|
|
|
binary_message = torch.cat(binary_message, dim=-1) |
|
|
|
return detect_prob, binary_message, detected |
|
|
|
|
|
|
|
class WMEmbedder(nn.Module): |
|
""" |
|
A class that takes a secret message, processes it into chunk embeddings |
|
(as a small sequence), and uses a TransformerDecoder to do cross-attention |
|
between the original hidden (target) and the watermark tokens (memory). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
nbits: int, |
|
input_dim: int, |
|
nchunk_size: int, |
|
hidden_dim: int = 256, |
|
num_heads: int = 1, |
|
num_layers: int = 4, |
|
): |
|
super().__init__() |
|
self.nchunk_size = nchunk_size |
|
assert nbits % nchunk_size == 0, "nbits must be a multiple of nchunk_size!" |
|
self.nbits = nbits |
|
self.nchunks = nbits // nchunk_size |
|
|
|
|
|
self.msg_embeddings = nn.ModuleList( |
|
nn.Embedding(2**nchunk_size, hidden_dim) for _ in range(self.nchunks) |
|
) |
|
|
|
|
|
self.input_projection = nn.Linear(input_dim, hidden_dim) |
|
|
|
|
|
|
|
|
|
decoder_layer = nn.TransformerDecoderLayer( |
|
d_model=hidden_dim, |
|
nhead=num_heads, |
|
dim_feedforward=2 * hidden_dim, |
|
activation="gelu", |
|
batch_first=True, |
|
) |
|
self.transformer_decoder = nn.TransformerDecoder( |
|
decoder_layer, num_layers=num_layers |
|
) |
|
|
|
|
|
|
|
self.output_projection = nn.Linear(hidden_dim, input_dim) |
|
|
|
def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden: [batch, input_dim, seq_len] |
|
msg: [batch, nbits] |
|
Returns: |
|
A tensor [batch, input_dim, seq_len] with watermark injected. |
|
""" |
|
b, in_dim, seq_len = hidden.shape |
|
|
|
|
|
hidden_projected = self.input_projection( |
|
hidden.permute(0, 2, 1) |
|
) |
|
|
|
|
|
|
|
chunk_emb_list = [] |
|
for i in range(self.nchunks): |
|
|
|
chunk_bits = msg[:, i * self.nchunk_size : (i + 1) * self.nchunk_size] |
|
chunk_val = torch.zeros_like(chunk_bits[:, 0]) |
|
for bit_idx in range(self.nchunk_size): |
|
|
|
chunk_val += chunk_bits[:, bit_idx] << bit_idx |
|
|
|
|
|
chunk_emb = self.msg_embeddings[i](chunk_val) |
|
chunk_emb_list.append(chunk_emb.unsqueeze(1)) |
|
|
|
|
|
chunk_emb_seq = torch.cat(chunk_emb_list, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x_decoded = self.transformer_decoder( |
|
tgt=hidden_projected, |
|
memory=chunk_emb_seq, |
|
) |
|
|
|
|
|
x_output = self.output_projection(x_decoded) |
|
|
|
|
|
x_output = x_output.permute(0, 2, 1) |
|
|
|
|
|
x_output = x_output + hidden |
|
|
|
return x_output |
|
|
|
|
|
from speechtokenizer import SpeechTokenizer |
|
|
|
|
|
class SBW(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.nbits = 16 |
|
config_path = ( |
|
"speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json" |
|
) |
|
ckpt_path = "speechtokenizer/pretrained_model/SpeechTokenizer.pt" |
|
self.st_model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) |
|
self.msg_processor = WMEmbedder( |
|
nbits=16, |
|
input_dim=1024, |
|
nchunk_size=4, |
|
) |
|
self.detector = WMDetector( |
|
1024, |
|
16, |
|
nchunk_size=4, |
|
) |
|
|
|
def detect_watermark( |
|
self, x: torch.Tensor, return_logits=False |
|
) -> Tuple[float, torch.Tensor]: |
|
embedding = self.st_model.forward_feature(x) |
|
if return_logits: |
|
return self.detector(embedding) |
|
return self.detector.detect_watermark(embedding) |
|
|
|
def forward( |
|
self, |
|
speech_input: torch.Tensor, |
|
message: Optional[torch.Tensor] = None, |
|
) -> Dict[str, torch.Tensor]: |
|
recon, recon_wm, acoustic, acoustic_wm = self.st_model( |
|
speech_input, msg_processor=self.msg_processor, message=message |
|
) |
|
wav_length = min(speech_input.size(-1), recon_wm.size(-1)) |
|
speech_input = speech_input[..., :wav_length] |
|
recon = recon[..., :wav_length] |
|
recon_wm = recon_wm[..., :wav_length] |
|
return { |
|
"recon": recon, |
|
"recon_wm": recon_wm, |
|
} |
|
|