buttercrab commited on
Commit
e43723b
·
unverified ·
1 Parent(s): 0f25f79
Files changed (8) hide show
  1. app.py +2 -2
  2. dia/__init__.py +6 -0
  3. dia/audio.py +0 -22
  4. dia/config.py +96 -113
  5. dia/layers.py +411 -127
  6. dia/model.py +559 -168
  7. dia/state.py +82 -69
  8. requirements.txt +3 -2
app.py CHANGED
@@ -16,7 +16,7 @@ from dia.model import Dia
16
  print("Loading Nari model...")
17
  try:
18
  # Use the function from inference.py
19
- model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32")
20
  except Exception as e:
21
  print(f"Error loading Nari model: {e}")
22
  raise
@@ -375,4 +375,4 @@ if __name__ == "__main__":
375
 
376
  # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
377
  # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
378
- demo.launch()
 
16
  print("Loading Nari model...")
17
  try:
18
  # Use the function from inference.py
19
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B-0626", compute_dtype="float16")
20
  except Exception as e:
21
  print(f"Error loading Nari model: {e}")
22
  raise
 
375
 
376
  # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
377
  # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
378
+ demo.launch()
dia/__init__.py CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .model import Dia
2
+
3
+
4
+ __all__ = [
5
+ "Dia",
6
+ ]
dia/audio.py CHANGED
@@ -179,25 +179,3 @@ def revert_audio_delay(
179
  ) # Changed np.where to torch.where
180
 
181
  return result_BxTxC
182
-
183
-
184
- @torch.no_grad()
185
- @torch.inference_mode()
186
- def decode(
187
- model,
188
- audio_codes,
189
- ):
190
- """
191
- Decodes the given frames into an output audio waveform
192
- """
193
- if len(audio_codes) != 1:
194
- raise ValueError(f"Expected one frame, got {len(audio_codes)}")
195
-
196
- try:
197
- audio_values = model.quantizer.from_codes(audio_codes)
198
- audio_values = model.decode(audio_values[0])
199
-
200
- return audio_values
201
- except Exception as e:
202
- print(f"Error in decode method: {str(e)}")
203
- raise
 
179
  ) # Changed np.where to torch.where
180
 
181
  return result_BxTxC
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dia/config.py CHANGED
@@ -14,149 +14,132 @@ Key components:
14
  """
15
 
16
  import os
17
- from typing import Annotated
18
 
19
- from pydantic import BaseModel, BeforeValidator, Field
20
-
21
-
22
- class DataConfig(BaseModel, frozen=True):
23
- """Configuration for data loading and preprocessing.
24
-
25
- Attributes:
26
- text_length: Maximum length of text sequences (must be multiple of 128).
27
- audio_length: Maximum length of audio sequences (must be multiple of 128).
28
- channels: Number of audio channels.
29
- text_pad_value: Value used for padding text sequences.
30
- audio_eos_value: Value representing the end of audio sequences.
31
- audio_bos_value: Value representing the beginning of audio sequences.
32
- audio_pad_value: Value used for padding audio sequences.
33
- delay_pattern: List of delay values for each audio channel.
34
- """
35
-
36
- text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
37
- Field(gt=0, multiple_of=128)
38
- )
39
- audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
40
- Field(gt=0, multiple_of=128)
41
- )
42
- channels: int = Field(default=9, gt=0, multiple_of=1)
43
- text_pad_value: int = Field(default=0)
44
- audio_eos_value: int = Field(default=1024)
45
- audio_pad_value: int = Field(default=1025)
46
- audio_bos_value: int = Field(default=1026)
47
- delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
48
- default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
49
- )
50
-
51
- def __hash__(self) -> int:
52
- """Generate a hash based on all fields of the config."""
53
- return hash(
54
- (
55
- self.text_length,
56
- self.audio_length,
57
- self.channels,
58
- self.text_pad_value,
59
- self.audio_pad_value,
60
- self.audio_bos_value,
61
- self.audio_eos_value,
62
- tuple(self.delay_pattern),
63
- )
64
- )
65
 
66
 
67
  class EncoderConfig(BaseModel, frozen=True):
68
  """Configuration for the encoder component of the Dia model.
69
 
70
  Attributes:
71
- n_layer: Number of transformer layers.
72
- n_embd: Embedding dimension.
73
- n_hidden: Hidden dimension size in the MLP layers.
74
- n_head: Number of attention heads.
75
- head_dim: Dimension per attention head.
 
 
 
 
 
 
 
 
 
76
  """
77
 
78
- n_layer: int = Field(gt=0)
79
- n_embd: int = Field(gt=0)
80
- n_hidden: int = Field(gt=0)
81
- n_head: int = Field(gt=0)
82
- head_dim: int = Field(gt=0)
 
 
 
 
 
 
 
 
 
83
 
84
 
85
  class DecoderConfig(BaseModel, frozen=True):
86
  """Configuration for the decoder component of the Dia model.
87
 
88
  Attributes:
89
- n_layer: Number of transformer layers.
90
- n_embd: Embedding dimension.
91
- n_hidden: Hidden dimension size in the MLP layers.
92
- gqa_query_heads: Number of query heads for grouped-query self-attention.
93
- kv_heads: Number of key/value heads for grouped-query self-attention.
94
- gqa_head_dim: Dimension per query head for grouped-query self-attention.
95
- cross_query_heads: Number of query heads for cross-attention.
96
- cross_head_dim: Dimension per cross-attention head.
 
 
 
 
 
 
 
 
 
 
 
97
  """
98
 
99
- n_layer: int = Field(gt=0)
100
- n_embd: int = Field(gt=0)
101
- n_hidden: int = Field(gt=0)
102
- gqa_query_heads: int = Field(gt=0)
103
- kv_heads: int = Field(gt=0)
104
- gqa_head_dim: int = Field(gt=0)
105
- cross_query_heads: int = Field(gt=0)
106
- cross_head_dim: int = Field(gt=0)
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
- class ModelConfig(BaseModel, frozen=True):
110
  """Main configuration container for the Dia model architecture.
111
 
112
  Attributes:
 
 
113
  encoder: Configuration for the encoder component.
114
  decoder: Configuration for the decoder component.
115
  src_vocab_size: Size of the source (text) vocabulary.
116
  tgt_vocab_size: Size of the target (audio code) vocabulary.
117
- dropout: Dropout probability applied within the model.
118
- normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
119
- weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
120
- rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
121
- rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
 
 
 
 
 
 
122
  """
123
 
124
- encoder: EncoderConfig
125
- decoder: DecoderConfig
126
- src_vocab_size: int = Field(default=128, gt=0)
127
- tgt_vocab_size: int = Field(default=1028, gt=0)
128
- dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
129
- normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
130
- weight_dtype: str = Field(default="float32", description="Weight precision")
131
- rope_min_timescale: int = Field(
132
- default=1, description="Timescale For global Attention"
133
  )
134
- rope_max_timescale: int = Field(
135
- default=10_000, description="Timescale For global Attention"
 
 
136
  )
137
-
138
-
139
- class TrainingConfig(BaseModel, frozen=True):
140
- pass
141
-
142
-
143
- class DiaConfig(BaseModel, frozen=True):
144
- """Master configuration for the Dia model.
145
-
146
- Combines all sub-configurations into a single validated object.
147
-
148
- Attributes:
149
- version: Configuration version string.
150
- model: Model architecture configuration.
151
- training: Training process configuration (precision settings).
152
- data: Data loading and processing configuration.
153
- """
154
-
155
- version: str = Field(default="1.0")
156
- model: ModelConfig
157
- # TODO: remove training. this is just for backwards-compatability
158
- training: TrainingConfig
159
- data: DataConfig
160
 
161
  def save(self, path: str) -> None:
162
  """Save the current configuration instance to a JSON file.
 
14
  """
15
 
16
  import os
 
17
 
18
+ from pydantic import BaseModel, Field
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  class EncoderConfig(BaseModel, frozen=True):
22
  """Configuration for the encoder component of the Dia model.
23
 
24
  Attributes:
25
+ model_type: Type of the model, defaults to "dia_encoder".
26
+ hidden_size: Size of the encoder layers, defaults to 1024.
27
+ intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the encoder, defaults to 4096.
28
+ num_hidden_layers: Number of hidden layers in the encoder, defaults to 12.
29
+ num_attention_heads: Number of attention heads in the encoder, defaults to 16.
30
+ num_key_value_heads: Number of key-value heads in the encoder, defaults to 16.
31
+ head_dim: Dimension of each attention head, defaults to 128.
32
+ hidden_act: Activation function in the encoder, defaults to "silu".
33
+ max_position_embeddings: Maximum number of position embeddings, defaults to 1024.
34
+ initializer_range: Range for initializing weights, defaults to 0.02.
35
+ norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
36
+ rope_theta: Theta value for RoPE, defaults to 10000.0.
37
+ rope_scaling: Optional scaling factor for RoPE.
38
+ vocab_size: Vocabulary size, defaults to 256.
39
  """
40
 
41
+ head_dim: int = Field(default=128, gt=0)
42
+ hidden_act: str = Field(default="silu")
43
+ hidden_size: int = Field(default=1024, gt=0)
44
+ initializer_range: float = Field(default=0.02)
45
+ intermediate_size: int = Field(default=4096, gt=0)
46
+ max_position_embeddings: int = Field(default=1024, gt=0)
47
+ model_type: str = Field(default="dia_encoder")
48
+ norm_eps: float = Field(default=1e-5)
49
+ num_attention_heads: int = Field(default=16, gt=0)
50
+ num_hidden_layers: int = Field(default=12, gt=0)
51
+ num_key_value_heads: int = Field(default=16, gt=0)
52
+ rope_scaling: float | None = Field(default=None)
53
+ rope_theta: float = Field(default=10000.0)
54
+ vocab_size: int = Field(default=256, gt=0)
55
 
56
 
57
  class DecoderConfig(BaseModel, frozen=True):
58
  """Configuration for the decoder component of the Dia model.
59
 
60
  Attributes:
61
+ model_type: Type of the model, defaults to "dia_decoder".
62
+ hidden_size: Size of the decoder layers, defaults to 2048.
63
+ intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the decoder, defaults to 8192.
64
+ num_hidden_layers: Number of hidden layers in the decoder, defaults to 18.
65
+ num_attention_heads: Number of attention heads in the decoder, defaults to 16.
66
+ num_key_value_heads: Number of key-value heads in the decoder, defaults to 4.
67
+ head_dim: Dimension of each attention head, defaults to 128.
68
+ cross_hidden_size: Size of the cross-attention layers, defaults to 1024.
69
+ cross_num_attention_heads: Number of attention heads in the cross-attention mechanism, defaults to 16.
70
+ cross_num_key_value_heads: Number of key-value heads in the cross-attention mechanism, defaults to 16.
71
+ cross_head_dim: Dimension of each cross-attention head, defaults to 128.
72
+ hidden_act: Activation function in the decoder, defaults to "silu".
73
+ max_position_embeddings: Maximum number of position embeddings in the decoder, defaults to 3072.
74
+ initializer_range: Range for initializing weights in the decoder, defaults to 0.02.
75
+ norm_eps: Epsilon value for normalization layers in the decoder, defaults to 1e-5.
76
+ rope_theta: Theta value for RoPE in the decoder, defaults to 10000.0.
77
+ rope_scaling: Optional scaling factor for RoPE in the decoder.
78
+ vocab_size: Vocabulary size for the decoder, defaults to 1028.
79
+ num_channels: Number of channels in the decoder, defaults to 9.
80
  """
81
 
82
+ cross_head_dim: int = Field(default=128, gt=0)
83
+ cross_hidden_size: int = Field(default=1024, gt=0)
84
+ cross_num_attention_heads: int = Field(default=16, gt=0)
85
+ cross_num_key_value_heads: int = Field(default=16, gt=0)
86
+ head_dim: int = Field(default=128, gt=0)
87
+ hidden_act: str = Field(default="silu")
88
+ hidden_size: int = Field(default=2048, gt=0)
89
+ initializer_range: float = Field(default=0.02)
90
+ intermediate_size: int = Field(default=8192, gt=0)
91
+ max_position_embeddings: int = Field(default=3072, gt=0)
92
+ model_type: str = Field(default="dia_decoder")
93
+ norm_eps: float = Field(default=1e-5)
94
+ num_attention_heads: int = Field(default=16, gt=0)
95
+ num_channels: int = Field(default=9, gt=0)
96
+ num_hidden_layers: int = Field(default=18, gt=0)
97
+ num_key_value_heads: int = Field(default=4, gt=0)
98
+ rope_scaling: float | None = Field(default=None)
99
+ rope_theta: float = Field(default=10000.0)
100
+ vocab_size: int = Field(default=1028, gt=0)
101
 
102
 
103
+ class DiaConfig(BaseModel, frozen=True):
104
  """Main configuration container for the Dia model architecture.
105
 
106
  Attributes:
107
+ model_type: Type of the model, defaults to "dia".
108
+ is_encoder_decoder: Flag indicating if the model is an encoder-decoder type, defaults to True.
109
  encoder: Configuration for the encoder component.
110
  decoder: Configuration for the decoder component.
111
  src_vocab_size: Size of the source (text) vocabulary.
112
  tgt_vocab_size: Size of the target (audio code) vocabulary.
113
+ initializer_range: Range for initializing weights, defaults to 0.02.
114
+ norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
115
+ torch_dtype: Data type for model weights in PyTorch, defaults to "float32".
116
+ bos_token_id: Beginning-of-sequence token ID, defaults to 1026.
117
+ eos_token_id: End-of-sequence token ID, defaults to 1024.
118
+ pad_token_id: Padding token ID, defaults to 1025.
119
+ rope_theta: Theta value for RoPE, defaults to 10000.0.
120
+ rope_scaling: Optional scaling factor for RoPE.
121
+ transformers_version: Version of the transformers library, defaults to "4.53.0.dev0".
122
+ architectures: List of model architectures, defaults to ["DiaForConditionalGeneration"].
123
+ delay_pattern: List of delay values for each audio channel, defaults to [0,8,9,10,11,12,13,14,15].
124
  """
125
 
126
+ architectures: list[str] = Field(
127
+ default_factory=lambda: ["DiaForConditionalGeneration"]
 
 
 
 
 
 
 
128
  )
129
+ bos_token_id: int = Field(default=1026)
130
+ decoder_config: DecoderConfig
131
+ delay_pattern: list[int] = Field(
132
+ default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
133
  )
134
+ encoder_config: EncoderConfig
135
+ eos_token_id: int = Field(default=1024)
136
+ initializer_range: float = Field(default=0.02)
137
+ is_encoder_decoder: bool = Field(default=True)
138
+ model_type: str = Field(default="dia")
139
+ norm_eps: float = Field(default=1e-5)
140
+ pad_token_id: int = Field(default=1025)
141
+ torch_dtype: str = Field(default="float32")
142
+ transformers_version: str = Field(default="4.53.0.dev0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  def save(self, path: str) -> None:
145
  """Save the current configuration instance to a JSON file.
