aiface's picture
Upload 11 files
907b7f3
raw
history blame
994 Bytes
import torch
import numpy as np
# -- mixup data augmentation # mixup augmentation ๊ณ„์‚ฐ
# from https://github.com/hongyi-zhang/mixup/blob/master/cifar/utils.py
def mixup_data(x, y, alpha=1.0, soft_labels = None, use_cuda=False):
'''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
if alpha > 0.:
lam = np.random.beta(alpha, alpha) # ๋ฒ ํƒ€ ๋ถ„ํฌ์—์„œ ํ‘œ๋ณธ ์ถ”์ถœ
else:
lam = 1.
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda() # ์ฃผ์–ด์ง„ ๋ฒ”์œ„ ๋‚ด์˜ ์ •์ˆ˜๋ฅผ ๋žœ๋คํ•˜๊ฒŒ ์ƒ์„ฑ # tensor ๋ฅผ gpu ์— ํ• ๋‹น
else:
index = torch.randperm(batch_size) # ์ฃผ์–ด์ง„ ๋ฒ”์œ„ ๋‚ด์˜ ์ •์ˆ˜๋ฅผ ๋žœ๋คํ•˜๊ฒŒ ์ƒ์„ฑ
mixed_x = lam * x + (1 - lam) * x[index,:]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
# mixup ์ ์šฉ
def mixup_criterion(y_a, y_b, lam):
return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)