import matplotlib.pyplot as plt import numpy as np import torch from torchvision import transforms import sys import os import cv2 from FCN import FCN8s, FCN16s, FCN32s, FCNs, VGGNet from tqdm import tqdm import pandas as pd device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = torch.load('./models/fcn_model_2.pt') #Load the model model = model.to(device) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if __name__ =='__main__': occupations = ['backpacker', 'ballplayer', 'bartender', 'basketball_player', 'boatman', 'carpenter', 'cheerleader', 'climber', 'computer_user', 'craftsman', 'dancer', 'disk_jockey', 'doctor', 'drummer', 'electrician', 'farmer', 'fireman', 'flutist', 'gardener', 'guard', 'guitarist', 'gymnast', 'hairdresser', 'horseman', 'judge', 'laborer', 'lawman', 'lifeguard', 'machinist', 'motorcyclist', 'nurse', 'painter', 'patient', 'prayer', 'referee', 'repairman', 'reporter', 'retailer', 'runner', 'sculptor', 'seller', 'singer', 'skateboarder', 'soccer_player', 'soldier', 'speaker', 'student', 'teacher', 'tennis_player', 'trumpeter', 'waiter'] facet = pd.read_csv("../../../datasets/facet/annotations/annotations.csv", header=0).rename(columns={'Unnamed: 0': 'sample_idx'}) # Bounding boxes selected_categories = ['dancer', 'craftsman', 'gardener', 'retailer'] text_lists = [ [f"a photo of a {category}" for category in selected_categories], ] root = "../../../datasets/facet/images_bb_small" # For FACET for category in occupations: n_imgs = facet[facet['class1'] == category]['person_id'].shape[0] print(n_imgs) for id_ in tqdm(range(n_imgs)): img = facet[facet['class1'] == category].iloc[id_] if int(img['gender_presentation_masc']) == 1: gender = 'male' elif int(img['gender_presentation_fem']) == 1: gender = 'female' else: continue bb = eval(img["bounding_box"]) img_id = str(img['person_id']) img_path = os.path.join(root, img_id + ".jpg") save_file = './paper_output/FACET_FCN_skin' if not os.path.exists(save_file): os.makedirs(save_file) imgA = cv2.imread(img_path) print(img_id) size = imgA.shape imgA = cv2.resize(imgA, (160, 160)) imgA = transform(imgA) imgA = imgA.to(device) imgA = imgA.unsqueeze(0) output = model(imgA) output = torch.sigmoid(output) output_np = output.cpu().detach().numpy().copy() output_np = np.squeeze(output_np)*255 output_npA = output_np[0] output_npB = output_np[1] output_sigmoid = output_npA/(output_npA+output_npB)*255 output = cv2.resize(output_sigmoid,(size[1],size[0])) ret, output = cv2.threshold(output,127,255,cv2.THRESH_BINARY) cv2.imwrite(f"{save_file}/{img_id}.jpg", output)