dia/layers.py CHANGED
@@ -1,10 +1,11 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
  from torch import Tensor
5
  from torch.nn import RMSNorm
6
 
7
- from .config import DiaConfig
8
  from .state import DecoderInferenceState, EncoderInferenceState, KVCache
9
 
10
 
@@ -15,12 +16,10 @@ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
15
  class DenseGeneral(nn.Module):
16
  """
17
  PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
18
-
19
  Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
  for the generalized matrix multiplication. Weight/bias shapes are calculated
21
  and parameters created during initialization based on config.
22
  `load_weights` validates shapes and copies data.
23
-
24
  Attributes:
25
  axis (Tuple[int, ...]): Input axis or axes to contract.
26
  in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
@@ -46,7 +45,6 @@ class DenseGeneral(nn.Module):
46
 
47
  factory_kwargs = {"device": device, "dtype": weight_dtype}
48
  self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
49
- self.register_parameter("bias", None)
50
 
51
  def forward(self, inputs: Tensor) -> Tensor:
52
  norm_axis = _normalize_axes(self.axis, inputs.ndim)
@@ -112,53 +110,112 @@ class RotaryEmbedding(nn.Module):
112
  self.embedding_dims = embedding_dims
113
  self.min_timescale = min_timescale
114
  self.max_timescale = max_timescale
115
- self.dtype = dtype
116
 
117
  half_embedding_dim = embedding_dims // 2
118
  fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
119
- self.register_buffer(
120
- "timescale",
121
- self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction,
122
- persistent=False,
123
- )
124
-
125
- def extra_repr(self) -> str:
126
- s = f"{self.timescale.shape}"
127
- return s
128
 
129
  def forward(self, inputs: torch.Tensor, position: torch.Tensor):
130
  """Applies RoPE."""
131
  position = position.unsqueeze(-1).unsqueeze(-1)
132
- timescale = self.timescale.to(inputs.device)
133
- sinusoid_inp = position / timescale
134
- sin = torch.sin(sinusoid_inp).to(inputs.dtype)
135
- cos = torch.cos(sinusoid_inp).to(inputs.dtype)
136
- first_half, second_half = torch.chunk(inputs, 2, dim=-1)
 
 
 
 
 
 
 
 
137
  first_part = first_half * cos - second_half * sin
138
  second_part = second_half * cos + first_half * sin
139
- return torch.cat((first_part, second_part), dim=-1)
 
 
 
140
 
141
 
142
- class Attention(nn.Module):
143
- """Attention using DenseGeneral."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def __init__(
146
  self,
147
- config: DiaConfig,
148
  q_embed_dim: int,
149
  kv_embed_dim: int,
150
  num_query_heads: int,
151
  num_kv_heads: int,
152
  head_dim: int,
153
  compute_dtype: torch.dtype,
154
- is_cross_attn: bool = False,
155
  out_embed_dim: int | None = None,
156
  ):
157
  super().__init__()
158
  self.num_query_heads = num_query_heads
159
  self.num_kv_heads = num_kv_heads
160
  self.head_dim = head_dim
161
- self.is_cross_attn = is_cross_attn
162
  self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
163
  self.projected_query_dim = num_query_heads * head_dim
164
  if num_query_heads % num_kv_heads != 0:
@@ -196,21 +253,18 @@ class Attention(nn.Module):
196
  # --- Rotary Embedding ---
197
  self.rotary_emb = RotaryEmbedding(
198
  embedding_dims=self.head_dim,
199
- min_timescale=config.model.rope_min_timescale,
200
- max_timescale=config.model.rope_max_timescale,
201
  dtype=compute_dtype,
202
  )
203
 
204
  def forward(
205
  self,
206
  Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
207
- Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
208
  q_positions: torch.Tensor, # (B, T)
209
  kv_positions: torch.Tensor | None = None, # (B, S)
210
  attn_mask: torch.Tensor
211
  | None = None, # None in Decoder Self Attention, Valid mask in Others
212
  cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
213
- prefill: bool = False,
214
  is_causal: bool = False,
215
  ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
216
  """
@@ -223,7 +277,6 @@ class Attention(nn.Module):
223
  kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
224
  attn_mask: Attention mask.
225
  cache: KVCache.
226
- prefill: If True, use prefill mode.
227
 
228
  Returns:
229
  A tuple containing:
@@ -235,44 +288,266 @@ class Attention(nn.Module):
235
  original_dtype = Xq.dtype
236
 
237
  Xq_BxTxNxH = self.q_proj(Xq)
238
- Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
239
  Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
240
 
241
  attn_k: torch.Tensor | None = None
242
  attn_v: torch.Tensor | None = None
243
 
244
- if self.is_cross_attn:
245
- attn_k, attn_v = cache.k, cache.v
 
 
 
 
 
 
 
 
 
 
 
 
246
  else:
247
- Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
248
- Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
249
- Xk_BxSxKxH = self.rotary_emb(
250
- Xk_BxSxKxH, position=kv_positions
251
- ) # (B, S, K, H)
252
-
253
- Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
254
- Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
255
-
256
- if cache is None:
257
- attn_k = Xk_BxKxSxH
258
- attn_v = Xv_BxKxSxH
259
- else:
260
- if prefill:
261
- attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
262
- cache.prefill(attn_k, attn_v)
263
- else:
264
- attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
265
-
266
- attn_output = F.scaled_dot_product_attention(
267
- Xq_BxNxTxH,
268
- attn_k,
269
- attn_v,
270
- attn_mask=attn_mask,
271
- scale=1.0,
272
- enable_gqa=self.num_gqa_groups > 1,
273
- is_causal=is_causal,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  )
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
277
  output = self.o_proj(attn_output)
278
 
@@ -285,34 +560,33 @@ class EncoderLayer(nn.Module):
285
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
286
  super().__init__()
287
  self.config = config
288
- model_config = config.model
289
- enc_config = config.model.encoder
290
- embed_dim = enc_config.n_embd
291
 
292
  self.pre_sa_norm = RMSNorm(
293
  embed_dim,
294
- eps=model_config.normalization_layer_epsilon,
295
  dtype=torch.float32,
296
  )
297
- self.self_attention = Attention(
298
- config,
299
  q_embed_dim=embed_dim,
300
  kv_embed_dim=embed_dim,
301
- num_query_heads=enc_config.n_head,
302
- num_kv_heads=enc_config.n_head,
303
  head_dim=enc_config.head_dim,
304
  compute_dtype=compute_dtype,
305
- is_cross_attn=False,
306
  out_embed_dim=embed_dim,
307
  )
308
  self.post_sa_norm = RMSNorm(
309
  embed_dim,
310
- eps=model_config.normalization_layer_epsilon,
311
  dtype=torch.float32,
312
  )
313
  self.mlp = MlpBlock(
314
  embed_dim=embed_dim,
315
- intermediate_dim=enc_config.n_hidden,
316
  compute_dtype=compute_dtype,
317
  )
318
 
@@ -322,10 +596,10 @@ class EncoderLayer(nn.Module):
322
  state: EncoderInferenceState,
323
  ) -> torch.Tensor:
324
  residual = x
325
- x_norm = self.pre_sa_norm(x)
 
326
  sa_out = self.self_attention(
327
- Xq=x_norm,
328
- Xkv=x_norm,
329
  q_positions=state.positions,
330
  kv_positions=state.positions,
331
  attn_mask=state.attn_mask,
@@ -333,7 +607,7 @@ class EncoderLayer(nn.Module):
333
  x = residual + sa_out
334
 
335
  residual = x
336
- x_norm = self.post_sa_norm(x)
337
  mlp_out = self.mlp(x_norm)
338
  x = residual + mlp_out
339
 
@@ -346,20 +620,23 @@ class Encoder(nn.Module):
346
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
347
  super().__init__()
348
  self.config = config
349
- model_config = config.model
350
- enc_config = config.model.encoder
351
 
352
  self.embedding = nn.Embedding(
353
- model_config.src_vocab_size,
354
- enc_config.n_embd,
355
  dtype=compute_dtype,
356
  )
357
  self.layers = nn.ModuleList(
358
- [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
 
 
 
359
  )
360
  self.norm = RMSNorm(
361
- enc_config.n_embd,
362
- eps=model_config.normalization_layer_epsilon,
363
  dtype=torch.float32,
364
  )
365
 
@@ -373,7 +650,7 @@ class Encoder(nn.Module):
373
  for layer in self.layers:
374
  x = layer(x, state)
375
 
376
- x = self.norm(x)
377
  return x
378
 
379
 
@@ -383,57 +660,55 @@ class DecoderLayer(nn.Module):
383
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
384
  super().__init__()
385
  self.config = config
386
- model_config = config.model
387
- dec_config = config.model.decoder
388
- enc_config = config.model.encoder
389
- dec_embed_dim = dec_config.n_embd
390
- enc_embed_dim = enc_config.n_embd
391
 
392
  # Norms
393
  self.pre_sa_norm = RMSNorm(
394
  dec_embed_dim,
395
- eps=model_config.normalization_layer_epsilon,
396
  dtype=torch.float32,
397
  )
398
  self.pre_ca_norm = RMSNorm(
399
  dec_embed_dim,
400
- eps=model_config.normalization_layer_epsilon,
401
  dtype=torch.float32,
402
  )
403
  self.pre_mlp_norm = RMSNorm(
404
  dec_embed_dim,
405
- eps=model_config.normalization_layer_epsilon,
406
  dtype=torch.float32,
407
  )
408
 
409
  # Self-Attention (GQA) with Causal Masking
410
- self.self_attention = Attention(
411
- config,
412
  q_embed_dim=dec_embed_dim,
413
  kv_embed_dim=dec_embed_dim,
414
- num_query_heads=dec_config.gqa_query_heads,
415
- num_kv_heads=dec_config.kv_heads,
416
- head_dim=dec_config.gqa_head_dim,
417
  compute_dtype=compute_dtype,
418
- is_cross_attn=False,
419
  out_embed_dim=dec_embed_dim,
420
  )
421
  # Cross-Attention (MHA)
422
- self.cross_attention = Attention(
423
- config=config,
424
  q_embed_dim=dec_embed_dim,
425
  kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
426
- num_query_heads=dec_config.cross_query_heads,
427
- num_kv_heads=dec_config.cross_query_heads,
428
  head_dim=dec_config.cross_head_dim,
429
  compute_dtype=compute_dtype,
430
- is_cross_attn=True,
431
  out_embed_dim=dec_embed_dim,
432
  )
433
  # MLP
434
  self.mlp = MlpBlock(
435
  embed_dim=dec_embed_dim,
436
- intermediate_dim=dec_config.n_hidden,
437
  compute_dtype=compute_dtype,
438
  )
439
 
@@ -444,37 +719,39 @@ class DecoderLayer(nn.Module):
444
  self_attn_cache: KVCache | None = None,
445
  cross_attn_cache: KVCache | None = None,
446
  prefill: bool = False,
 
447
  ) -> torch.Tensor:
448
  residual = x
449
- x_norm = self.pre_sa_norm(x)
 
 
450
 
451
  sa_out = self.self_attention(
452
- Xq=x_norm, # (2, 1, D)
453
- Xkv=x_norm, # (2, 1, D)
454
  q_positions=state.dec_positions, # (2, 1)
455
  kv_positions=state.dec_positions, # (2, 1)
456
- attn_mask=None,
457
  cache=self_attn_cache,
458
  prefill=prefill,
459
  is_causal=prefill,
 
460
  )
461
 
462
  x = residual + sa_out
463
 
464
  residual = x
465
- x_norm = self.pre_ca_norm(x)
466
  ca_out = self.cross_attention(
467
  Xq=x_norm,
468
- Xkv=state.enc_out,
469
  q_positions=state.dec_positions,
470
  kv_positions=state.enc_positions,
471
- attn_mask=state.dec_cross_attn_mask,
472
  cache=cross_attn_cache,
473
  )
474
  x = residual + ca_out
475
 
476
  residual = x
477
- x_norm = self.pre_mlp_norm(x)
478
  mlp_out = self.mlp(x_norm)
479
  x = residual + mlp_out
480
 
@@ -487,16 +764,14 @@ class Decoder(nn.Module):
487
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
488
  super().__init__()
489
  self.config = config
490
- model_config = config.model
491
- dec_config = config.model.decoder
492
- data_config = config.data
493
- self.num_channels = data_config.channels
494
- self.num_layers = dec_config.n_layer
495
 
496
  self.embeddings = nn.ModuleList(
497
  [
498
  nn.Embedding(
499
- model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
500
  )
501
  for _ in range(self.num_channels)
502
  ]
@@ -509,14 +784,14 @@ class Decoder(nn.Module):
509
  )
510
 
511
  self.norm = RMSNorm(
512
- dec_config.n_embd,
513
- eps=model_config.normalization_layer_epsilon,
514
  dtype=torch.float32,
515
  )
516
 
517
  self.logits_dense = DenseGeneral(
518
- in_shapes=(dec_config.n_embd,),
519
- out_features=(self.num_channels, model_config.tgt_vocab_size),
520
  axis=(-1,),
521
  weight_dtype=compute_dtype,
522
  )
@@ -524,7 +799,6 @@ class Decoder(nn.Module):
524
  def precompute_cross_attn_cache(
525
  self,
526
  enc_out: torch.Tensor, # (B, S, E)
527
- enc_positions: torch.Tensor, # (B, S)
528
  ) -> list[KVCache]:
529
  """
530
  Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
@@ -536,7 +810,6 @@ class Decoder(nn.Module):
536
  k_proj = cross_attn_module.k_proj(enc_out)
537
  v_proj = cross_attn_module.v_proj(enc_out)
538
 
539
- k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
540
  k = k_proj.transpose(1, 2)
541
  v = v_proj.transpose(1, 2)
542
 
@@ -548,10 +821,10 @@ class Decoder(nn.Module):
548
  self,
549
  tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
550
  state: DecoderInferenceState,
 
551
  ) -> torch.Tensor:
552
  """
553
  Performs a single decoding step, managing KV caches layer by layer.
554
-
555
  Returns:
556
  A tuple containing:
557
  - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
@@ -571,6 +844,7 @@ class Decoder(nn.Module):
571
  state,
572
  self_attn_cache=self_cache,
573
  cross_attn_cache=cross_cache,
 
574
  )
575
 
576
  x = self.norm(x)
@@ -583,7 +857,6 @@ class Decoder(nn.Module):
583
  ) -> torch.Tensor:
584
  """
585
  Forward pass for the Decoder stack, managing KV caches.
586
-
587
  Args:
588
  tgt_ids_BxTxC: Target token IDs (B, T, C).
589
  encoder_out: Output from the encoder (B, S, E).
