LInKAlphabetDemo / LInK /CurveUtils.py
Open-TO's picture
init
460c05d
raw
history blame contribute delete
859 Bytes
import torch
def uniformize(curves: torch.tensor, n: int = 200) -> torch.tensor:
with torch.no_grad():
l = torch.cumsum(torch.nn.functional.pad(torch.norm(curves[:,1:,:] - curves[:,:-1,:],dim=-1),[1,0,0,0]),-1)
l = l/l[:,-1].unsqueeze(-1)
sampling = torch.linspace(0,1,n).to(l.device).unsqueeze(0).tile([l.shape[0],1])
end_is = torch.searchsorted(l,sampling)[:,1:]
end_ids = end_is.unsqueeze(-1).tile([1,1,2])
l_end = torch.gather(l,1,end_is)
l_start = torch.gather(l,1,end_is-1)
ws = (l_end - sampling[:,1:])/(l_end-l_start)
end_gather = torch.gather(curves,1,end_ids)
start_gather = torch.gather(curves,1,end_ids-1)
uniform_curves = torch.cat([curves[:,0:1,:],(end_gather - (end_gather-start_gather)*ws.unsqueeze(-1))],1)
return uniform_curves