File size: 3,292 Bytes
3a1da90 5306fb5 3a1da90 ef423d2 3a1da90 5306fb5 3a1da90 5306fb5 3a1da90 5306fb5 3a1da90 5306fb5 3a1da90 5306fb5 3a1da90 5306fb5 3a1da90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import logging
from typing import Callable, Optional
import torch
from torchdiffeq import odeint
log = logging.getLogger()
## partially from https://github.com/gle-bellier/flow-matching
class FlowMatching:
def __init__(self, min_sigma: float = 0.0,
inference_mode='euler',
num_steps: int = 25,
reverse_flow: bool = True):
# inference_mode: 'euler' or 'adaptive'
# num_steps: number of steps in the euler inference mode
# !TODO activate min_sigma for flow matching
super().__init__()
self.min_sigma = min_sigma
self.inference_mode = inference_mode
self.num_steps = num_steps
self.reverse_flow = reverse_flow
assert self.inference_mode in ['euler', 'adaptive']
if self.inference_mode == 'adaptive' and num_steps > 0:
log.info('The number of steps is ignored in adaptive inference mode ')
def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor,
t: torch.Tensor) -> torch.Tensor:
# which is psi_t(x), eq 22 in flow matching for generative models
t = t[:, None, None].expand_as(x0)
if self.reverse_flow:
return (1 - t) * x1 + t * x0 # xt = (1-t)*x1 + t*x0 -> vt = x0 - x1
else:
return (1 - t) * x0 + t * x1 # xt = (1-t)*x0 + t*x1 -> vt = x1 - x0
def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
# return the mean error without reducing the batch dimension
reduce_dim = list(range(1, len(predicted_v.shape)))
if self.reverse_flow:
target_v = x0 - x1
else:
target_v = x1 - x0
return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)
def get_x0_xt_c(
self,
x1: torch.Tensor,
t: torch.Tensor,
Cs: list[torch.Tensor],
generator: Optional[torch.Generator] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x0 = torch.empty_like(x1).normal_(generator=generator)
xt = self.get_conditional_flow(x0, x1, t)
return x0, x1, xt, Cs
def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
if self.reverse_flow:
return self.run_t0_to_t1(fn, x1, 0, 1)
else:
return self.run_t0_to_t1(fn, x1, 1, 0)
def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
if self.reverse_flow:
return self.run_t0_to_t1(fn, x0, 1, 0)
else:
return self.run_t0_to_t1(fn, x0, 0, 1)
def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
# fn: a function that takes (t, x) and returns the direction x0->x1
if self.inference_mode == 'adaptive':
return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype))
elif self.inference_mode == 'euler':
x = x0
steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1)
for ti, t in enumerate(steps[:-1]):
flow = fn(t, x)
next_t = steps[ti + 1]
dt = next_t - t
x = x + dt * flow
return x
|