@@ -597,7 +870,6 @@ class Decoder(nn.Module):
597
  precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
598
  derived from `encoder_out`. This is passed identically
599
  to all layers.
600
-
601
  Returns:
602
  A tuple containing:
603
  - logits: The final output logits (B, T, C * V), cast to float32.
@@ -632,7 +904,19 @@ class Decoder(nn.Module):
632
  return logits_BxTxCxV.to(torch.float32)
633
 
634
 
635
- class DiaModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
636
  """PyTorch Dia Model using DenseGeneral."""
637
 
638
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
  from torch import Tensor
6
  from torch.nn import RMSNorm
7
 
8
+ from .config import DecoderConfig, DiaConfig, EncoderConfig
9
  from .state import DecoderInferenceState, EncoderInferenceState, KVCache
10
 
11
 
 
16
  class DenseGeneral(nn.Module):
17
  """
18
  PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
 
19
  Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
20
  for the generalized matrix multiplication. Weight/bias shapes are calculated
21
  and parameters created during initialization based on config.
22
  `load_weights` validates shapes and copies data.
 
23
  Attributes:
24
  axis (Tuple[int, ...]): Input axis or axes to contract.
25
  in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
 
45
 
46
  factory_kwargs = {"device": device, "dtype": weight_dtype}
47
  self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
 
48
 
49
  def forward(self, inputs: Tensor) -> Tensor:
50
  norm_axis = _normalize_axes(self.axis, inputs.ndim)
 
110
  self.embedding_dims = embedding_dims
111
  self.min_timescale = min_timescale
112
  self.max_timescale = max_timescale
113
+ self.compute_dtype = dtype
114
 
115
  half_embedding_dim = embedding_dims // 2
116
  fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
117
+ timescale = (
118
+ self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
119
+ ).to(torch.float32)
120
+ self.register_buffer("timescale", timescale, persistent=False)
 
 
 
 
 
121
 
122
  def forward(self, inputs: torch.Tensor, position: torch.Tensor):
123
  """Applies RoPE."""
124
  position = position.unsqueeze(-1).unsqueeze(-1)
125
+ sinusoid_inp = position / self.timescale
126
+ sin = torch.sin(sinusoid_inp)
127
+ cos = torch.cos(sinusoid_inp)
128
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
129
+ first_part = first_half * cos - second_half * sin
130
+ second_part = second_half * cos + first_half * sin
131
+ return torch.cat(
132
+ (first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)),
133
+ dim=-1,
134
+ )
135
+
136
+ def apply_rope(self, inputs: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor):
137
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
138
  first_part = first_half * cos - second_half * sin
139
  second_part = second_half * cos + first_half * sin
140
+ return torch.cat(
141
+ (first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)),
142
+ dim=-1,
143
+ )
144
 
145
 
146
+ def custom_scaled_dot_product_attention(
147
+ query: torch.Tensor,
148
+ key: torch.Tensor,
149
+ value: torch.Tensor,
150
+ attn_mask: torch.Tensor | None = None,
151
+ scale: float = 1.0,
152
+ is_causal: bool = False,
153
+ num_gqa_groups: int = 1,
154
+ ) -> torch.Tensor:
155
+ """
156
+ Custom scaled dot-product attention with GQA support for MPS compatibility.
157
+
158
+ Args:
159
+ query: (B, N_q, T, H) - Query tensor, N_q = num_query_heads
160
+ key: (B, N_kv, S, H) - Key tensor, N_kv = num_kv_heads
161
+ value: (B, N_kv, S, H) - Value tensor
162
+ attn_mask: (B, 1, T, S) - Attention mask, optional
163
+ scale: Scaling factor for attention scores
164
+ is_causal: If True, apply causal masking
165
+ num_gqa_groups: Number of query groups per KV head (N_q / N_kv)
166
+
167
+ Returns:
168
+ output: (B, N_q, T, H) - Attention output
169
+ """
170
+ B, N_q, T, H = query.shape
171
+ _, N_kv, S, _ = key.shape
172
+
173
+ # For GQA, repeat key and value tensors to match query heads
174
+ if num_gqa_groups > 1:
175
+ key = key.repeat_interleave(num_gqa_groups, dim=1) # (B, N_q, S, H)
176
+ value = value.repeat_interleave(num_gqa_groups, dim=1) # (B, N_q, S, H)
177
+
178
+ # Compute attention scores: (B, N_q, T, H) @ (B, N_q, H, S) -> (B, N_q, T, S)
179
+ scores = torch.matmul(query, key.transpose(-1, -2)) * scale
180
+
181
+ # Apply causal mask if needed
182
+ if is_causal:
183
+ causal_mask = torch.tril(
184
+ torch.ones(T, S, dtype=torch.bool, device=query.device)
185
+ )
186
+ scores = scores.masked_fill(~causal_mask, float("-inf"))
187
+
188
+ # Apply attention mask if provided
189
+ if attn_mask is not None:
190
+ scores = scores.masked_fill(~attn_mask, float("-inf"))
191
+
192
+ # Softmax over the last dimension (S)
193
+ attn_weights = F.softmax(scores, dim=-1)
194
+
195
+ # Compute output: (B, N_q, T, S) @ (B, N_q, S, H) -> (B, N_q, T, H)
196
+ output = torch.matmul(attn_weights, value)
197
+
198
+ return output
199
+
200
+
201
+ class CrossAttention(nn.Module):
202
+ """Cross-Attention using DenseGeneral."""
203
 
204
  def __init__(
205
  self,
206
+ config: EncoderConfig | DecoderConfig,
207
  q_embed_dim: int,
208
  kv_embed_dim: int,
209
  num_query_heads: int,
210
  num_kv_heads: int,
211
  head_dim: int,
212
  compute_dtype: torch.dtype,
 
213
  out_embed_dim: int | None = None,
214
  ):
215
  super().__init__()
216
  self.num_query_heads = num_query_heads
217
  self.num_kv_heads = num_kv_heads
218
  self.head_dim = head_dim
 
219
  self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
220
  self.projected_query_dim = num_query_heads * head_dim
221
  if num_query_heads % num_kv_heads != 0:
 
253
  # --- Rotary Embedding ---
254
  self.rotary_emb = RotaryEmbedding(
255
  embedding_dims=self.head_dim,
256
+ max_timescale=config.rope_theta,
 
257
  dtype=compute_dtype,
258
  )
259
 
260
  def forward(
261
  self,
262
  Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
 
263
  q_positions: torch.Tensor, # (B, T)
264
  kv_positions: torch.Tensor | None = None, # (B, S)
265
  attn_mask: torch.Tensor
266
  | None = None, # None in Decoder Self Attention, Valid mask in Others
267
  cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
 
268
  is_causal: bool = False,
269
  ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
270
  """
 
277
  kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
278
  attn_mask: Attention mask.
279
  cache: KVCache.
 
280
 
281
  Returns:
282
  A tuple containing:
 
288
  original_dtype = Xq.dtype
289
 
290
  Xq_BxTxNxH = self.q_proj(Xq)
 
291
  Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
292
 
293
  attn_k: torch.Tensor | None = None
294
  attn_v: torch.Tensor | None = None
295
 
296
+ attn_k, attn_v = cache.k, cache.v
297
+
298
+ # Use custom attention for MPS backend, otherwise use optimized PyTorch function
299
+ is_mps = Xq.device.type == "mps" and torch.backends.mps.is_available()
300
+ if is_mps:
301
+ attn_output = custom_scaled_dot_product_attention(
302
+ query=Xq_BxNxTxH,
303
+ key=attn_k,
304
+ value=attn_v,
305
+ attn_mask=attn_mask if not is_causal else None,
306
+ scale=1.0,
307
+ is_causal=is_causal,
308
+ num_gqa_groups=self.num_gqa_groups,
309
+ )
310
  else:
