|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def drop_feature(img_feature, video_max_frames, img_similarity=None): |
|
T, P, D = img_feature.shape |
|
indices = [[i] for i in range(T)] |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, img_similarity, [indices] |
|
cur_feature = img_feature[:T0] |
|
if img_similarity is not None: |
|
cur_sim = img_similarity[:T0 - 1] |
|
else: |
|
cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) |
|
cur_indices = indices[:T0] |
|
step_indices = [cur_indices] |
|
for i in range(T0, T): |
|
new_feature = img_feature[i] |
|
new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0) |
|
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) |
|
all_indices = cur_indices + [[i]] |
|
all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0) |
|
idx = torch.argmax(all_sim) |
|
if random.randint(0, 1) > 0: |
|
idx = idx + 1 |
|
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) |
|
if idx + 1 == T0 + 1: |
|
cur_sim = all_sim[:T0 - 1] |
|
cur_indices = all_indices[:-1] |
|
elif idx == 0: |
|
cur_sim = all_sim[1:] |
|
cur_indices = all_indices[1:] |
|
else: |
|
cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]]) |
|
cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0) |
|
cur_indices = all_indices[:idx] + all_indices[idx + 1:] |
|
step_indices.append(cur_indices) |
|
|
|
return cur_feature, cur_sim, step_indices |
|
|
|
|
|
def merge_feature(img_feature, video_max_frames, img_similarity=None): |
|
T, P, D = img_feature.shape |
|
indices = [[i] for i in range(T)] |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, img_similarity, [indices] |
|
cur_feature = img_feature[:T0] |
|
cur_indices = indices[:T0] |
|
step_indices = [cur_indices] |
|
if img_similarity is not None: |
|
cur_sim = img_similarity[:T0 - 1] |
|
else: |
|
cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) |
|
for i in range(T0, T): |
|
new_feature = img_feature[i] |
|
new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0) |
|
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) |
|
all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0) |
|
all_indices = cur_indices + [[i]] |
|
idx = torch.argmax(all_sim) |
|
all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0 |
|
all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1] |
|
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) |
|
cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]]) |
|
cur_indices = all_indices[:idx] + all_indices[idx + 1:] |
|
if idx > 0: |
|
cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0) |
|
if idx + 1 < T0: |
|
cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0) |
|
step_indices.append(cur_indices) |
|
|
|
return cur_feature, cur_sim, step_indices |
|
|
|
|
|
def kmeans_feature(img_feature, video_max_frames, img_similarity=None): |
|
def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10): |
|
indices = torch.randperm(X.size(0))[:num_clusters] |
|
centroids = X[indices] |
|
for i in range(max_iter): |
|
if distance == 'euclidean': |
|
dists = torch.cdist(X, centroids, p=2) |
|
else: |
|
raise NotImplementedError("Only Euclidean distance is supported yet") |
|
labels = torch.argmin(dists, dim=1) |
|
new_centroids = [] |
|
for j in range(num_clusters): |
|
cluster_points = X[labels == j] |
|
if len(cluster_points) > 0: |
|
new_centroid = cluster_points.mean(0) |
|
else: |
|
new_centroid = X[random.randint(0, X.size(0) - 1)] |
|
new_centroids.append(new_centroid) |
|
new_centroids = torch.stack(new_centroids) |
|
diff = torch.norm(centroids - new_centroids, dim=1).sum() |
|
if diff < tol: |
|
break |
|
centroids = new_centroids |
|
return centroids, labels, i |
|
T, P, D = img_feature.shape |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, img_similarity, [[[i] for i in range(T)]] |
|
X = img_feature.view(T, -1) |
|
centroids, labels, exit_step = kmeans_torch(X, T0) |
|
reduced_feature = centroids.view(T0, P, D) |
|
|
|
step_indices = [[] for _ in range(T0)] |
|
for i in range(T0): |
|
step_indices[i] = [j for j in range(T) if labels[j] == i] |
|
return reduced_feature, img_similarity, [step_indices] |
|
|
|
|
|
def weighted_kmeans_feature(img_feature, video_max_frames, weights=None): |
|
if weights is None: |
|
weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device) |
|
def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10): |
|
indices = torch.randperm(X.size(0), device=X.device)[:num_clusters] |
|
centroids = X[indices] |
|
for i in range(max_iter): |
|
if distance == 'euclidean': |
|
dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt() |
|
else: |
|
raise NotImplementedError("Only Euclidean distance is supported yet") |
|
labels = torch.argmin(dists, dim=1) |
|
weighted_sum = torch.zeros_like(centroids) |
|
weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device) |
|
for j in range(num_clusters): |
|
cluster_mask = labels == j |
|
weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0) |
|
weights_sum[j] = torch.sum(weights[cluster_mask]) |
|
mask = weights_sum > 0 |
|
new_centroids = torch.zeros_like(weighted_sum) |
|
new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None] |
|
if mask.sum() < num_clusters: |
|
new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())]) |
|
diff = torch.norm(centroids - new_centroids, dim=1).sum() |
|
if diff < tol: |
|
break |
|
centroids = new_centroids |
|
return centroids, labels, weights_sum, i |
|
T, P, D = img_feature.shape |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, weights, [[[i] for i in range(T)]] |
|
X = img_feature.view(T, -1) |
|
centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights) |
|
reduced_feature = centroids.view(T0, P, D) |
|
|
|
step_indices = [[] for _ in range(T0)] |
|
for i in range(T0): |
|
step_indices[i] = [j for j in range(T) if labels[j] == i] |
|
return reduced_feature, weights, [step_indices] |
|
|
|
|
|
def k_drop_feature(img_feature, video_max_frames, img_similarity=None): |
|
T, P, D = img_feature.shape |
|
indices = [[i] for i in range(T)] |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, img_similarity, [indices] |
|
cur_feature = img_feature[:T0] |
|
normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1) |
|
cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) |
|
cur_sim.fill_diagonal_(-100.0) |
|
cur_indices = indices[:T0] |
|
step_indices = [cur_indices] |
|
for i in range(T0, T): |
|
|
|
new_feature = img_feature[i] |
|
normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1) |
|
new_sim = torch.mm(normed_cur_features, normed_new_feature.T) |
|
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) |
|
normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0) |
|
all_indices = cur_indices + [[i]] |
|
|
|
all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) |
|
all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) |
|
all_sim[-1, :-1] = new_sim.T |
|
|
|
idx = torch.argmax(all_sim) |
|
left, right = idx // (T0 + 1), idx % (T0 + 1) |
|
if random.randint(0, 1) > 0: |
|
idx = left |
|
else: |
|
idx = right |
|
assert all_sim[left, right] == torch.max(all_sim) |
|
|
|
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) |
|
normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]]) |
|
cur_indices = all_indices[:idx] + all_indices[idx + 1:] |
|
cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) |
|
cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) |
|
step_indices.append(cur_indices) |
|
|
|
return cur_feature, None, step_indices |
|
|
|
|
|
def k_merge_feature(img_feature, video_max_frames, img_similarity=None): |
|
T, P, D = img_feature.shape |
|
indices = [[i] for i in range(T)] |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, img_similarity, [indices] |
|
cur_feature = img_feature[:T0] |
|
normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1) |
|
cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) |
|
cur_sim.fill_diagonal_(-100.0) |
|
cur_indices = indices[:T0] |
|
step_indices = [cur_indices] |
|
for i in range(T0, T): |
|
|
|
new_feature = img_feature[i] |
|
normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1) |
|
new_sim = torch.mm(normed_cur_features, normed_new_feature.T) |
|
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) |
|
normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0) |
|
all_indices = cur_indices + [[i]] |
|
|
|
all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) |
|
all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) |
|
all_sim[-1, :-1] = new_sim.T |
|
|
|
idx = torch.argmax(all_sim) |
|
left, right = idx // (T0 + 1), idx % (T0 + 1) |
|
assert all_sim[left, right] == torch.max(all_sim) |
|
|
|
all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0 |
|
normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1) |
|
all_indices[right] = all_indices[left] + all_indices[right] |
|
|
|
new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) |
|
all_sim[right, :] = new_sim.T |
|
all_sim[:, right:right+1] = new_sim |
|
all_sim[right, right] = -100.0 |
|
|
|
cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]]) |
|
normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]]) |
|
cur_indices = all_indices[:left] + all_indices[left + 1:] |
|
cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) |
|
cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) |
|
step_indices.append(cur_indices) |
|
|
|
return cur_feature, cur_sim, step_indices |
|
|
|
|
|
def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2): |
|
T, P, D = img_feature.shape |
|
T0 = video_max_frames |
|
if T <= T0: |
|
return img_feature, None |
|
cur_feature = img_feature[:T0] |
|
turing_memory = cur_feature.reshape(T0*P, D) |
|
for i in range(T0, T, T0): |
|
j = min(i + T0, T) |
|
new_feature = img_feature[i:j] |
|
new_feature = new_feature.reshape(-1, D) |
|
turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) |
|
cur_feature = turing_memory.reshape(T0, P, D) |
|
|
|
return cur_feature, None |
|
|