File size: 1,752 Bytes
3a1da90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Optional

import torch


def log_normal_sample(x: torch.Tensor,
                      generator: Optional[torch.Generator] = None,
                      m: float = 0.0,
                      s: float = 1.0) -> torch.Tensor:
    bs = x.shape[0]
    s = torch.randn(bs, device=x.device, generator=generator) * s + m
    return torch.sigmoid(s)
import torch
from typing import Optional, Tuple

def log_normal_sample_r_t(
    x: torch.Tensor,
    generator: Optional[torch.Generator] = None,
    m: float = 0.0,
    s: float = 1.0,
    epsilon: float = 1.0  # 控制第二个张量的最小增量
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    生成两个张量,确保第二个张量的每个元素都大于第一个张量。
    
    参数:
    x (torch.Tensor): 输入张量(用于确定 batch_size 和设备)
    generator (torch.Generator, optional): 随机数生成器
    m (float): 正态分布的均值(默认为 0)
    s (float): 正态分布的标准差(默认为 1)
    epsilon (float): 控制第二个张量的最小增量(默认为 1)
    
    返回:
    Tuple[torch.Tensor, torch.Tensor]: 两个经过 sigmoid 处理的张量,第二个的每个元素均大于第一个
    """
    bs = x.shape[0]
    device = x.device
    
    # 生成第一个张量的原始值
    s1 = torch.randn(bs, device=device, generator=generator) * s + m
    
    # 生成第二个张量,确保每个元素比第一个大:
    # 使用绝对值正态分布作为增量,保证非负性
    increment = torch.abs(torch.randn(bs, device=device, generator=generator)) * epsilon
    s2 = s1 + increment
    
    # 应用 sigmoid 并返回
    #第二个比第一个大
    return torch.sigmoid(s1), torch.sigmoid(s2)