311
+ attn_output = F.scaled_dot_product_attention(
312
+ Xq_BxNxTxH,
313
+ attn_k,
314
+ attn_v,
315
+ attn_mask=attn_mask if not is_causal else None,
316
+ scale=1.0,
317
+ enable_gqa=self.num_gqa_groups > 1,
318
+ is_causal=is_causal,
319
+ )
320
+
321
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
322
+ output = self.o_proj(attn_output)
323
+
324
+ return output.to(original_dtype)
325
+
326
+
327
+ class FusedQKV(nn.Module):
328
+ def __init__(
329
+ self,
330
+ in_features: int,
331
+ out_features: int,
332
+ bias: bool = False,
333
+ num_q_heads: int = 1,
334
+ q_head_dim: int = 1,
335
+ num_kv_heads: int = 1,
336
+ kv_head_dim: int = 1,
337
+ ):
338
+ super().__init__()
339
+ self.num_q_heads = num_q_heads
340
+ self.q_head_dim = q_head_dim
341
+ self.num_kv_heads = num_kv_heads
342
+ self.kv_head_dim = kv_head_dim
343
+ self.q_output_dim = num_q_heads * q_head_dim
344
+ self.kv_output_dim = num_kv_heads * kv_head_dim
345
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
346
+
347
+ def forward(
348
+ self, inputs: torch.Tensor
349
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
350
+ x = self.linear(inputs)
351
+
352
+ q, k, v = x.split(
353
+ [self.q_output_dim, self.kv_output_dim, self.kv_output_dim], dim=-1
354
  )
355
 
356
+ q = q.reshape(q.shape[:-1] + (self.num_q_heads, self.q_head_dim))
357
+ k = k.reshape(k.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
358
+ v = v.reshape(v.shape[:-1] + (self.num_kv_heads, self.kv_head_dim))
359
+
360
+ return q, k, v
361
+
362
+
363
+ class SelfAttention(nn.Module):
364
+ """Attention using DenseGeneral."""
365
+
366
+ def __init__(
367
+ self,
368
+ config: DiaConfig,
369
+ q_embed_dim: int,
370
+ kv_embed_dim: int,
371
+ num_query_heads: int,
372
+ num_kv_heads: int,
373
+ head_dim: int,
374
+ compute_dtype: torch.dtype,
375
+ out_embed_dim: int | None = None,
376
+ ):
377
+ super().__init__()
378
+ self.num_query_heads = num_query_heads
379
+ self.num_kv_heads = num_kv_heads
380
+ self.head_dim = head_dim
381
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
382
+ self.projected_query_dim = num_query_heads * head_dim
383
+ if num_query_heads % num_kv_heads != 0:
384
+ raise ValueError(
385
+ f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
386
+ )
387
+ self.num_gqa_groups = num_query_heads // num_kv_heads
388
+ self.kv_embed_dim = kv_embed_dim
389
+ self.q_embed_dim = q_embed_dim
390
+
391
+ # --- Projection Layers using DenseGeneral ---
392
+ self.q_proj = DenseGeneral(
393
+ in_shapes=(q_embed_dim,),
394
+ out_features=(num_query_heads, head_dim),
395
+ axis=(-1,),
396
+ weight_dtype=compute_dtype,
397
+ )
398
+ self.k_proj = DenseGeneral(
399
+ in_shapes=(kv_embed_dim,),
400
+ out_features=(num_kv_heads, head_dim),
401
+ axis=(-1,),
402
+ weight_dtype=compute_dtype,
403
+ )
404
+ self.v_proj = DenseGeneral(
405
+ in_shapes=(kv_embed_dim,),
406
+ out_features=(num_kv_heads, head_dim),
407
+ axis=(-1,),
408
+ weight_dtype=compute_dtype,
409
+ )
410
+ self.o_proj = DenseGeneral(
411
+ in_shapes=(num_query_heads, head_dim),
412
+ out_features=(self.output_dim,),
413
+ axis=(-2, -1),
414
+ weight_dtype=compute_dtype,
415
+ )
416
+
417
+ # --- Rotary Embedding ---
418
+ self.rotary_emb = RotaryEmbedding(
419
+ embedding_dims=self.head_dim,
420
+ max_timescale=config.rope_theta,
421
+ dtype=compute_dtype,
422
+ )
423
+
424
+ self.is_fused_qkv = False
425
+
426
+ def get_linear_weight(self, dense: DenseGeneral):
427
+ W_dg = dense.weight.data
428
+
429
+ out_features = 1
430
+ input_features = 1
431
+ for dim in dense.out_features:
432
+ out_features *= dim
433
+ for dim in dense.in_shapes:
434
+ input_features *= dim
435
+
436
+ W_dg_reshaped_for_linear_T = W_dg.reshape(input_features, out_features)
437
+ linear_weight = W_dg_reshaped_for_linear_T.transpose(0, 1).contiguous()
438
+ return linear_weight
439
+
440
+ def patch_fused_qkv(self):
441
+ q_proj_weight = self.get_linear_weight(self.q_proj)
442
+ k_proj_weight = self.get_linear_weight(self.k_proj)
443
+ v_proj_weight = self.get_linear_weight(self.v_proj)
444
+
445
+ self.qkv = FusedQKV(
446
+ self.kv_embed_dim,
447
+ (
448
+ self.num_query_heads * self.head_dim
449
+ + 2 * (self.num_kv_heads * self.head_dim)
450
+ ),
451
+ bias=False,
452
+ num_q_heads=self.num_query_heads,
453
+ q_head_dim=self.head_dim,
454
+ num_kv_heads=self.num_kv_heads,
455
+ kv_head_dim=self.head_dim,
456
+ )
457
+ self.qkv.linear.weight.data = torch.cat(
458
+ [q_proj_weight, k_proj_weight, v_proj_weight], dim=0
459
+ )
460
+
461
+ # print(f"qkv.weight.shape: {self.qkv.linear.weight.shape}")
462
+ self.is_fused_qkv = True
463
+
464
+ def forward(
465
+ self,
466
+ X: torch.Tensor, # (B, T, D) T = 1 in AR generation
467
+ q_positions: torch.Tensor, # (B, T)
468
+ kv_positions: torch.Tensor | None = None, # (B, S)
469
+ attn_mask: torch.Tensor
470
+ | None = None, # None in Decoder Self Attention, Valid mask in Others
471
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
472
+ prefill: bool = False,
473
+ is_causal: bool = False,
474
+ current_idx: torch.Tensor | None = None,
475
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
476
+ """
477
+ Performs attention calculation with optional KV caching.
478
+ Args:
479
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
480
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
481
+ q_positions: Positions for queries (B, T).
482
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
483
+ attn_mask: Attention mask.
484
+ cache: KVCache.
485
+ prefill: If True, use prefill mode.
486
+ Returns:
487
+ A tuple containing:
488
+ - output: The attention output tensor (B, T, output_dim).
489
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
490
+ """
491
+ if kv_positions is None:
492
+ kv_positions = q_positions
493
+
494
+ original_dtype = X.dtype
495
+
496
+ if self.is_fused_qkv:
497
+ Xq_BxTxNxH, Xk_BxSxKxH, Xv_BxSxKxH = self.qkv(X)
498
+ else:
499
+ Xq_BxTxNxH = self.q_proj(X)
500
+ Xk_BxSxKxH = self.k_proj(X)
501
+ Xv_BxSxKxH = self.v_proj(X)
502
+
503
+ position = q_positions.unsqueeze(-1).unsqueeze(-1)
504
+ sinusoid_inp = position / self.rotary_emb.timescale
505
+ sin = torch.sin(sinusoid_inp)
506
+ cos = torch.cos(sinusoid_inp)
507
+
508
+ Xq_BxTxNxH = self.rotary_emb.apply_rope(Xq_BxTxNxH, sin, cos)
509
+ Xk_BxSxKxH = self.rotary_emb.apply_rope(Xk_BxSxKxH, sin, cos)
510
+
511
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
512
+
513
+ attn_k: torch.Tensor | None = None
514
+ attn_v: torch.Tensor | None = None
515
+
516
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
517
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
518
+
519
+ if cache is None:
520
+ attn_k = Xk_BxKxSxH
521
+ attn_v = Xv_BxKxSxH
522
+ elif prefill:
523
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
524
+ cache.prefill(attn_k, attn_v)
525
+ else:
526
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH, current_idx)
527
+
528
+ # Use custom attention for MPS backend, otherwise use optimized PyTorch function
529
+ is_mps = Xv_BxSxKxH.device.type == "mps" and torch.backends.mps.is_available()
530
+ if is_mps:
531
+ attn_output = custom_scaled_dot_product_attention(
532
+ query=Xq_BxNxTxH,
533
+ key=attn_k,
534
+ value=attn_v,
535
+ attn_mask=attn_mask if not is_causal else None,
536
+ scale=1.0,
537
+ is_causal=is_causal,
538
+ num_gqa_groups=self.num_gqa_groups,
539
+ )
540
+ else:
541
+ attn_output = F.scaled_dot_product_attention(
542
+ Xq_BxNxTxH,
543
+ attn_k,
544
+ attn_v,
545
+ attn_mask=attn_mask if not is_causal else None,
546
+ scale=1.0,
547
+ enable_gqa=self.num_gqa_groups > 1,
548
+ is_causal=is_causal,
549
+ )
550
+
551
  attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
552
  output = self.o_proj(attn_output)
553
 
 
560
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
561
  super().__init__()
562
  self.config = config
563
+ enc_config = config.encoder_config
564
+ embed_dim = enc_config.hidden_size
565
+ self.compute_dtype = compute_dtype
566
 
567
  self.pre_sa_norm = RMSNorm(
568
  embed_dim,
569
+ eps=enc_config.norm_eps,
570
  dtype=torch.float32,
571
  )
572
+ self.self_attention = SelfAttention(
573
+ enc_config,
574
  q_embed_dim=embed_dim,
575
  kv_embed_dim=embed_dim,
576
+ num_query_heads=enc_config.num_attention_heads,
577
+ num_kv_heads=enc_config.num_key_value_heads,
578
  head_dim=enc_config.head_dim,
579
  compute_dtype=compute_dtype,
 
580
  out_embed_dim=embed_dim,
581
  )
582
  self.post_sa_norm = RMSNorm(
583
  embed_dim,
584
+ eps=enc_config.norm_eps,
585
  dtype=torch.float32,
586
  )
587
  self.mlp = MlpBlock(
588
  embed_dim=embed_dim,
589
+ intermediate_dim=enc_config.intermediate_size,
590
  compute_dtype=compute_dtype,
591
  )
592
 
 
596
  state: EncoderInferenceState,
597
  ) -> torch.Tensor:
598
  residual = x
599
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
600
+
601
  sa_out = self.self_attention(
602
+ X=x_norm,
 
603
  q_positions=state.positions,
604
  kv_positions=state.positions,
605
  attn_mask=state.attn_mask,
 
607
  x = residual + sa_out
608
 
609
  residual = x
610
+ x_norm = self.post_sa_norm(x).to(self.compute_dtype)
611
  mlp_out = self.mlp(x_norm)
612
  x = residual + mlp_out
613
 
 
620
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
621
  super().__init__()
622
  self.config = config
623
+ enc_config = config.encoder_config
624
+ self.compute_dtype = compute_dtype
625
 
626
  self.embedding = nn.Embedding(
627
+ enc_config.vocab_size,
628
+ enc_config.hidden_size,
629
  dtype=compute_dtype,
630
  )
631
  self.layers = nn.ModuleList(
632
+ [
633
+ EncoderLayer(config, compute_dtype)
634
+ for _ in range(enc_config.num_hidden_layers)
635
+ ]
636
  )
637
  self.norm = RMSNorm(
638
+ enc_config.hidden_size,
639
+ eps=enc_config.norm_eps,
640
  dtype=torch.float32,
641
  )
642
 
 
650
  for layer in self.layers:
651
  x = layer(x, state)
652
 
653
+ x = self.norm(x).to(self.compute_dtype)
654
  return x
655
 
656
 
 
660
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
661
  super().__init__()
662
  self.config = config
663
+ dec_config = config.decoder_config
664
+ enc_config = config.encoder_config
665
+ dec_embed_dim = dec_config.hidden_size
666
+ enc_embed_dim = enc_config.hidden_size
667
+ self.compute_dtype = compute_dtype
668
 
669
  # Norms
670
  self.pre_sa_norm = RMSNorm(
671
  dec_embed_dim,
672
+ eps=dec_config.norm_eps,
673
  dtype=torch.float32,
674
  )
675
  self.pre_ca_norm = RMSNorm(
676
  dec_embed_dim,
677
+ eps=dec_config.norm_eps,
678
  dtype=torch.float32,
679
  )
680
  self.pre_mlp_norm = RMSNorm(
681
  dec_embed_dim,
682
+ eps=dec_config.norm_eps,
683
  dtype=torch.float32,
684
  )
685
 
686
  # Self-Attention (GQA) with Causal Masking
687
+ self.self_attention = SelfAttention(
688
+ dec_config,
689
  q_embed_dim=dec_embed_dim,
690
  kv_embed_dim=dec_embed_dim,
691
+ num_query_heads=dec_config.num_attention_heads,
692
+ num_kv_heads=dec_config.num_key_value_heads,
693
+ head_dim=dec_config.head_dim,
694
  compute_dtype=compute_dtype,
 
695
  out_embed_dim=dec_embed_dim,
696
  )
697
  # Cross-Attention (MHA)
698
+ self.cross_attention = CrossAttention(
699
+ dec_config,
700
  q_embed_dim=dec_embed_dim,
701
  kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
702
+ num_query_heads=dec_config.cross_num_attention_heads,
703
+ num_kv_heads=dec_config.cross_num_key_value_heads,
704
  head_dim=dec_config.cross_head_dim,
705
  compute_dtype=compute_dtype,
 
706
  out_embed_dim=dec_embed_dim,
707
  )
708
  # MLP
709
  self.mlp = MlpBlock(
710
  embed_dim=dec_embed_dim,
711
+ intermediate_dim=dec_config.intermediate_size,
712
  compute_dtype=compute_dtype,
713
  )
714
 
 
719
  self_attn_cache: KVCache | None = None,
720
  cross_attn_cache: KVCache | None = None,
721
  prefill: bool = False,
722
+ current_idx: int = 0,
723
  ) -> torch.Tensor:
724
  residual = x
725
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
726
+
727
+ self_attn_mask = state.casual_attn_mask[None, None, current_idx]
728
 
729
  sa_out = self.self_attention(
730
+ X=x_norm, # (2, 1, D)
 
731
  q_positions=state.dec_positions, # (2, 1)
732
  kv_positions=state.dec_positions, # (2, 1)
733
+ attn_mask=self_attn_mask,
734
  cache=self_attn_cache,
735
  prefill=prefill,
736
  is_causal=prefill,
737
+ current_idx=current_idx,
738
  )
739
 
740
  x = residual + sa_out
741
 
742
  residual = x
743
+ x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
744
  ca_out = self.cross_attention(
745
  Xq=x_norm,
 
746
  q_positions=state.dec_positions,
747
  kv_positions=state.enc_positions,
748
+ attn_mask=state.cross_attn_mask,
749
  cache=cross_attn_cache,
750
  )
751
  x = residual + ca_out
752
 
753
  residual = x
754
+ x_norm = self.pre_mlp_norm(x).to(self.compute_dtype)
755
  mlp_out = self.mlp(x_norm)
756
  x = residual + mlp_out
757
 
 
764
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
765
  super().__init__()
766
  self.config = config
767
+ dec_config = config.decoder_config
768
+ self.num_channels = dec_config.num_channels
769
+ self.num_layers = dec_config.num_hidden_layers
 
 
770
 
771
  self.embeddings = nn.ModuleList(
772
  [
773
  nn.Embedding(
774
+ dec_config.vocab_size, dec_config.hidden_size, dtype=compute_dtype
775
  )
776
  for _ in range(self.num_channels)
777
  ]
 
784
  )
785
 
786
  self.norm = RMSNorm(
787
+ dec_config.hidden_size,
788
+ eps=dec_config.norm_eps,
789
  dtype=torch.float32,
790
  )
791
 
792
  self.logits_dense = DenseGeneral(
793
+ in_shapes=(dec_config.hidden_size,),
794
+ out_features=(self.num_channels, dec_config.vocab_size),
795
  axis=(-1,),
796
  weight_dtype=compute_dtype,
797
  )
 
799
  def precompute_cross_attn_cache(
800
  self,
801
  enc_out: torch.Tensor, # (B, S, E)
 
802
  ) -> list[KVCache]:
803
  """
804
  Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
 
810
  k_proj = cross_attn_module.k_proj(enc_out)
811
  v_proj = cross_attn_module.v_proj(enc_out)
812
 
 
813
  k = k_proj.transpose(1, 2)
814
  v = v_proj.transpose(1, 2)
815
 
 
821
  self,
822
  tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
823
  state: DecoderInferenceState,
824
+ current_idx: int,
825
  ) -> torch.Tensor:
826
  """
827
  Performs a single decoding step, managing KV caches layer by layer.
 
828
  Returns:
829
  A tuple containing:
830
  - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
 
844
  state,
845
  self_attn_cache=self_cache,
846
  cross_attn_cache=cross_cache,
847
+ current_idx=current_idx,
848
  )
849
 
850
  x = self.norm(x)
 
857
  ) -> torch.Tensor:
858
  """
859
  Forward pass for the Decoder stack, managing KV caches.
 
860
  Args:
861
  tgt_ids_BxTxC: Target token IDs (B, T, C).
862
  encoder_out: Output from the encoder (B, S, E).
 
870
  precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
871
  derived from `encoder_out`. This is passed identically
872
  to all layers.
 
873
  Returns:
874
  A tuple containing:
875
  - logits: The final output logits (B, T, C * V), cast to float32.
 
904
  return logits_BxTxCxV.to(torch.float32)
905
 
906
 
907
+ class DiaModel(
908
+ nn.Module,
909
+ PyTorchModelHubMixin,
910
+ repo_url="https://github.com/nari-labs/dia",
911
+ pipeline_tag="text-to-speech",
912
+ license="apache-2.0",
913
+ coders={
914
+ DiaConfig: (
915
+ lambda x: x.model_dump(),
916
+ lambda data: DiaConfig.model_validate(data),
917
+ ),
918
+ },
919
+ ):
920
  """PyTorch Dia Model using DenseGeneral."""
921
 
922
  def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
dia/model.py CHANGED
@@ -1,17 +1,16 @@
1
  import time
2
  from enum import Enum
 
3
 
4
- import dac
5
  import numpy as np
6
  import torch
 
7
  import torchaudio
8
- from huggingface_hub import hf_hub_download
9
 
10
  from .audio import (
11
  apply_audio_delay,
12
  build_delay_indices,
13
  build_revert_indices,
14
- decode,
15
  revert_audio_delay,
16
  )
17
  from .config import DiaConfig
@@ -20,6 +19,7 @@ from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
20
 
21
 
22
  DEFAULT_SAMPLE_RATE = 44100
 
23
 
24
 
25
  def _get_default_device():
@@ -34,16 +34,29 @@ def _sample_next_token(
34
  logits_BCxV: torch.Tensor,
35
  temperature: float,
36
  top_p: float,
37
- cfg_filter_top_k: int | None = None,
 
38
  ) -> torch.Tensor:
39
  if temperature == 0.0:
40
  return torch.argmax(logits_BCxV, dim=-1)
41
 
42
  logits_BCxV = logits_BCxV / temperature
43
- if cfg_filter_top_k is not None:
44
- _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
45
  mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
46
- mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
47
  logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
48
 
49
  if top_p < 1.0:
@@ -54,13 +67,15 @@ def _sample_next_token(
54
  cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
55
 
56
  sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
57
- sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
58
- ..., :-1
59
- ].clone()
60
- sorted_indices_to_remove_BCxV[..., 0] = 0
 
 
61
 
62
  indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
63
- indices_to_remove_BCxV.scatter_(
64
  dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
65
  )
66
  logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
@@ -94,12 +109,15 @@ class Dia:
94
  config: DiaConfig,
95
  compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
96
  device: torch.device | None = None,
 
97
  ):
98
  """Initializes the Dia model.
99
 
100
  Args:
101
  config: The configuration object for the model.
 
102
  device: The device to load the model onto. If None, will automatically select the best available device.
 
103
 
104
  Raises:
105
  RuntimeError: If there is an error loading the DAC model.
@@ -110,8 +128,16 @@ class Dia:
110
  if isinstance(compute_dtype, str):
111
  compute_dtype = ComputeDtype(compute_dtype)
112
  self.compute_dtype = compute_dtype.to_dtype()
113
- self.model = DiaModel(config, self.compute_dtype)
114
  self.dac_model = None
 
 
 
 
 
 
 
 
115
 
116
  @classmethod
117
  def from_local(
@@ -120,13 +146,16 @@ class Dia:
120
  checkpoint_path: str,
121
  compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
122
  device: torch.device | None = None,
 
123
  ) -> "Dia":
124
  """Loads the Dia model from local configuration and checkpoint files.
125
 
126
  Args:
127
  config_path: Path to the configuration JSON file.
128
  checkpoint_path: Path to the model checkpoint (.pth) file.
 
129
  device: The device to load the model onto. If None, will automatically select the best available device.
 
130
 
131
  Returns:
132
  An instance of the Dia model loaded with weights and set to eval mode.
@@ -139,7 +168,7 @@ class Dia:
139
  if config is None:
140
  raise FileNotFoundError(f"Config file not found at {config_path}")
141
 
142
- dia = cls(config, compute_dtype, device)
143
 
144
  try:
145
  state_dict = torch.load(checkpoint_path, map_location=dia.device)
@@ -153,15 +182,17 @@ class Dia:
153
 
154
  dia.model.to(dia.device)
155
  dia.model.eval()
156
- dia._load_dac_model()
 
157
  return dia
158
 
159
  @classmethod
160
  def from_pretrained(
161
  cls,
162
- model_name: str = "nari-labs/Dia-1.6B",
163
  compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
164
  device: torch.device | None = None,
 
165
  ) -> "Dia":
166
  """Loads the Dia model from a Hugging Face Hub repository.
