ONNX
thecollabagepatch commited on
Commit
2a53414
·
verified ·
1 Parent(s): 02e995a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -14,3 +14,58 @@ when using the initial version, the decoder ((autoencoder_arm.onnx)) crashes the
14
 
15
  nothing to see here, yet... just wanted a place to store these.
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  nothing to see here, yet... just wanted a place to store these.
16
 
17
+ like everything else i do...pure vibes zero real knowledge.
18
+
19
+ Here's a python script i used to validate outputs against the original pytorch model.
20
+
21
+ there's another one using cfg stuff that gets essentially the same outputs.
22
+
23
+ ```
24
+
25
+ #!/usr/bin/env python
26
+ import numpy as np, soundfile as sf, onnxruntime as ort
27
+ from transformers import AutoTokenizer
28
+
29
+ # Load ONNX models
30
+ dit = ort.InferenceSession("diffusion_dit_arm.onnx")
31
+ cond = ort.InferenceSession("conditioners.onnx")
32
+ dec = ort.InferenceSession("autoencoder_arm.onnx")
33
+
34
+ # Config
35
+ prompt = "lo-fi hip-hop beat with pianos 90bpm"
36
+ steps = 10
37
+ rng = np.random.RandomState(12345)
38
+ x = rng.randn(1, 64, 256).astype(np.float32)
39
+
40
+ # Conditioning
41
+ tok = AutoTokenizer.from_pretrained("t5-base")
42
+ tokens = tok(prompt, truncation=True, padding="max_length", max_length=128, return_tensors="np")
43
+ conds = cond.run(None, {
44
+ "input_ids": tokens["input_ids"].astype(np.int64),
45
+ "attention_mask": tokens["attention_mask"].astype(np.int64),
46
+ "seconds_total": np.array([10.0], dtype=np.float32)
47
+ })
48
+ cross, _, glob = conds
49
+
50
+ # Run 10 steps with linear t, no CFG
51
+ for i in range(steps):
52
+ t_val = 1.0 - i / (steps - 1)
53
+ t = np.array([t_val], dtype=np.float32)
54
+
55
+ v = dit.run(None, {
56
+ "x": x, "t": t,
57
+ "cross_attn_cond": cross,
58
+ "global_cond": glob
59
+ })[0]
60
+
61
+ x -= 0.1 * v # fixed Euler step
62
+
63
+ # Decode
64
+ audio = dec.run(None, {'sampled': x})[0]
65
+ if audio.shape[0] == 2:
66
+ audio = audio.T
67
+ audio /= np.abs(audio).max()
68
+ sf.write("onnx_lofi_linear.wav", audio, 44100)
69
+ print("✅ onnx_lofi_linear.wav written!")
70
+
71
+ ```