fix flow direction
Browse files- meanaudio/model/flow_matching.py +23 -10
meanaudio/model/flow_matching.py
CHANGED
@@ -10,15 +10,18 @@ log = logging.getLogger()
|
|
10 |
## partially from https://github.com/gle-bellier/flow-matching
|
11 |
class FlowMatching:
|
12 |
|
13 |
-
def __init__(self, min_sigma: float = 0.0,
|
|
|
|
|
|
|
14 |
# inference_mode: 'euler' or 'adaptive'
|
15 |
# num_steps: number of steps in the euler inference mode
|
|
|
16 |
super().__init__()
|
17 |
self.min_sigma = min_sigma
|
18 |
self.inference_mode = inference_mode
|
19 |
self.num_steps = num_steps
|
20 |
-
|
21 |
-
# self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma)
|
22 |
|
23 |
assert self.inference_mode in ['euler', 'adaptive']
|
24 |
if self.inference_mode == 'adaptive' and num_steps > 0:
|
@@ -28,12 +31,18 @@ class FlowMatching:
|
|
28 |
t: torch.Tensor) -> torch.Tensor:
|
29 |
# which is psi_t(x), eq 22 in flow matching for generative models
|
30 |
t = t[:, None, None].expand_as(x0)
|
31 |
-
|
|
|
|
|
|
|
32 |
|
33 |
def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
|
34 |
# return the mean error without reducing the batch dimension
|
35 |
reduce_dim = list(range(1, len(predicted_v.shape)))
|
36 |
-
|
|
|
|
|
|
|
37 |
return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
|
38 |
|
39 |
def get_x0_xt_c(
|
@@ -49,10 +58,16 @@ class FlowMatching:
|
|
49 |
return x0, x1, xt, Cs
|
50 |
|
51 |
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
|
58 |
# fn: a function that takes (t, x) and returns the direction x0->x1
|
@@ -66,8 +81,6 @@ class FlowMatching:
|
|
66 |
flow = fn(t, x)
|
67 |
next_t = steps[ti + 1]
|
68 |
dt = next_t - t
|
69 |
-
|
70 |
-
# fix: we need to subtract the flow since we learn the reverse trajectory
|
71 |
-
x = x - dt * flow
|
72 |
|
73 |
return x
|
|
|
10 |
## partially from https://github.com/gle-bellier/flow-matching
|
11 |
class FlowMatching:
|
12 |
|
13 |
+
def __init__(self, min_sigma: float = 0.0,
|
14 |
+
inference_mode='euler',
|
15 |
+
num_steps: int = 25,
|
16 |
+
reverse_flow: bool = True):
|
17 |
# inference_mode: 'euler' or 'adaptive'
|
18 |
# num_steps: number of steps in the euler inference mode
|
19 |
+
# !TODO activate min_sigma
|
20 |
super().__init__()
|
21 |
self.min_sigma = min_sigma
|
22 |
self.inference_mode = inference_mode
|
23 |
self.num_steps = num_steps
|
24 |
+
self.reverse_flow = reverse_flow
|
|
|
25 |
|
26 |
assert self.inference_mode in ['euler', 'adaptive']
|
27 |
if self.inference_mode == 'adaptive' and num_steps > 0:
|
|
|
31 |
t: torch.Tensor) -> torch.Tensor:
|
32 |
# which is psi_t(x), eq 22 in flow matching for generative models
|
33 |
t = t[:, None, None].expand_as(x0)
|
34 |
+
if self.reverse_flow:
|
35 |
+
return (1 - t) * x1 + t * x0 # xt = (1-t)*x1 + t*x0 -> vt = x0 - x1
|
36 |
+
else:
|
37 |
+
return (1 - t) * x0 + t * x1 # xt = (1-t)*x0 + t*x1 -> vt = x1 - x0
|
38 |
|
39 |
def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
|
40 |
# return the mean error without reducing the batch dimension
|
41 |
reduce_dim = list(range(1, len(predicted_v.shape)))
|
42 |
+
if self.reverse_flow:
|
43 |
+
target_v = x0 - x1
|
44 |
+
else:
|
45 |
+
target_v = x1 - x0
|
46 |
return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
|
47 |
|
48 |
def get_x0_xt_c(
|
|
|
58 |
return x0, x1, xt, Cs
|
59 |
|
60 |
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
|
61 |
+
if self.reverse_flow:
|
62 |
+
return self.run_t0_to_t1(fn, x1, 0, 1)
|
63 |
+
else:
|
64 |
+
return self.run_t0_to_t1(fn, x1, 1, 0)
|
65 |
|
66 |
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
|
67 |
+
if self.reverse_flow:
|
68 |
+
return self.run_t0_to_t1(fn, x0, 1, 0)
|
69 |
+
else:
|
70 |
+
return self.run_t0_to_t1(fn, x0, 0, 1)
|
71 |
|
72 |
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
|
73 |
# fn: a function that takes (t, x) and returns the direction x0->x1
|
|
|
81 |
flow = fn(t, x)
|
82 |
next_t = steps[ti + 1]
|
83 |
dt = next_t - t
|
84 |
+
x = x + dt * flow
|
|
|
|
|
85 |
|
86 |
return x
|