167
 
@@ -169,8 +200,10 @@ class Dia:
169
  repository ID and then loads the model.
170
 
171
  Args:
172
- model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
 
173
  device: The device to load the model onto. If None, will automatically select the best available device.
 
174
 
175
  Returns:
176
  An instance of the Dia model loaded with weights and set to eval mode.
@@ -179,110 +212,192 @@ class Dia:
179
  FileNotFoundError: If config or checkpoint download/loading fails.
180
  RuntimeError: If there is an error loading the checkpoint.
181
  """
182
- config_path = hf_hub_download(repo_id=model_name, filename="config.json")
183
- checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
184
- return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def _load_dac_model(self):
 
 
 
 
 
 
 
 
 
 
187
  try:
188
  dac_model_path = dac.utils.download()
189
  dac_model = dac.DAC.load(dac_model_path).to(self.device)
 
190
  except Exception as e:
191
  raise RuntimeError("Failed to load DAC model") from e
192
  self.dac_model = dac_model
193
 
194
- def _prepare_text_input(self, text: str) -> torch.Tensor:
195
- """Encodes text prompt, pads, and creates attention mask and positions."""
196
- text_pad_value = self.config.data.text_pad_value
197
- max_len = self.config.data.text_length
 
 
 
 
 
 
 
 
 
198
 
199
  byte_text = text.encode("utf-8")
 
 
200
  replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
201
  text_tokens = list(replaced_bytes)
 
 
 
 
 
202
 
203
- current_len = len(text_tokens)
204
- padding_needed = max_len - current_len
205
- if padding_needed <= 0:
206
- text_tokens = text_tokens[:max_len]
207
- padded_text_np = np.array(text_tokens, dtype=np.uint8)
208
- else:
209
- padded_text_np = np.pad(
210
- text_tokens,
211
- (0, padding_needed),
212
- mode="constant",
213
- constant_values=text_pad_value,
214
- ).astype(np.uint8)
215
-
216
- src_tokens = (
217
- torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
218
- ) # [1, S]
219
  return src_tokens
220
 
221
  def _prepare_audio_prompt(
222
- self, audio_prompt: torch.Tensor | None
223
- ) -> tuple[torch.Tensor, int]:
224
- num_channels = self.config.data.channels
225
- audio_bos_value = self.config.data.audio_bos_value
226
- audio_pad_value = self.config.data.audio_pad_value
227
- delay_pattern = self.config.data.delay_pattern
228
- max_delay_pattern = max(delay_pattern)
229
 
230
- prefill = torch.full(
231
- (1, num_channels),
232
- fill_value=audio_bos_value,
233
- dtype=torch.int,
234
- device=self.device,
235
- )
 
236
 
237
- prefill_step = 1
 
 
 
 
 
 
 
 
 
 
 
238
 
239
- if audio_prompt is not None:
240
- prefill_step += audio_prompt.shape[0]
241
- prefill = torch.cat([prefill, audio_prompt], dim=0)
 
 
242
 
243
- delay_pad_tensor = torch.full(
244
- (max_delay_pattern, num_channels),
245
  fill_value=-1,
246
  dtype=torch.int,
247
  device=self.device,
248
  )
249
- prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
 
 
 
 
 
 
 
 
 
 
250
 
251
  delay_precomp = build_delay_indices(
252
- B=1,
253
- T=prefill.shape[0],
254
  C=num_channels,
255
  delay_pattern=delay_pattern,
256
  )
257
 
258
- prefill = apply_audio_delay(
259
- audio_BxTxC=prefill.unsqueeze(0),
260
- pad_value=audio_pad_value,
261
  bos_value=audio_bos_value,
262
  precomp=delay_precomp,
263
- ).squeeze(0)
264
 
265
- return prefill, prefill_step
266
 
267
  def _prepare_generation(
268
- self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
 
 
 
 
269
  ):
270
- enc_input_cond = self._prepare_text_input(text)
271
- enc_input_uncond = torch.zeros_like(enc_input_cond)
272
- enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
273
 
274
- if isinstance(audio_prompt, str):
275
- audio_prompt = self.load_audio(audio_prompt)
276
- prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
 
277
 
278
- if verbose:
279
- print("generate: data loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
282
  encoder_out = self.model.encoder(enc_input, enc_state)
283
 
284
  dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
285
- encoder_out, enc_state.positions
286
  )
287
  dec_state = DecoderInferenceState.new(
288
  self.config,
@@ -290,15 +405,18 @@ class Dia:
290
  encoder_out,
291
  dec_cross_attn_cache,
292
  self.compute_dtype,
 
293
  )
294
- dec_output = DecoderOutput.new(self.config, self.device)
295
- dec_output.prefill(prefill, prefill_step)
296
 
297
- dec_step = prefill_step - 1
 
 
 
298
  if dec_step > 0:
299
  dec_state.prepare_step(0, dec_step)
300
- tokens_BxTxC = (
301
- dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
302
  )
303
  self.model.decoder.forward(tokens_BxTxC, dec_state)
304
 
@@ -311,43 +429,114 @@ class Dia:
311
  cfg_scale: float,
312
  temperature: float,
313
  top_p: float,
314
- cfg_filter_top_k: int,
 
315
  ) -> torch.Tensor:
316
- audio_eos_value = self.config.data.audio_eos_value
317
- logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
- logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
320
- uncond_logits_CxV = logits_last_BxCxV[0, :, :]
321
- cond_logits_CxV = logits_last_BxCxV[1, :, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
324
- logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
325
- logits_CxV[1:, audio_eos_value:] = -torch.inf
326
 
327
- pred_C = _sample_next_token(
328
- logits_CxV.float(),
329
  temperature=temperature,
330
  top_p=top_p,
331
- cfg_filter_top_k=cfg_filter_top_k,
 
332
  )
333
- return pred_C
334
 
335
- def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
336
- num_channels = self.config.data.channels
337
- seq_length = generated_codes.shape[0]
338
- delay_pattern = self.config.data.delay_pattern
339
- audio_pad_value = self.config.data.audio_pad_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  max_delay_pattern = max(delay_pattern)
341
 
342
  revert_precomp = build_revert_indices(
343
- B=1,
344
  T=seq_length,
345
  C=num_channels,
346
  delay_pattern=delay_pattern,
347
  )
348
 
349
  codebook = revert_audio_delay(
350
- audio_BxTxC=generated_codes.unsqueeze(0),
351
  pad_value=audio_pad_value,
352
  precomp=revert_precomp,
353
  T=seq_length,
@@ -358,20 +547,85 @@ class Dia:
358
  invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
359
  codebook[invalid_mask] = 0
360
 
361
- audio = decode(self.dac_model, codebook.transpose(1, 2))
362
 
363
- return audio.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  def load_audio(self, audio_path: str) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
367
  if sr != DEFAULT_SAMPLE_RATE:
368
  audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
369
- audio = audio.to(self.device).unsqueeze(0) # 1, C, T
370
- audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
371
- _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
372
- return encoded_frame.squeeze(0).transpose(0, 1)
 
 
373
 
374
  def save_audio(self, path: str, audio: np.ndarray):
 
 
 
 
 
 
 
 
 
375
  import soundfile as sf
376
 
377
  sf.write(path, audio, DEFAULT_SAMPLE_RATE)
@@ -379,23 +633,63 @@ class Dia:
379
  @torch.inference_mode()
380
  def generate(
381
  self,
382
- text: str,
383
- max_tokens: int | None = None,
384
  cfg_scale: float = 3.0,
385
- temperature: float = 1.3,
386
  top_p: float = 0.95,
387
  use_torch_compile: bool = False,
388
- cfg_filter_top_k: int = 35,
389
- audio_prompt: str | torch.Tensor | None = None,
390
- audio_prompt_path: str | None = None,
 
 
 
 
 
 
391
  use_cfg_filter: bool | None = None,
392
  verbose: bool = False,
393
- ) -> np.ndarray:
394
- audio_eos_value = self.config.data.audio_eos_value
395
- audio_pad_value = self.config.data.audio_pad_value
396
- delay_pattern = self.config.data.delay_pattern
397
- max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  max_delay_pattern = max(delay_pattern)
 
 
 
399
  self.model.eval()
400
 
401
  if audio_prompt_path:
@@ -407,82 +701,179 @@ class Dia:
407
  if verbose:
408
  total_start_time = time.time()
409
 
410
- dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
411
- dec_step = dec_output.prefill_step - 1
412
-
413
- bos_countdown = max_delay_pattern
414
- eos_detected = False
415
- eos_countdown = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
- if use_torch_compile:
418
- step_fn = torch.compile(self._decoder_step, mode="default")
419
  else:
420
- step_fn = self._decoder_step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
  if verbose:
423
  print("generate: starting generation loop")
424
  if use_torch_compile:
425
  print(
426
- "generate: by using use_torch_compile=True, the first step would take long"
427
  )
428
  start_time = time.time()
429
 
 
430
  while dec_step < max_tokens:
 
 
 
 
 
431
  dec_state.prepare_step(dec_step)
432
- tokens_Bx1xC = (
433
- dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
434
- )
435
- pred_C = step_fn(
 
436
  tokens_Bx1xC,
437
  dec_state,
438
  cfg_scale,
439
  temperature,
440
  top_p,
441
  cfg_filter_top_k,
 
442
  )
443
 
444
- if (
445
- not eos_detected and pred_C[0] == audio_eos_value
446
- ) or dec_step == max_tokens - max_delay_pattern - 1:
447
- eos_detected = True
448
- eos_countdown = max_delay_pattern
449
-
450
- if eos_countdown > 0:
451
- step_after_eos = max_delay_pattern - eos_countdown
452
- for i, d in enumerate(delay_pattern):
453
- if step_after_eos == d:
454
- pred_C[i] = audio_eos_value
455
- elif step_after_eos > d:
456
- pred_C[i] = audio_pad_value
457
- eos_countdown -= 1
458
-
459
- bos_countdown = max(0, bos_countdown - 1)
460
- dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
461
-
462
- if eos_countdown == 0:
463
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  dec_step += 1
 
466
  if verbose and dec_step % 86 == 0:
467
  duration = time.time() - start_time
468
- print(
469
- f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
470
- )
 
471
  start_time = time.time()
472
 
473
- if dec_output.prefill_step >= dec_step + 1:
474
- print("Warning: Nothing generated")
475
- return None
476
 
477
- generated_codes = dec_output.generated_tokens[
478
- dec_output.prefill_step : dec_step + 1, :
479
- ]
480
 
481
- if verbose:
482
- total_step = dec_step + 1 - dec_output.prefill_step
483
- total_duration = time.time() - total_start_time
484
- print(
485
- f"generate: total step={total_step}, total duration={total_duration:.3f}s"
 
 
 
 
 
 
 
 
 
 
 
 
486
  )
487
 
488
- return self._generate_output(generated_codes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
  from enum import Enum
3
+ from typing import Callable
4
 
 
5
  import numpy as np
6
  import torch
7
+ import torch.nn.functional as F
8
  import torchaudio
 
9
 
10
  from .audio import (
11
  apply_audio_delay,
12
  build_delay_indices,
13
  build_revert_indices,
 
14
  revert_audio_delay,
15
  )
16
  from .config import DiaConfig
 
19
 
20
 
21
  DEFAULT_SAMPLE_RATE = 44100
22
+ SAMPLE_RATE_RATIO = 512
23
 
24
 
25
  def _get_default_device():
 
34
  logits_BCxV: torch.Tensor,
35
  temperature: float,
36
  top_p: float,
37
+ top_k: int | None,
38
+ audio_eos_value: int,
39
  ) -> torch.Tensor:
40
  if temperature == 0.0:
41
  return torch.argmax(logits_BCxV, dim=-1)
42
 
43
  logits_BCxV = logits_BCxV / temperature
44
+
45
+ if audio_eos_value is not None and audio_eos_value >= 0:
46
+ top_logit_indices_BC = torch.argmax(logits_BCxV, dim=-1)
47
+ eos_not_highest_mask_BC = top_logit_indices_BC != audio_eos_value
48
+ mask_eos_unless_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
49
+ mask_eos_unless_highest_BCxV[eos_not_highest_mask_BC, audio_eos_value] = True
50
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_unless_highest_BCxV, -torch.inf)
51
+ eos_highest_mask_BC = top_logit_indices_BC == audio_eos_value
52
+ mask_eos_highest_BCxV = torch.zeros_like(logits_BCxV, dtype=torch.bool)
53
+ mask_eos_highest_BCxV[eos_highest_mask_BC, :audio_eos_value] = True
54
+ logits_BCxV = logits_BCxV.masked_fill(mask_eos_highest_BCxV, -torch.inf)
55
+
56
+ if top_k is not None:
57
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=top_k, dim=-1)
58
  mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
59
+ mask = mask.scatter(dim=-1, index=top_k_indices_BCxV, value=False)
60
  logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
61
 
62
  if top_p < 1.0:
 
67
  cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
68
 
69
  sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
70
+ sorted_indices_to_remove_BCxV = torch.roll(
71
+ sorted_indices_to_remove_BCxV, shifts=1, dims=-1
72
+ )
73
+ sorted_indices_to_remove_BCxV[..., 0] = torch.zeros_like(
74
+ sorted_indices_to_remove_BCxV[..., 0]
75
+ )
76
 
77
  indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
78
+ indices_to_remove_BCxV = indices_to_remove_BCxV.scatter(
79
  dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
80
  )
81
  logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
 
109
  config: DiaConfig,
110
  compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
111
  device: torch.device | None = None,
112
+ load_dac: bool = True,
113
  ):
114
  """Initializes the Dia model.
115
 
116
  Args:
117
  config: The configuration object for the model.
118
+ compute_dtype: The computation dtype to use.
119
  device: The device to load the model onto. If None, will automatically select the best available device.
120
+ load_dac: Whether to load the DAC model.
121
 
122
  Raises:
123
  RuntimeError: If there is an error loading the DAC model.
 
128
  if isinstance(compute_dtype, str):
129
  compute_dtype = ComputeDtype(compute_dtype)
130
  self.compute_dtype = compute_dtype.to_dtype()
131
+ self.model: DiaModel = DiaModel(config, self.compute_dtype)
132
  self.dac_model = None
133
+ self._compiled_step = None
134
+ self.load_dac = load_dac
135
+
136
+ if not self.load_dac:
137
+ print("Warning: DAC model will not be loaded. This is not recommended.")
138
+
139
+ if torch.cuda.is_available():
140
+ torch.backends.cuda.matmul.allow_tf32 = True
141
 
