File size: 3,292 Bytes
3a1da90
 
 
 
 
 
 
 
 
 
 
 
5306fb5
 
 
 
3a1da90
 
ef423d2
3a1da90
 
 
 
5306fb5
3a1da90
 
 
 
 
 
 
 
 
5306fb5
 
 
 
3a1da90
 
 
 
5306fb5
 
 
 
3a1da90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5306fb5
 
 
 
3a1da90
 
5306fb5
 
 
 
3a1da90
 
 
 
 
 
 
 
 
 
 
 
 
5306fb5
3a1da90
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import logging
from typing import Callable, Optional

import torch
from torchdiffeq import odeint

log = logging.getLogger()


## partially from https://github.com/gle-bellier/flow-matching
class FlowMatching:

    def __init__(self, min_sigma: float = 0.0, 
                 inference_mode='euler', 
                 num_steps: int = 25, 
                 reverse_flow: bool = True):
        # inference_mode: 'euler' or 'adaptive'
        # num_steps: number of steps in the euler inference mode
        # !TODO activate min_sigma for flow matching
        super().__init__()
        self.min_sigma = min_sigma
        self.inference_mode = inference_mode
        self.num_steps = num_steps
        self.reverse_flow = reverse_flow

        assert self.inference_mode in ['euler', 'adaptive']
        if self.inference_mode == 'adaptive' and num_steps > 0:
            log.info('The number of steps is ignored in adaptive inference mode ')

    def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor,
                             t: torch.Tensor) -> torch.Tensor:
        # which is psi_t(x), eq 22 in flow matching for generative models
        t = t[:, None, None].expand_as(x0)
        if self.reverse_flow:
            return (1 - t) * x1 + t * x0  # xt = (1-t)*x1 + t*x0 -> vt = x0 - x1
        else:
            return (1 - t) * x0 + t * x1  # xt = (1-t)*x0 + t*x1 -> vt = x1 - x0

    def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor:
        # return the mean error without reducing the batch dimension
        reduce_dim = list(range(1, len(predicted_v.shape)))
        if self.reverse_flow:
            target_v = x0 - x1  
        else:
            target_v = x1 - x0 
        return (predicted_v - target_v).pow(2).mean(dim=reduce_dim)

    def get_x0_xt_c(
        self,
        x1: torch.Tensor,
        t: torch.Tensor,
        Cs: list[torch.Tensor],
        generator: Optional[torch.Generator] = None
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x0 = torch.empty_like(x1).normal_(generator=generator)

        xt = self.get_conditional_flow(x0, x1, t)
        return x0, x1, xt, Cs

    def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor:
        if self.reverse_flow:
            return self.run_t0_to_t1(fn, x1, 0, 1)
        else: 
            return self.run_t0_to_t1(fn, x1, 1, 0)

    def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor:
        if self.reverse_flow:
            return self.run_t0_to_t1(fn, x0, 1, 0)
        else:
            return self.run_t0_to_t1(fn, x0, 0, 1)

    def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor:
        # fn: a function that takes (t, x) and returns the direction x0->x1

        if self.inference_mode == 'adaptive':
            return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype))
        elif self.inference_mode == 'euler':
            x = x0
            steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1)
            for ti, t in enumerate(steps[:-1]):
                flow = fn(t, x)
                next_t = steps[ti + 1]
                dt = next_t - t
                x = x + dt * flow

        return x