AndreasXi commited on
Commit
5306fb5
·
1 Parent(s): ef27943

fix flow direction

Browse files
Files changed (1) hide show
  1. meanaudio/model/flow_matching.py +23 -10
meanaudio/model/flow_matching.py CHANGED
@@ -10,15 +10,18 @@ log = logging.getLogger()
10
  ## partially from https://github.com/gle-bellier/flow-matching
11
  class FlowMatching:
12
 
13
- def __init__(self, min_sigma: float = 0.0, inference_mode='euler', num_steps: int = 25):
 
 
 
14
  # inference_mode: 'euler' or 'adaptive'
15
  # num_steps: number of steps in the euler inference mode
 
16
  super().__init__()
17
  self.min_sigma = min_sigma
18
  self.inference_mode = inference_mode
19
  self.num_steps = num_steps
20
-
21
- # self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma)
22
 
23
  assert self.inference_mode in ['euler', 'adaptive']
24
  if self.inference_mode == 'adaptive' and num_steps > 0:
@@ -28,12 +31,18 @@ class FlowMatching:
28
  t: torch.Tensor) -> torch.Tensor:
29
  # which is psi_t(x), eq 22 in flow matching for generative models
30
  t = t[:, None, None].expand_as(x0)
31
- return (1 - (1 - self.min_sigma) * t) * x0 + t * x1 # (1-(1-sigma)*t)*x0 + t*x1
 
 
 
32
 
33
  def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
34
  # return the mean error without reducing the batch dimension
35
  reduce_dim = list(range(1, len(predicted_v.shape)))
36
- target_v = x1 - (1 - self.min_sigma) * x0
 
 
 
37
  return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
38
 
39
  def get_x0_xt_c(
@@ -49,10 +58,16 @@ class FlowMatching:
49
  return x0, x1, xt, Cs
50
 
51
  def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
52
- return self.run_t0_to_t1(fn, x1, 1, 0)
 
 
 
53
 
54
  def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
55
- return self.run_t0_to_t1(fn, x0, 0, 1)
 
 
 
56
 
57
  def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
58
  # fn: a function that takes (t, x) and returns the direction x0->x1
@@ -66,8 +81,6 @@ class FlowMatching:
66
  flow = fn(t, x)
67
  next_t = steps[ti + 1]
68
  dt = next_t - t
69
-
70
- # fix: we need to subtract the flow since we learn the reverse trajectory
71
- x = x - dt * flow
72
 
73
  return x
 
10
  ## partially from https://github.com/gle-bellier/flow-matching
11
  class FlowMatching:
12
 
13
+ def __init__(self, min_sigma: float = 0.0,
14
+ inference_mode='euler',
15
+ num_steps: int = 25,
16
+ reverse_flow: bool = True):
17
  # inference_mode: 'euler' or 'adaptive'
18
  # num_steps: number of steps in the euler inference mode
19
+ # !TODO activate min_sigma
20
  super().__init__()
21
  self.min_sigma = min_sigma
22
  self.inference_mode = inference_mode
23
  self.num_steps = num_steps
24
+ self.reverse_flow = reverse_flow
 
25
 
26
  assert self.inference_mode in ['euler', 'adaptive']
27
  if self.inference_mode == 'adaptive' and num_steps > 0:
 
31
  t: torch.Tensor) -> torch.Tensor:
32
  # which is psi_t(x), eq 22 in flow matching for generative models
33
  t = t[:, None, None].expand_as(x0)
34
+ if self.reverse_flow:
35
+ return (1 - t) * x1 + t * x0 # xt = (1-t)*x1 + t*x0 -> vt = x0 - x1
36
+ else:
37
+ return (1 - t) * x0 + t * x1 # xt = (1-t)*x0 + t*x1 -> vt = x1 - x0
38
 
39
  def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
40
  # return the mean error without reducing the batch dimension
41
  reduce_dim = list(range(1, len(predicted_v.shape)))
42
+ if self.reverse_flow:
43
+ target_v = x0 - x1
44
+ else:
45
+ target_v = x1 - x0
46
  return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
47
 
48
  def get_x0_xt_c(
 
58
  return x0, x1, xt, Cs
59
 
60
  def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
61
+ if self.reverse_flow:
62
+ return self.run_t0_to_t1(fn, x1, 0, 1)
63
+ else:
64
+ return self.run_t0_to_t1(fn, x1, 1, 0)
65
 
66
  def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
67
+ if self.reverse_flow:
68
+ return self.run_t0_to_t1(fn, x0, 1, 0)
69
+ else:
70
+ return self.run_t0_to_t1(fn, x0, 0, 1)
71
 
72
  def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
73
  # fn: a function that takes (t, x) and returns the direction x0->x1
 
81
  flow = fn(t, x)
82
  next_t = steps[ti + 1]
83
  dt = next_t - t
84
+ x = x + dt * flow
 
 
85
 
86
  return x