142
  @classmethod
143
  def from_local(
 
146
  checkpoint_path: str,
147
  compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
148
  device: torch.device | None = None,
149
+ load_dac: bool = True,
150
  ) -> "Dia":
151
  """Loads the Dia model from local configuration and checkpoint files.
152
 
153
  Args:
154
  config_path: Path to the configuration JSON file.
155
  checkpoint_path: Path to the model checkpoint (.pth) file.
156
+ compute_dtype: The computation dtype to use.
157
  device: The device to load the model onto. If None, will automatically select the best available device.
158
+ load_dac: Whether to load the DAC model.
159
 
160
  Returns:
161
  An instance of the Dia model loaded with weights and set to eval mode.
 
168
  if config is None:
169
  raise FileNotFoundError(f"Config file not found at {config_path}")
170
 
171
+ dia = cls(config, compute_dtype, device, load_dac)
172
 
173
  try:
174
  state_dict = torch.load(checkpoint_path, map_location=dia.device)
 
182
 
183
  dia.model.to(dia.device)
184
  dia.model.eval()
185
+ if load_dac:
186
+ dia._load_dac_model()
187
  return dia
188
 
189
  @classmethod
190
  def from_pretrained(
191
  cls,
192
+ model_name: str = "nari-labs/Dia-1.6B-0626",
193
  compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
194
  device: torch.device | None = None,
195
+ load_dac: bool = True,
196
  ) -> "Dia":
197
  """Loads the Dia model from a Hugging Face Hub repository.
198
 
 
200
  repository ID and then loads the model.
201
 
202
  Args:
203
+ model_name: The Hugging Face Hub repository ID (e.g., "nari-labs/Dia-1.6B-0626").
204
+ compute_dtype: The computation dtype to use.
205
  device: The device to load the model onto. If None, will automatically select the best available device.
206
+ load_dac: Whether to load the DAC model.
207
 
208
  Returns:
209
  An instance of the Dia model loaded with weights and set to eval mode.
 
212
  FileNotFoundError: If config or checkpoint download/loading fails.
213
  RuntimeError: If there is an error loading the checkpoint.
214
  """
215
+ if isinstance(compute_dtype, str):
216
+ compute_dtype = ComputeDtype(compute_dtype)
217
+
218
+ # Load model directly using DiaModel's from_pretrained which handles HF download
219
+ try:
220
+ loaded_model = DiaModel.from_pretrained(
221
+ model_name, compute_dtype=compute_dtype.to_dtype()
222
+ )
223
+ except Exception as e:
224
+ raise RuntimeError(
225
+ f"Error loading model from Hugging Face Hub ({model_name})"
226
+ ) from e
227
+
228
+ config = loaded_model.config # Get config from the loaded model
229
+ dia = cls(config, compute_dtype, device, load_dac)
230
+
231
+ dia.model = loaded_model # Assign the already loaded model
232
+ dia.model.to(dia.device)
233
+ dia.model.eval()
234
+ if load_dac:
235
+ dia._load_dac_model()
236
+ return dia
237
 
238
  def _load_dac_model(self):
239
+ """Loads the Descript Audio Codec (DAC) model.
240
+
241
+ Downloads the DAC model if necessary and loads it onto the specified device.
242
+ Sets the DAC model to evaluation mode.
243
+
244
+ Raises:
245
+ RuntimeError: If downloading or loading the DAC model fails.
246
+ """
247
+ import dac
248
+
249
  try:
250
  dac_model_path = dac.utils.download()
251
  dac_model = dac.DAC.load(dac_model_path).to(self.device)
252
+ dac_model.eval() # Ensure DAC is in eval mode
253
  except Exception as e:
254
  raise RuntimeError("Failed to load DAC model") from e
255
  self.dac_model = dac_model
256
 
257
+ def _encode_text(self, text: str) -> torch.Tensor:
258
+ """Encodes the input text string into a tensor of token IDs using byte-level encoding.
259
+
260
+ Special tokens [S1] and [S2] are replaced by their byte values. The resulting
261
+ sequence is truncated to the maximum configured text length.
262
+
263
+ Args:
264
+ text: The input text string.
265
+
266
+ Returns:
267
+ A tensor containing the encoded byte token IDs.
268
+ """
269
+ max_len = self.config.encoder_config.max_position_embeddings
270
 
271
  byte_text = text.encode("utf-8")
272
+ # Replace special tokens with their byte values if needed by the specific tokenizer/config
273
+ # Assuming byte values 1 and 2 are correct placeholders based on original code
274
  replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
275
  text_tokens = list(replaced_bytes)
276
+ return torch.tensor(
277
+ text_tokens[:max_len],
278
+ dtype=torch.long,
279
+ device=self.device,
280
+ )
281
 
282
+ def _pad_text_input(self, text_tokens: list[torch.Tensor]) -> torch.Tensor:
283
+ """Pads the text input to the maximum length."""
284
+ text_pad_value = 0
285
+ max_len = self.config.encoder_config.max_position_embeddings
286
+ batch_size = len(text_tokens)
287
+
288
+ src_tokens = torch.full(
289
+ (batch_size, 1, max_len),
290
+ fill_value=text_pad_value,
291
+ dtype=torch.long,
292
+ device=self.device,
293
+ )
294
+ for i in range(batch_size):
295
+ current_len = len(text_tokens[i])
296
+ src_tokens[i, 0, :current_len] = text_tokens[i]
 
297
  return src_tokens
298
 
299
  def _prepare_audio_prompt(
300
+ self, audio_prompts: list[torch.Tensor | None]
301
+ ) -> tuple[torch.Tensor, list[int]]:
302
+ """Prepares the audio prompt tensor for the decoder.
 
 
 
 
303
 
304
+ Handles padding, adds the beginning-of-sequence (BOS) token, applies the
305
+ delay pattern, and determines the number of prefill steps for each item
306
+ in the batch.
307
+
308
+ Args:
309
+ audio_prompts: A list of audio prompt tensors (encoded DAC frames) or None.
310
+ Each tensor should have shape [T, C].
311
 
312
+ Returns:
313
+ A tuple containing:
314
+ - delayed_batch (torch.Tensor): The prepared audio prompt tensor with
315
+ delays applied, shape [B, T_max_padded, C].
316
+ - prefill_steps (list[int]): A list containing the number of valid
317
+ tokens (including BOS) for each prompt in the batch.
318
+ """
319
+ num_channels = self.config.decoder_config.num_channels
320
+ audio_bos_value = self.config.bos_token_id
321
+ delay_pattern = self.config.delay_pattern
322
+ max_delay_pattern = max(delay_pattern)
323
+ batch_size = len(audio_prompts)
324
 
325
+ max_len = (
326
+ max(p.shape[0] if p is not None else 0 for p in audio_prompts)
327
+ + max_delay_pattern
328
+ )
329
+ prefill_steps = []
330
 
331
+ prefill = torch.full(
332
+ (batch_size, max_len, num_channels),
333
  fill_value=-1,
334
  dtype=torch.int,
335
  device=self.device,
336
  )
337
+
338
+ prefill[:, 0, :] = audio_bos_value
339
+
340
+ for i in range(batch_size):
341
+ prompt = audio_prompts[i]
342
+ if prompt is not None:
343
+ prompt = prompt.to(device=self.device, dtype=torch.int)
344
+ prefill[i, 1 : prompt.shape[0] + 1, :] = prompt
345
+ prefill_steps.append(prompt.shape[0] + 1)
346
+ else:
347
+ prefill_steps.append(1)
348
 
349
  delay_precomp = build_delay_indices(
350
+ B=batch_size,
351
+ T=max_len,
352
  C=num_channels,
353
  delay_pattern=delay_pattern,
354
  )
355
 
356
+ delayed_batch = apply_audio_delay(
357
+ audio_BxTxC=prefill,
358
+ pad_value=-1,
359
  bos_value=audio_bos_value,
360
  precomp=delay_precomp,
361
+ )
362
 
363
+ return delayed_batch, prefill_steps
364
 
365
  def _prepare_generation(
366
+ self,
367
+ text: torch.Tensor,
368
+ audio_prompts: list[torch.Tensor | None],
369
+ max_tokens: int | None = None,
370
+ attn_fn: Callable = F.scaled_dot_product_attention,
371
  ):
372
+ """Initializes the model state for generation.
 
 
373
 
374
+ Encodes the text input (conditional and unconditional), prepares the
375
+ encoder and decoder states (including KV caches and cross-attention),
376
+ prepares the audio prompt, and performs the initial decoder prefill steps
377
+ based on the audio prompts.
378
 
