anonymous-upload-neurips-2025 commited on
Commit
cb3baf9
·
verified ·
1 Parent(s): d0a9518

Upload zero_shot_classification.py

Browse files
Files changed (1) hide show
  1. zero_shot_classification.py +339 -0
zero_shot_classification.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import clip
3
+ import torch
4
+ import open_clip
5
+
6
+ import numpy as np
7
+ from torchvision.datasets import CIFAR100
8
+ from tqdm import tqdm
9
+ import torchvision.transforms as transforms
10
+ import warnings
11
+ with warnings.catch_warnings():
12
+ warnings.simplefilter(action='ignore', category=FutureWarning)
13
+ with warnings.catch_warnings():
14
+ warnings.simplefilter(action='ignore', category=UserWarning)
15
+ import torchvision
16
+
17
+ import pandas as pd
18
+ from pathlib import Path
19
+ from PIL import Image
20
+ from torch.utils.data import Dataset, DataLoader
21
+ import pickle
22
+
23
+
24
+ class FACET(Dataset):
25
+ """Face Landmarks dataset."""
26
+
27
+ def __init__(self, paths, labels, root_dir, file_extension=".jpg", transform=None):
28
+ """
29
+ Arguments:
30
+ csv_file (string): Path to the csv file with annotations.
31
+ root_dir (string): Directory with all the images.
32
+ transform (callable, optional): Optional transform to be applied
33
+ on a sample.
34
+ """
35
+ self.fpaths = paths
36
+ self.extension = file_extension
37
+ self.labels = labels
38
+ self.root_dir = root_dir
39
+ self.transform = transform
40
+
41
+ def __len__(self):
42
+ return len(self.fpaths)
43
+
44
+ def __getitem__(self, idx):
45
+ if torch.is_tensor(idx):
46
+ idx = idx.tolist()
47
+
48
+ img_name = os.path.join(self.root_dir,
49
+ str(self.fpaths[idx])+self.extension)
50
+ image = self.transform(Image.open(img_name).convert('RGB'))
51
+ label = self.labels[idx]
52
+
53
+ return image, label
54
+
55
+ imagenet_templates = [
56
+ 'a bad photo of a {}.',
57
+ 'a photo of many {}.',
58
+ 'a sculpture of a {}.',
59
+ 'a photo of the hard to see {}.',
60
+ 'a low resolution photo of the {}.',
61
+ 'a rendering of a {}.',
62
+ 'graffiti of a {}.',
63
+ 'a bad photo of the {}.',
64
+ 'a cropped photo of the {}.',
65
+ 'a tattoo of a {}.',
66
+ 'the embroidered {}.',
67
+ 'a photo of a hard to see {}.',
68
+ 'a bright photo of a {}.',
69
+ 'a photo of a clean {}.',
70
+ 'a photo of a dirty {}.',
71
+ 'a dark photo of the {}.',
72
+ 'a drawing of a {}.',
73
+ 'a photo of my {}.',
74
+ 'the plastic {}.',
75
+ 'a photo of the cool {}.',
76
+ 'a close-up photo of a {}.',
77
+ 'a black and white photo of the {}.',
78
+ 'a painting of the {}.',
79
+ 'a painting of a {}.',
80
+ 'a pixelated photo of the {}.',
81
+ 'a sculpture of the {}.',
82
+ 'a bright photo of the {}.',
83
+ 'a cropped photo of a {}.',
84
+ 'a plastic {}.',
85
+ 'a photo of the dirty {}.',
86
+ 'a jpeg corrupted photo of a {}.',
87
+ 'a blurry photo of the {}.',
88
+ 'a photo of the {}.',
89
+ 'a good photo of the {}.',
90
+ 'a rendering of the {}.',
91
+ 'a {} in a video game.',
92
+ 'a photo of one {}.',
93
+ 'a doodle of a {}.',
94
+ 'a close-up photo of the {}.',
95
+ 'a photo of a {}.',
96
+ 'the origami {}.',
97
+ 'the {} in a video game.',
98
+ 'a sketch of a {}.',
99
+ 'a doodle of the {}.',
100
+ 'a origami {}.',
101
+ 'a low resolution photo of a {}.',
102
+ 'the toy {}.',
103
+ 'a rendition of the {}.',
104
+ 'a photo of the clean {}.',
105
+ 'a photo of a large {}.',
106
+ 'a rendition of a {}.',
107
+ 'a photo of a nice {}.',
108
+ 'a photo of a weird {}.',
109
+ 'a blurry photo of a {}.',
110
+ 'a cartoon {}.',
111
+ 'art of a {}.',
112
+ 'a sketch of the {}.',
113
+ 'a embroidered {}.',
114
+ 'a pixelated photo of a {}.',
115
+ 'itap of the {}.',
116
+ 'a jpeg corrupted photo of the {}.',
117
+ 'a good photo of a {}.',
118
+ 'a plushie {}.',
119
+ 'a photo of the nice {}.',
120
+ 'a photo of the small {}.',
121
+ 'a photo of the weird {}.',
122
+ 'the cartoon {}.',
123
+ 'art of the {}.',
124
+ 'a drawing of the {}.',
125
+ 'a photo of the large {}.',
126
+ 'a black and white photo of a {}.',
127
+ 'the plushie {}.',
128
+ 'a dark photo of a {}.',
129
+ 'itap of a {}.',
130
+ 'graffiti of the {}.',
131
+ 'a toy {}.',
132
+ 'itap of my {}.',
133
+ 'a photo of a cool {}.',
134
+ 'a photo of a small {}.',
135
+ 'a tattoo of the {}.',
136
+ ]
137
+
138
+ models = (
139
+ # CLIP OpenAI
140
+ "ViT-B/16",
141
+ "ViT-B/32",
142
+ "ViT-L/14",
143
+ "RN50",
144
+ "RN101",
145
+
146
+ # CLIP OpenCLIP
147
+ "vit_b_16_400m",
148
+ "vit_b_16_2b",
149
+ "vit_l_14_400m",
150
+ "vit_l_14_2b",
151
+ "vit_b_32_400m",
152
+ "vit_b_32_2b",
153
+ )
154
+ weights = (
155
+ # CLIP OpenAI
156
+ "OpenAI hub",
157
+ "OpenAI hub",
158
+ "OpenAI hub",
159
+ "OpenAI hub",
160
+ "OpenAI hub",
161
+ # CLIP OpenCLIP
162
+ "OpenCLIP hub",
163
+ "OpenCLIP hub",
164
+ "OpenCLIP hub",
165
+ "OpenCLIP hub",
166
+ "OpenCLIP hub",
167
+ "OpenCLIP hub",
168
+ )
169
+
170
+
171
+ facet_annotations_file_path = "INSERT_HERE/annotations.csv"
172
+ facet_root = "?????" # where the in-painted images are stored, the following structure is expected:
173
+ # facet_root/
174
+ # facet_paper_skin_ours/
175
+ # facet_paper_clothes_only/
176
+ # facet_paper_skin_ours_occupation_prompt/
177
+ # facet_paper_clothes_only_occupation_prompt/
178
+ # facet_paper_whole_body/
179
+ # facet_paper_whole_body_occupation_prompt/
180
+
181
+ facet = pd.read_csv(facet_annotations_file_path, header=0).rename(columns={'Unnamed: 0': 'sample_idx'})
182
+ device = "cuda" if torch.cuda.is_available() else "cpu"
183
+
184
+
185
+ experiments = ["facet_paper_skin_ours", "facet_paper_clothes_only", "facet_paper_skin_ours_occupation_prompt", "facet_paper_clothes_only_occupation_prompt",
186
+ "facet_paper_whole_body", "facet_paper_whole_body_occupation_prompt"
187
+ ]
188
+
189
+
190
+ for experiment in experiments:
191
+ for model_name, weight in zip(models, weights):
192
+ print( "\n\n",model_name, experiment)
193
+ preprocess = None
194
+ if model_name == "vit_b_16_400m":
195
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion400m_e32')
196
+ clip_src = "OpenCLIP"
197
+ elif model_name == "vit_b_16_2b":
198
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained='laion2b_s34b_b88k')
199
+ clip_src = "OpenCLIP"
200
+ elif model_name == "vit_b_32_400m":
201
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion400m_e32')
202
+ clip_src = "OpenCLIP"
203
+ elif model_name == "vit_b_32_2b":
204
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
205
+ clip_src = "OpenCLIP"
206
+ elif model_name == "vit_l_14_400m":
207
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion400m_e32')
208
+ clip_src = "OpenCLIP"
209
+ elif model_name == "vit_l_14_2b":
210
+ model, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
211
+ clip_src = "OpenCLIP"
212
+ elif "ViT" in model_name:
213
+ model, preprocess = clip.load(model_name, device)
214
+ clip_src = "OpenAI"
215
+ elif "RN" in model_name:
216
+ model, preprocess = clip.load(model_name, device)
217
+ clip_src = "OpenAI"
218
+ else:
219
+ raise NotImplementedError
220
+
221
+ model.cuda()
222
+
223
+ occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter']
224
+
225
+ tokens_occupations = clip.tokenize([f"A photo of a {occupation}" for occupation in occupations]).cuda()
226
+
227
+ facet_img_root = facet_save_root + "/" experiment + "/"
228
+ out_dir = experiment + "_zero_shot"
229
+ if not os.path.exists(out_dir):
230
+ os.makedirs(out_dir)
231
+
232
+
233
+ fnames = list(os.listdir(facet_img_root))
234
+
235
+ for attribute_value in ["only_original_male", "only_original_female", "original", "male_to_female", "male_to_male", "female_to_female", "female_to_male"]:
236
+ print(f"----{attribute_value}----")
237
+ facet = pd.read_csv("../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) # Bounding boxes
238
+ extension = ".png"
239
+
240
+ processed_synthetic_samples = set()
241
+
242
+ for fname in fnames:
243
+ bbid, target_attr = fname.split("_")[0], "_".join(fname.split("_")[1:]).split(".")[0]
244
+
245
+ if "only" in attribute_value:
246
+ if target_attr=="original" and bbid not in processed_synthetic_samples:
247
+ processed_synthetic_samples.add(int(bbid))
248
+ elif target_attr==attribute_value and bbid not in processed_synthetic_samples:
249
+ processed_synthetic_samples.add(int(bbid))
250
+
251
+ if attribute_value == "only_original_male":
252
+ facet = facet[facet.person_id.isin(processed_synthetic_samples)]
253
+ facet = facet[facet.gender_presentation_na != 1]
254
+ facet = facet[facet.gender_presentation_non_binary != 1]
255
+ facet = facet[(facet.gender_presentation_masc == 1)]
256
+ elif attribute_value == "only_original_female":
257
+ facet = facet[facet.person_id.isin(processed_synthetic_samples)]
258
+ facet = facet[facet.gender_presentation_na != 1]
259
+ facet = facet[facet.gender_presentation_non_binary != 1]
260
+ facet = facet[(facet.gender_presentation_fem == 1)]
261
+ else:
262
+ facet = facet[facet.person_id.isin(processed_synthetic_samples)]
263
+ facet = facet[facet.gender_presentation_na != 1]
264
+ facet = facet[facet.gender_presentation_non_binary != 1]
265
+ facet = facet[(facet.gender_presentation_masc == 1) | (facet.gender_presentation_fem == 1)]
266
+
267
+
268
+ facet["class1"] = facet["class1"].apply(lambda val: int(occupations.index(val)))
269
+
270
+ bsize = 512
271
+ predictions = []
272
+ acc = my_acc = 0
273
+ n_batches = 0
274
+
275
+ def zeroshot_classifier(classnames, templates):
276
+ with torch.no_grad():
277
+ zeroshot_weights = []
278
+ for classname in tqdm(classnames):
279
+ texts = [template.format(classname) for template in templates] #format with class
280
+ texts = clip.tokenize(texts).cuda() #tokenize
281
+ class_embeddings = model.encode_text(texts) #embed with text encoder
282
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
283
+ class_embedding = class_embeddings.mean(dim=0)
284
+ class_embedding /= class_embedding.norm()
285
+ zeroshot_weights.append(class_embedding)
286
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
287
+ return zeroshot_weights
288
+
289
+
290
+ if "only" in attribute_value:
291
+ dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_original.png")
292
+ else:
293
+ dataset = FACET(facet.person_id.values, torch.tensor(facet.class1.values), facet_img_root, transform=preprocess, file_extension=f"_{attribute_value}.png")
294
+ dataloader = DataLoader(dataset, batch_size=bsize, shuffle=False, num_workers=6, drop_last=False,)
295
+
296
+ zeroshot_weights = zeroshot_classifier(occupations[:39], imagenet_templates)
297
+
298
+
299
+ for imgs, labels in tqdm(dataloader):
300
+
301
+ with torch.no_grad(), torch.cuda.amp.autocast():
302
+ if clip_src == "OpenAI":
303
+ # CLIP
304
+ image_features = model.encode_image(imgs.half().cuda())
305
+ image_features /= image_features.norm(dim=-1, keepdim=True)
306
+ logits = 100. * image_features @ zeroshot_weights
307
+ probs = logits.softmax(dim=-1).cpu().numpy()
308
+ else:
309
+ # OpenCLIP
310
+ image_features = model.encode_image(imgs.half().cuda())
311
+ image_features /= image_features.norm(dim=-1, keepdim=True)
312
+ probs = (100. * image_features @ zeroshot_weights).softmax(dim=-1).cpu().numpy()
313
+
314
+ preds_batch = np.argmax(probs, axis=-1)
315
+ predictions += preds_batch.tolist()
316
+ acc += torch.sum(torch.tensor(preds_batch).cuda()==labels.cuda()) / preds_batch.shape[0]
317
+ n_batches += 1
318
+
319
+
320
+ print(model_name, "acc: ", acc / n_batches, "%")
321
+
322
+ results = pd.DataFrame({"person_id": facet.person_id.values,
323
+ "inpainted_attribute": attribute_value,
324
+ "age_presentation_young": facet.age_presentation_young.values,
325
+ "age_presentation_middle": facet.age_presentation_middle.values,
326
+ "age_presentation_older": facet.age_presentation_older.values,
327
+ "gender_presentation_fem": facet.gender_presentation_fem.values,
328
+ "gender_presentation_masc": facet.gender_presentation_masc.values,
329
+ "gt_class_label": facet.class1.values,
330
+ "class_predictions": predictions
331
+ })
332
+
333
+ results.to_csv(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_predictions.csv')
334
+
335
+ with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}_accuracy.txt', "w") as of:
336
+ of.write(str((acc/n_batches).item()))
337
+
338
+ with open(f'{out_dir}/{model_name.replace("/", "_").replace("-", "_")}_{attribute_value}.pkl', "wb") as f:
339
+ pickle.dump(predictions, f)