379
+ Args:
380
+ text: The padded text input tensor, shape [B, 1, T_text].
381
+ audio_prompts: A list of prepared audio prompt tensors or None.
382
+
383
+ Returns:
384
+ A tuple containing:
385
+ - dec_state (DecoderInferenceState): The initialized decoder state.
386
+ - dec_output (DecoderOutput): The initialized decoder output manager,
387
+ containing the prefilled audio tokens.
388
+ """
389
+ batch_size = text.shape[0]
390
+
391
+ enc_input_uncond = torch.zeros_like(text)
392
+ enc_input_cond = text
393
+ stacked_inputs = torch.stack([enc_input_uncond, enc_input_cond], dim=1)
394
+ enc_input = stacked_inputs.view(2 * batch_size, -1)
395
 
396
  enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
397
  encoder_out = self.model.encoder(enc_input, enc_state)
398
 
399
  dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
400
+ encoder_out
401
  )
402
  dec_state = DecoderInferenceState.new(
403
  self.config,
 
405
  encoder_out,
406
  dec_cross_attn_cache,
407
  self.compute_dtype,
408
+ max_generation_length=max_tokens,
409
  )
410
+ prefill, prefill_steps = self._prepare_audio_prompt(audio_prompts)
 
411
 
412
+ dec_output = DecoderOutput.new(batch_size, self.config, self.device)
413
+ dec_output.prefill(prefill, prefill_steps)
414
+
415
+ dec_step = min(prefill_steps) - 1
416
  if dec_step > 0:
417
  dec_state.prepare_step(0, dec_step)
418
+ tokens_BxTxC = dec_output.get_tokens_at(0, dec_step).repeat_interleave(
419
+ 2, dim=0
420
  )
421
  self.model.decoder.forward(tokens_BxTxC, dec_state)
422
 
 
429
  cfg_scale: float,
430
  temperature: float,
431
  top_p: float,
432
+ top_k: int,
433
+ current_idx: int,
434
  ) -> torch.Tensor:
435
+ """Performs a single step of the decoder inference.
436
+
437
+ Takes the tokens from the previous step, runs them through the decoder
438
+ (for both conditional and unconditional paths), applies classifier-free
439
+ guidance (CFG), samples the next token using temperature, top-p, and top-k
440
+ sampling, and applies constraints (e.g., preventing EOS in certain channels).
441
+
442
+ Args:
443
+ tokens_Bx1xC: The input tokens for the current step, shape [2*B, 1, C].
444
+ Repeated for CFG (unconditional and conditional).
445
+ dec_state: The current state of the decoder (KV caches, etc.).
446
+ cfg_scale: The scale factor for classifier-free guidance.
447
+ temperature: The temperature for sampling.
448
+ top_p: The cumulative probability threshold for top-p sampling.
449
+ top_k: The number of top logits to consider for top-k sampling.
450
+ current_idx: The current generation step index.
451
+
452
+ Returns:
453
+ torch.Tensor: The sampled next tokens for each item in the batch,
454
+ shape [B, C].
455
+ """
456
+ B = tokens_Bx1xC.shape[0] // 2
457
+
458
+ audio_eos_value = self.config.eos_token_id
459
+ logits_Bx1xCxV = self.model.decoder.decode_step(
460
+ tokens_Bx1xC, dec_state, current_idx
461
+ )
462
+
463
+ logits_last_2BxCxV = logits_Bx1xCxV[:, -1]
464
+ logits_last_Bx2xCxV = logits_last_2BxCxV.view(
465
+ B, 2, *logits_last_2BxCxV.shape[1:]
466
+ )
467
 
468
+ uncond_logits_BxCxV = logits_last_Bx2xCxV[:, 0, :, :] # Shape [B, C, V]
469
+ cond_logits_BxCxV = logits_last_Bx2xCxV[:, 1, :, :] # Shape [B, C, V]
470
+ logits_BxCxV = cond_logits_BxCxV + cfg_scale * (
471
+ cond_logits_BxCxV - uncond_logits_BxCxV
472
+ )
473
+
474
+ _, top_k_indices_BxCxk = torch.topk(logits_BxCxV, k=top_k, dim=-1)
475
+ mask_BxCxV = torch.ones_like(logits_BxCxV, dtype=torch.bool)
476
+ mask_BxCxV = mask_BxCxV.scatter(dim=-1, index=top_k_indices_BxCxk, value=False)
477
+ logits_BxCxV = cond_logits_BxCxV.masked_fill(mask_BxCxV, -torch.inf)
478
+
479
+ logits_BxCxV[:, :, audio_eos_value + 1 :] = torch.full_like(
480
+ logits_BxCxV[:, :, audio_eos_value + 1 :],
481
+ fill_value=-torch.inf,
482
+ )
483
+ logits_BxCxV[:, 1:, audio_eos_value:] = torch.full_like(
484
+ logits_BxCxV[:, 1:, audio_eos_value:],
485
+ fill_value=-torch.inf,
486
+ )
487
 
488
+ flat_logits_BCxV = logits_BxCxV.view(
489
+ B * self.config.decoder_config.num_channels, -1
490
+ )
491
 
492
+ pred_BC = _sample_next_token(
493
+ flat_logits_BCxV.float(),
494
  temperature=temperature,
495
  top_p=top_p,
496
+ top_k=top_k,
497
+ audio_eos_value=audio_eos_value,
498
  )
 
499
 
500
+ pred_BxC = pred_BC.view(B, self.config.decoder_config.num_channels)
501
+ return pred_BxC
502
+
503
+ def _generate_output(
504
+ self, generated_codes: torch.Tensor, lengths_Bx: torch.Tensor
505
+ ) -> list[np.ndarray]:
506
+ """Converts generated delayed codes into audio waveforms.
507
+
508
+ Reverts the delay pattern applied during generation, decodes the resulting
509
+ codebook using the DAC model (if loaded), and returns a list of audio
510
+ waveforms as NumPy arrays. If DAC is not loaded, returns the raw codebook indices.
511
+
512
+ Args:
513
+ generated_codes: The tensor of generated audio codes with delays,
514
+ shape [B, T_gen, C].
515
+ lengths_Bx: A tensor containing the valid length of generated codes
516
+ (excluding padding and BOS/EOS markers) for each item
517
+ in the batch, shape [B].
518
+
519
+ Returns:
520
+ A list of NumPy arrays, where each array represents the generated audio
521
+ waveform for one item in the batch. If DAC is not loaded, returns the
522
+ raw, reverted codebook indices as NumPy arrays.
523
+ """
524
+ num_channels = self.config.decoder_config.num_channels
525
+ batch_size = generated_codes.shape[0]
526
+ seq_length = generated_codes.shape[1]
527
+ delay_pattern = self.config.delay_pattern
528
+ audio_pad_value = self.config.pad_token_id
529
  max_delay_pattern = max(delay_pattern)
530
 
531
  revert_precomp = build_revert_indices(
532
+ B=batch_size,
533
  T=seq_length,
534
  C=num_channels,
535
  delay_pattern=delay_pattern,
536
  )
537
 
538
  codebook = revert_audio_delay(
539
+ audio_BxTxC=generated_codes,
540
  pad_value=audio_pad_value,
541
  precomp=revert_precomp,
542
  T=seq_length,
 
547
  invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
548
  codebook[invalid_mask] = 0
549
 
550
+ audios = []
551
 
552
+ if self.load_dac:
553
+ for i in range(batch_size):
554
+ audio = self._decode(codebook[i, : lengths_Bx[i], :])
555
+ audio_np = audio.cpu().numpy()
556
+ audios.append(audio_np)
557
+ else:
558
+ for i in range(batch_size):
559
+ audios.append(codebook[i, : lengths_Bx[i], :].cpu().numpy())
560
+ return audios
561
+
562
+ @torch.no_grad()
563
+ @torch.inference_mode()
564
+ def _encode(self, audio: torch.Tensor) -> torch.Tensor:
565
+ """
566
+ Encodes the given audio waveform into a tensor of DAC codebook indices
567
+ """
568
+ audio = audio.unsqueeze(0)
569
+ audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
570
+ _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data)
571
+ encoded_frame: torch.Tensor
572
+ return encoded_frame.squeeze(0).transpose(0, 1)
573
+
574
+ @torch.no_grad()
575
+ @torch.inference_mode()
576
+ def _decode(self, audio_codes: torch.Tensor) -> torch.Tensor:
577
+ """
578
+ Decodes the given frames into an output audio waveform
579
+ """
580
+ audio_codes = audio_codes.unsqueeze(0).transpose(1, 2)
581
+ audio_values, _, _ = self.dac_model.quantizer.from_codes(audio_codes)
582
+ audio_values = self.dac_model.decode(audio_values)
583
+ audio_values: torch.Tensor
584
+ return audio_values.squeeze()
585
 
586
  def load_audio(self, audio_path: str) -> torch.Tensor:
587
+ """Loads and preprocesses an audio file for use as a prompt.
588
+
589
+ Loads the audio file, resamples it to the target sample rate if necessary,
590
+ preprocesses it using the DAC model's preprocessing, and encodes it into
591
+ DAC codebook indices.
592
+
593
+ Args:
594
+ audio_path: Path to the audio file.
595
+
596
+ Returns:
597
+ torch.Tensor: The encoded audio prompt as DAC codebook indices,
598
+ shape [T, C].
599
+
600
+ Raises:
601
+ RuntimeError: If the DAC model is not loaded (`load_dac=False` during init).
602
+ FileNotFoundError: If the audio file cannot be found.
603
+ Exception: If there's an error during loading or processing.
604
+ """
605
+ if self.dac_model is None:
606
+ raise RuntimeError(
607
+ "DAC model is required for loading audio prompts but was not loaded."
608
+ )
609
  audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
610
  if sr != DEFAULT_SAMPLE_RATE:
611
  audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
612
+ # Convert to mono if stereo
613
+ if audio.shape[0] > 1:
614
+ audio = torch.mean(
615
+ audio, dim=0, keepdim=True
616
+ ) # Average channels to get mono
617
+ return self._encode(audio.to(self.device))
618
 
619
  def save_audio(self, path: str, audio: np.ndarray):
620
+ """Saves the generated audio waveform to a file.
621
+
622
+ Uses the soundfile library to write the NumPy audio array to the specified
623
+ path with the default sample rate.
624
+
625
+ Args:
626
+ path: The path where the audio file will be saved.
627
+ audio: The audio waveform as a NumPy array.
628
+ """
629
  import soundfile as sf
630
 
631
  sf.write(path, audio, DEFAULT_SAMPLE_RATE)
 
633
  @torch.inference_mode()
634
  def generate(
635
  self,
636
+ text: str | list[str],
637
+ max_tokens: int = 3072,
638
  cfg_scale: float = 3.0,
639
+ temperature: float = 1.2,
640
  top_p: float = 0.95,
641
  use_torch_compile: bool = False,
642
+ cfg_filter_top_k: int = 45,
643
+ audio_prompt: list[str | torch.Tensor | None]
644
+ | str
645
+ | torch.Tensor
646
+ | None = None,
647
+ audio_prompt_path: list[str | torch.Tensor | None]
648
+ | str
649
+ | torch.Tensor
650
+ | None = None,
651
  use_cfg_filter: bool | None = None,
652
  verbose: bool = False,
653
+ ) -> np.ndarray | list[np.ndarray]:
654
+ """Generates audio corresponding to the input text.
655
+
656
+ Args:
657
+ text: The input text prompt, or a list of text prompts for batch generation.
658
+ max_tokens: The maximum number of audio tokens to generate per prompt.
659
+ Defaults to the model's configured audio length if None.
660
+ cfg_scale: The scale factor for classifier-free guidance (CFG). Higher values
661
+ lead to stronger guidance towards the text prompt.
662
+ temperature: The temperature for sampling. Higher values increase randomness.
663
+ top_p: The cumulative probability threshold for nucleus (top-p) sampling.
664
+ use_torch_compile: Whether to compile the generation steps using torch.compile.
665
+ Can significantly speed up generation after the initial
666
+ compilation overhead. Defaults to False.
667
+ cfg_filter_top_k: The number of top logits to consider during CFG filtering.
668
+ (Note: This parameter name might be slightly misleading based
669
+ on the code; it's used in the `_sample_next_token` function.)
670
+ audio_prompt: An audio prompt or list of prompts to condition the generation.
671
+ Can be a file path (str), a pre-loaded tensor (DAC codes), or None.
672
+ If a list, its length must match the batch size of the text input.
673
+ audio_prompt_path: (Deprecated) Use `audio_prompt` instead.
674
+ use_cfg_filter: (Deprecated) This parameter is no longer used.
675
+ verbose: If True, prints progress information during generation, including
676
+ speed metrics.
677
+
678
+ Returns:
679
+ If a single text prompt was provided, returns a NumPy array containing the
680
+ generated audio waveform.
681
+ If a list of text prompts was provided, returns a list of NumPy arrays,
682
+ each corresponding to a prompt in the input list. Returns None for a
683
+ sequence if no audio was generated for it.
684
+ """
685
+ batch_size = len(text) if isinstance(text, list) else 1
686
+ audio_eos_value = self.config.eos_token_id
687
+ audio_pad_value = self.config.pad_token_id
688
+ delay_pattern = self.config.delay_pattern
689
  max_delay_pattern = max(delay_pattern)
690
+ delay_pattern_Cx = torch.tensor(
691
+ delay_pattern, device=self.device, dtype=torch.long
692
+ )
693
  self.model.eval()
694
 
695
  if audio_prompt_path:
 
701
  if verbose:
702
  total_start_time = time.time()
703
 
704
+ if use_torch_compile and not hasattr(self, "_compiled"):
705
+ # Compilation can take about a minute.
706
+ self._prepare_generation = torch.compile(
707
+ self._prepare_generation, dynamic=True, fullgraph=True
708
+ )
709
+ self._decoder_step = torch.compile(
710
+ self._decoder_step, fullgraph=True, mode="max-autotune"
711
+ )
712
+ self._compiled = True
713
+
714
+ if isinstance(audio_prompt, list):
715
+ audio_prompt = [
716
+ self.load_audio(p) if isinstance(p, str) else p for p in audio_prompt
717
+ ]
718
+ elif isinstance(audio_prompt, str):
719
+ audio_prompt = [self.load_audio(audio_prompt)]
720
+ elif isinstance(audio_prompt, torch.Tensor):
721
+ audio_prompt = [audio_prompt]
722
+ elif audio_prompt is None:
723
+ audio_prompt = [None] * batch_size
724
+
725
+ assert len(audio_prompt) == batch_size, (
726
+ "Number of audio prompts must match batch size"
727
+ )
728
 
729
+ if isinstance(text, list):
730
+ text = [self._encode_text(t) for t in text]
731
  else:
732
+ text = [self._encode_text(text)]
733
+ text = self._pad_text_input(text)
734
+
735
+ dec_state, dec_output = self._prepare_generation(
736
+ text, audio_prompt, max_tokens=max_tokens
737
+ )
738
+ dec_step = min(dec_output.prefill_steps) - 1
739
+ current_idx = torch.tensor([dec_step], device=self.device)
740
+
741
+ eos_detected_Bx = torch.zeros(
742
+ (batch_size,), dtype=torch.bool, device=self.device
743
+ )
744
+ eos_countdown_Bx = torch.full(
745
+ (batch_size,), -1, dtype=torch.long, device=self.device
746
+ )
747
+ finished_step_Bx = torch.full(
748
+ (batch_size,), -1, dtype=torch.long, device=self.device
749
+ )
750
+
751
+ bos_over = False
752
 
753
  if verbose:
754
  print("generate: starting generation loop")
755
  if use_torch_compile:
756
  print(
757
+ "generate: using use_torch_compile=True, the first step may be slow"
758
  )
759
  start_time = time.time()
760
 
761
+ # --- Generation Loop ---
762
  while dec_step < max_tokens:
763
+ if (eos_countdown_Bx == 0).all():
764
+ break
765
+
766
+ current_step_idx = dec_step + 1
767
+ torch.compiler.cudagraph_mark_step_begin()
768
  dec_state.prepare_step(dec_step)
769
+ tokens_Bx1xC = dec_output.get_tokens_at(dec_step).repeat_interleave(
770
+ 2, dim=0
771
+ ) # Repeat for CFG
772
+
773
+ pred_BxC = self._decoder_step(
774
  tokens_Bx1xC,
775
  dec_state,
776
  cfg_scale,
777
  temperature,
778
  top_p,
779
  cfg_filter_top_k,
780
+ current_idx,
781
  )
782
 
783
+ current_idx += 1
784
+
785
+ active_mask_Bx = eos_countdown_Bx != 0
786
+ eos_trigger_Bx = torch.zeros_like(active_mask_Bx)
787
+ if active_mask_Bx.any():
788
+ is_eos_token = (~eos_detected_Bx[active_mask_Bx]) & (
789
+ pred_BxC[active_mask_Bx, 0] == audio_eos_value
790
+ )
791
+ is_max_len = current_step_idx >= max_tokens - max_delay_pattern
792
+ eos_trigger_Bx[active_mask_Bx] = is_eos_token | is_max_len
793
+ eos_detected_Bx |= eos_trigger_Bx
794
+ start_countdown_mask_Bx = eos_trigger_Bx & (eos_countdown_Bx < 0)
795
+ if start_countdown_mask_Bx.any():
796
+ eos_countdown_Bx[start_countdown_mask_Bx] = max_delay_pattern
797
+ finished_step_Bx[start_countdown_mask_Bx] = current_step_idx
798
+
799
+ padding_mask_Bx = eos_countdown_Bx > 0
800
+ if padding_mask_Bx.any():
801
+ pred_active_BxC = pred_BxC[padding_mask_Bx].clone()
802
+ countdown_active_Bx = eos_countdown_Bx[padding_mask_Bx]
803
+ step_after_eos_Bx = max_delay_pattern - countdown_active_Bx
804
+ step_after_eos_Bx_ = step_after_eos_Bx.unsqueeze(1)
805
+ delay_pattern_Cx_ = delay_pattern_Cx.unsqueeze(0)
806
+ eos_mask_NxC = step_after_eos_Bx_ == delay_pattern_Cx_
807
+ pad_mask_NxC = step_after_eos_Bx_ > delay_pattern_Cx_
808
+ pred_active_BxC[eos_mask_NxC] = audio_eos_value
809
+ pred_active_BxC[pad_mask_NxC] = audio_pad_value
810
+ pred_BxC[padding_mask_Bx] = pred_active_BxC
811
+ eos_countdown_Bx[padding_mask_Bx] -= 1
812
+
813
+ # --- Update BOS flag (Original) ---
814
+ if not bos_over:
815
+ bos_over = all(
816
+ dec_step - prefill_step > max_delay_pattern
817
+ for prefill_step in dec_output.prefill_steps
818
+ )
819
+
820
+ dec_output.update_one(pred_BxC, current_step_idx, not bos_over)
821
 
822
  dec_step += 1
823
+
824
  if verbose and dec_step % 86 == 0:
825
  duration = time.time() - start_time
826
+ if duration > 0:
827
+ print(
828
+ f"generate step {dec_step}: speed={86 * batch_size / duration:.3f} tokens/s, realtime factor={batch_size / duration:.3f}x"
829
+ )
830
  start_time = time.time()
831
 
832
+ # --- Finalize and Extract Output ---
833
+ final_step = dec_step + 1
 
834
 
835
+ finished_step_Bx[finished_step_Bx == -1] = final_step - max_delay_pattern
 
 
836
 
837
+ prefill_steps_tensor = torch.tensor(
838
+ dec_output.prefill_steps, device=self.device
839
+ )
840
+ lengths_Bx = finished_step_Bx - prefill_steps_tensor
841
+ lengths_Bx = torch.clamp(lengths_Bx, min=0)
842
+
843
+ max_len = lengths_Bx.max().item() + max_delay_pattern
844
+ outputs = []
845
+
846
+ if max_len > 0:
847
+ num_channels = self.config.decoder_config.num_channels
848
+ audio_pad_value = self.config.pad_token_id
849
+ generated_codes = torch.full(
850
+ (batch_size, max_len, num_channels),
851
+ fill_value=audio_pad_value,
852
+ dtype=torch.long,
853
+ device=self.device,
854
  )
855
 
856
+ for i in range(batch_size):
857
+ start_step = dec_output.prefill_steps[i]
858
+ actual_len = lengths_Bx[i].item() + max_delay_pattern
859
+ if actual_len > 0:
860
+ tokens_to_copy = dec_output.generated_tokens[
861
+ i, start_step : start_step + actual_len, :
862
+ ]
863
+ generated_codes[i, :actual_len, :] = tokens_to_copy
864
+
865
+ if verbose:
866
+ avg_steps = lengths_Bx.float().mean().item()
867
+ total_duration = time.time() - total_start_time
868
+ print(
869
+ f"generate: avg steps={avg_steps:.1f}, total duration={total_duration:.3f}s"
870
+ )
871
+
872
+ del dec_state
873
+
874
+ outputs = self._generate_output(generated_codes, lengths_Bx)
875
+ else:
876
+ print("Warning: Nothing generated for any sequence in the batch.")
877
+ outputs = [None] * batch_size
878
+
879
+ return outputs if batch_size > 1 else outputs[0]
dia/state.py CHANGED
@@ -1,4 +1,5 @@
1
  from dataclasses import dataclass
 
2
 
3
  import torch
4
 
@@ -14,29 +15,18 @@ def create_attn_mask(
14
  """
15
  Creates the attention mask (self or cross) mimicking JAX segment ID logic.
16
  """
17
- B1, Tq = q_padding_mask_1d.shape
18
- B2, Tk = k_padding_mask_1d.shape
19
- assert B1 == B2, "Query and key batch dimensions must match"
20
 
21
  p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
  p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
 
24
- # Condition A: Non-padding query attends to non-padding key
25
- non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26
-
27
- # Condition B: Padding query attends to padding key
28
- pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29
-
30
- # Combine: True if padding status is compatible (both non-pad OR both pad)
31
- mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32
-
33
  if is_causal:
34
- assert Tq == Tk, (
35
- "Causal mask requires query and key sequence lengths to be equal"
36
- )
37
  causal_mask_2d = torch.tril(
38
- torch.ones((Tq, Tk), dtype=torch.bool, device=device)
39
- ) # Shape [Tq, Tk]
40
  causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
41
  return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
42
  else:
@@ -58,19 +48,18 @@ class EncoderInferenceState:
58
  """Creates EtorchrInferenceParams from DiaConfig and a device."""
59
  device = cond_src.device
60
 
61
- positions = (
62
- torch.arange(config.data.text_length, device=device)
63
- .to(torch.long)
64
- .unsqueeze(0)
65
- .expand(2, -1)
66
- )
67
- padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
68
  attn_mask = create_attn_mask(
69
  padding_mask, padding_mask, device, is_causal=False
70
  )
71
 
72
  return cls(
73
- max_seq_len=config.data.text_length,
74
  device=device,
75
  positions=positions,
76
  padding_mask=padding_mask,
@@ -78,9 +67,13 @@ class EncoderInferenceState:
78
  )
79
 
80
 
81
- class KVCache:
 
 
 
82
  def __init__(
83
  self,
 
84
  num_heads: int,
85
  max_len: int,
86
  head_dim: int,
@@ -89,21 +82,33 @@ class KVCache:
89
  k: torch.Tensor | None = None,
90
  v: torch.Tensor | None = None,
91
  ):
92
- self.k = (
93
- torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
 
 
 
 
94
  if k is None
95
  else k
96
  )
97
- self.v = (
98
- torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
 
 
 
 
99
  if v is None
100
  else v
101
  )
102
- self.current_idx = torch.tensor(0)
 
 
 
103
 
104
  @classmethod
105
  def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
106
  return cls(
 
107
  num_heads=k.shape[1],
108
  max_len=k.shape[2],
109
  head_dim=k.shape[3],
@@ -114,20 +119,17 @@ class KVCache:
114
  )
115
 
116
  def update(
117
- self, k: torch.Tensor, v: torch.Tensor
118
  ) -> tuple[torch.Tensor, torch.Tensor]:
119
- self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
120
- self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
121
- self.current_idx += 1
122
- return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
123
 
124
- def prefill(
125
- self, k: torch.Tensor, v: torch.Tensor
126
- ) -> tuple[torch.Tensor, torch.Tensor]:
127
  prefill_len = k.shape[2]
128
  self.k[:, :, :prefill_len, :] = k
129
  self.v[:, :, :prefill_len, :] = v
130
- self.current_idx = prefill_len - 1
131
 
132
 
133
  @dataclass
@@ -139,9 +141,10 @@ class DecoderInferenceState:
139
  enc_out: torch.Tensor
140
  enc_positions: torch.Tensor
141
  dec_positions: torch.Tensor
142
- dec_cross_attn_mask: torch.Tensor
143
  self_attn_cache: list[KVCache]
144
  cross_attn_cache: list[KVCache]
 
 
145
 
146
  @classmethod
147
  def new(
@@ -151,28 +154,36 @@ class DecoderInferenceState:
151
  enc_out: torch.Tensor,
152
  dec_cross_attn_cache: list[KVCache],
153
  compute_dtype: torch.dtype,
 
154
  ) -> "DecoderInferenceState":
155
  """Creates DecoderInferenceParams from DiaConfig and a device."""
156
  device = enc_out.device
157
- max_audio_len = config.data.audio_length
 
 
 
158
 
159
  dec_positions = torch.full(
160
- (2, 1), fill_value=0, dtype=torch.long, device=device
 
 
 
161
  )
162
- tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
163
- dec_cross_attn_mask = create_attn_mask(
164
- tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
165
  )
166
 
167
  self_attn_cache = [
168
  KVCache(
169
- config.model.decoder.kv_heads,
 
170
  max_audio_len,
171
- config.model.decoder.gqa_head_dim,
172
  compute_dtype,
173
  device,
174
  )
175
- for _ in range(config.model.decoder.n_layer)
176
  ]
177
 
178
  return cls(
@@ -181,54 +192,56 @@ class DecoderInferenceState:
181
  enc_out=enc_out,
182
  enc_positions=enc_state.positions,
183
  dec_positions=dec_positions,
184
- dec_cross_attn_mask=dec_cross_attn_mask,
185
  self_attn_cache=self_attn_cache,
186
  cross_attn_cache=dec_cross_attn_cache,
 
 
187
  )
188
 
189
  def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
190
  if step_to is None:
191
  step_to = step_from + 1
192
- self.dec_positions = (
193
- torch.arange(step_from, step_to, device=self.device)
194
- .unsqueeze(0)
195
- .expand(2, -1)
196
- )
197
 
198
 
199
  @dataclass
200
  class DecoderOutput:
201
  generated_tokens: torch.Tensor
202
- prefill_step: int
203
 
204
  @classmethod
205
- def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
206
- max_audio_len = config.data.audio_length
 
 
207
  return cls(
208
  generated_tokens=torch.full(
209
- (max_audio_len, config.data.channels),
210
  fill_value=-1,
211
  dtype=torch.int,
212
  device=device,
213
  ),
214
- prefill_step=0,
215
  )
216
 
217
  def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
218
  if step_to is None:
219
  step_to = step_from + 1
220
- return self.generated_tokens[step_from:step_to, :]
221
 
222
  def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
 
223
  if apply_mask:
224
- mask = self.generated_tokens[step : step + 1, :] == -1
225
- self.generated_tokens[step : step + 1, :] = torch.where(
226
- mask, dec_out, self.generated_tokens[step : step + 1, :]
227
  )
228
  else:
229
- self.generated_tokens[step : step + 1, :] = dec_out
230
 
231
- def prefill(self, dec_out: torch.Tensor, prefill_step: int):
232
- length = dec_out.shape[0]
233
- self.generated_tokens[0:length, :] = dec_out
234
- self.prefill_step = prefill_step
 
1
  from dataclasses import dataclass
2
+ from typing import Optional
3
 
4
  import torch
5
 
 
15
  """
16
  Creates the attention mask (self or cross) mimicking JAX segment ID logic.
17
  """
18
+ # B1, Tq = q_padding_mask_1d.shape
19
+ # B2, Tk = k_padding_mask_1d.shape
 
20
 
21
  p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
  p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
 
24
+ mask = p_mask_q & p_mask_k
 
 
 
 
 
 
 
 
25
  if is_causal:
26
+ # assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
 
 
27
  causal_mask_2d = torch.tril(
28
+ torch.ones_like(mask[0], dtype=torch.bool, device=device)
29
+ ) # Shape [B, Tq, Tk]
30
  causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
31
  return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
32
  else:
 
48
  """Creates EtorchrInferenceParams from DiaConfig and a device."""
49
  device = cond_src.device
50
 
51
+ positions = torch.arange(
52
+ config.encoder_config.max_position_embeddings,
53
+ dtype=torch.float32,
54
+ device=device,
55
+ ).unsqueeze(0)
56
+ padding_mask = (cond_src.squeeze(1) != 0).to(device).repeat_interleave(2, dim=0)
 
57
  attn_mask = create_attn_mask(
58
  padding_mask, padding_mask, device, is_causal=False
59
  )
60
 
61
  return cls(
62
+ max_seq_len=config.encoder_config.max_position_embeddings,
63
  device=device,
64
  positions=positions,
65
  padding_mask=padding_mask,
 
67
  )
68
 
69
 
70
+ class KVCache(torch.nn.Module):
71
+ k: torch.Tensor
72
+ v: torch.Tensor
73
+
74
  def __init__(
75
  self,
76
+ batch_size: int,
77
  num_heads: int,
78
  max_len: int,
79
  head_dim: int,
 
82
  k: torch.Tensor | None = None,
83
  v: torch.Tensor | None = None,
84
  ):
85
+ k = (
86
+ torch.zeros(
87
+ (2 * batch_size, num_heads, max_len, head_dim),
88
+ dtype=dtype,
89
+ device=device,
90
+ )
91
  if k is None
92
  else k
93
  )
94
+ v = (
95
+ torch.zeros(
96
+ (2 * batch_size, num_heads, max_len, head_dim),
97
+ dtype=dtype,
98
+ device=device,
99
+ )
100
  if v is None
101
  else v
102
  )
103
+ super().__init__()
104
+
105
+ self.register_buffer("k", k)
106
+ self.register_buffer("v", v)
107
 
108
  @classmethod
109
  def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
110
  return cls(
111
+ batch_size=k.shape[0] // 2,
112
  num_heads=k.shape[1],
113
  max_len=k.shape[2],
114
  head_dim=k.shape[3],
 
119
  )
120
 
121
  def update(
122
+ self, k: torch.Tensor, v: torch.Tensor, current_idx: torch.Tensor
123
  ) -> tuple[torch.Tensor, torch.Tensor]:
124
+ k_out, v_out = self.k, self.v
125
+ k_out[:, :, current_idx, :] = k
126
+ v_out[:, :, current_idx, :] = v
127
+ return self.k, self.v
128
 
129
+ def prefill(self, k: torch.Tensor, v: torch.Tensor):
 
 
130
  prefill_len = k.shape[2]
131
  self.k[:, :, :prefill_len, :] = k
132
  self.v[:, :, :prefill_len, :] = v
 
133
 
134
 
135
  @dataclass
 
141
  enc_out: torch.Tensor
142
  enc_positions: torch.Tensor
143
  dec_positions: torch.Tensor
 
144
  self_attn_cache: list[KVCache]
145
  cross_attn_cache: list[KVCache]
146
+ casual_attn_mask: torch.Tensor
147
+ cross_attn_mask: torch.Tensor
148
 
149
  @classmethod
150
  def new(
 
154
  enc_out: torch.Tensor,
155
  dec_cross_attn_cache: list[KVCache],
156
  compute_dtype: torch.dtype,
157
+ max_generation_length: Optional[int] = None,
158
  ) -> "DecoderInferenceState":
159
  """Creates DecoderInferenceParams from DiaConfig and a device."""
160
  device = enc_out.device
161
+ max_audio_len = (
162
+ max_generation_length or config.decoder_config.max_position_embeddings
163
+ )
164
+ batch_size = enc_out.shape[0] // 2
165
 
166
  dec_positions = torch.full(
167
+ (2 * batch_size, 1), fill_value=0, dtype=torch.int32, device=device
168
+ )
169
+ causal_mask = torch.tril(
170
+ torch.ones(max_audio_len, max_audio_len, dtype=torch.bool, device=device)
171
  )
172
+ dec_mask = torch.ones((2 * batch_size, 1), dtype=torch.bool, device=device)
173
+ cross_attn_mask = create_attn_mask(
174
+ dec_mask, enc_state.padding_mask, device, is_causal=False
175
  )
176
 
177
  self_attn_cache = [
178
  KVCache(
179
+ batch_size,
180
+ config.decoder_config.num_key_value_heads,
181
  max_audio_len,
182
+ config.decoder_config.head_dim,
183
  compute_dtype,
184
  device,
185
  )
186
+ for _ in range(config.decoder_config.num_hidden_layers)
187
  ]
188
 
189
  return cls(
 
192
  enc_out=enc_out,
193
  enc_positions=enc_state.positions,
194
  dec_positions=dec_positions,
 
195
  self_attn_cache=self_attn_cache,
196
  cross_attn_cache=dec_cross_attn_cache,
197
+ casual_attn_mask=causal_mask,
198
+ cross_attn_mask=cross_attn_mask,
199
  )
200
 
201
  def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
202
  if step_to is None:
203
  step_to = step_from + 1
204
+ self.dec_positions = torch.arange(
205
+ step_from, step_to, dtype=torch.int32, device=self.device
206
+ ).unsqueeze(0)
 
 
207
 
208
 
209
  @dataclass
210
  class DecoderOutput:
211
  generated_tokens: torch.Tensor
212
+ prefill_steps: list[int]
213
 
214
  @classmethod
215
+ def new(
216
+ cls, batch_size: int, config: DiaConfig, device: torch.device
217
+ ) -> "DecoderOutput":
218
+ max_audio_len = config.decoder_config.max_position_embeddings
219
  return cls(
220
  generated_tokens=torch.full(
221
+ (batch_size, max_audio_len, config.decoder_config.num_channels),
222
  fill_value=-1,
223
  dtype=torch.int,
224
  device=device,
225
  ),
226
+ prefill_steps=[],
227
  )
228
 
229
  def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
230
  if step_to is None:
231
  step_to = step_from + 1
232
+ return self.generated_tokens[:, step_from:step_to, :]
233
 
234
  def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
235
+ dec_out = dec_out.to(self.generated_tokens.dtype)
236
  if apply_mask:
237
+ mask = self.generated_tokens[:, step, :] == -1
238
+ self.generated_tokens[:, step, :] = torch.where(
239
+ mask, dec_out, self.generated_tokens[:, step, :]
240
  )
241
  else:
242
+ self.generated_tokens[:, step, :] = dec_out
243
 
244
+ def prefill(self, dec_out: torch.Tensor, prefill_steps: list[int]):
245
+ length = dec_out.shape[1]
246
+ self.generated_tokens[:, :length, :] = dec_out
247
+ self.prefill_steps = prefill_steps
requirements.txt CHANGED
@@ -4,6 +4,7 @@ huggingface-hub>=0.30.2
4
  numpy>=2.2.4
5
  pydantic>=2.11.3
6
  soundfile>=0.13.1
7
- torchaudio>=2.0.0
8
- torch>=2.0.0
 
9
  gradio-dialogue>=0.0.4
 
4
  numpy>=2.2.4
5
  pydantic>=2.11.3
6
  soundfile>=0.13.1
7
+ torchaudio==2.6.0
8
+ torch==2.6.0
9
+ triton==3.2.0
10
  gradio-dialogue>=0.0.4