diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..f86af2e3eb03b083099fa5c729efa7f61af0a07f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0003.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0002.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0003.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Poignee_cuvette/Poignee_cuvette_0003.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0001.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0002.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Cuvette/Serrure_Cuvette_0001.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Cuvette/Serrure_Cuvette_0003.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Gaches/Serrure_Gaches_0002.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Gaches/Serrure_Gaches_0003.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0001.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0002.jpg filter=lfs diff=lfs merge=lfs -text +imgs/Serrure_pour_Porte/Serrure_pour_Porte_0003.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 7485776b278b5bab36e05fcb27c60e4a4ba34295..c0d6edad11dc5083f882c3833c2bb48afa820f93 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ --- -title: OGSO IA -emoji: 🐠 +title: Test 001 +emoji: 📚 colorFrom: green colorTo: purple sdk: gradio diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..daf9ab4fb2effacef40fccba60181f2d72916efd --- /dev/null +++ b/app.py @@ -0,0 +1,245 @@ +import gdown +import os +import torch +import requests +import numpy as np +import numpy.matlib +import copy +import cv2 +from PIL import Image +from typing import List +import timm +import gradio as gr +import torchvision.transforms as transforms + +from pim_module import PluginMoodel # Assure-toi que ce fichier est prĂ©sent + +# === TĂ©lĂ©chargement automatique depuis Google Drive === +if not os.path.exists("weights.pt"): + print("TĂ©lĂ©chargement des poids depuis Google Drive avec gdown...") + file_id = "1Ck9qyjs4_c_fqgaEpZ0eN9jIV5TiqkXp" + url = f"https://drive.google.com/uc?id={file_id}" + gdown.download(url, "weights.pt", quiet=False) + + + +# === Classes +classes_list = [ + "Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE", + "Ferrage_et_accessoires_Busettes", + "Ferrage_et_accessoires_Butees", + "Ferrage_et_accessoires_Chariots", + "Ferrage_et_accessoires_Charniere", + "Ferrage_et_accessoires_Compas_limiteur", + "Ferrage_et_accessoires_Renvois_d'angle", + "Joints_et_consommables_Equerres_aluminium_moulees", + "Joints_et_consommables_Joints_a_clipser", + "Joints_et_consommables_Joints_a_coller", + "Joints_et_consommables_Joints_a_glisser", + "Joints_et_consommables_Joints_EPDM", + "Joints_et_consommables_Joints_PVC_aluminium", + "Joints_et_consommables_Silicone_pour_vitrage_alu", + "Joints_et_consommables_Visserie_inox_alu", + "Poignee_carre_7_mm", + "Poignee_carre_8_mm", + "Poignee_cremone", + "Poignee_cuvette", + "Poignee_de_tirage", + "Poignee_pour_Levant_Coulissant", + "Serrure_Cremone_multipoints", + "Serrure_Cuvette", + "Serrure_Gaches", + "Serrure_Pene_Crochet", + "Serrure_pour_Porte", + "Serrure_Tringles", +] + +short_classes_list = [ + "Anti-fausse-manoeuvre", + "Busettes", + "ButĂ©es", + "Chariots", + "CharniĂšre", + "Compas-limiteur", + "Renvois d'angle", + "Equerres aluminium moulĂ©es", + "Joints Ă  clipser", + "Joints Ă  coller", + "Joints Ă  glisser", + "Joints EPDM", + "Joints PVC aluminium", + "Silicone pour vitrage alu", + "Visserie inox alu", + "PoignĂ©e carrĂ© 7 mm", + "PoignĂ©e carrĂ© 8 mm", + "PoignĂ©e crĂ©mone", + "PoignĂ©e cuvette", + "PoignĂ©e de tirage", + "PoignĂ©e pour levant coulissant", + "Serrure crĂ©mone multipoints", + "Serrure cuvette", + "Serrure gaches", + "Serrure pene crochet", + "Serrure pour porte", + "Serrure tringles", +] + +data_size = 384 +fpn_size = 1536 +num_classes = 27 +num_selects = {'layer1': 256, 'layer2': 128, 'layer3': 64, 'layer4': 32} +features, grads, module_id_mapper = {}, {}, {} + +def forward_hook(module, inp_hs, out_hs): + layer_id = len(features) + 1 + module_id_mapper[module] = layer_id + features[layer_id] = {"in": inp_hs, "out": out_hs} + +def backward_hook(module, inp_grad, out_grad): + layer_id = module_id_mapper[module] + grads[layer_id] = {"in": inp_grad, "out": out_grad} + +def build_model(path: str): + backbone = timm.create_model('swin_large_patch4_window12_384_in22k', pretrained=True) + model = PluginMoodel( + backbone=backbone, + return_nodes=None, + img_size=data_size, + use_fpn=True, + fpn_size=fpn_size, + proj_type="Linear", + upsample_type="Conv", + use_selection=True, + num_classes=num_classes, + num_selects=num_selects, + use_combiner=True, + comb_proj_size=None + ) + ckpt = torch.load(path, map_location="cpu", weights_only=False) + model.load_state_dict(ckpt["model"], strict=False) + model.eval() + + for layer in [0, 1, 2, 3]: + model.backbone.layers[layer].register_forward_hook(forward_hook) + model.backbone.layers[layer].register_full_backward_hook(backward_hook) + + for i in range(1, 5): + getattr(model.fpn_down, f'Proj_layer{i}').register_forward_hook(forward_hook) + getattr(model.fpn_down, f'Proj_layer{i}').register_full_backward_hook(backward_hook) + getattr(model.fpn_up, f'Proj_layer{i}').register_forward_hook(forward_hook) + getattr(model.fpn_up, f'Proj_layer{i}').register_full_backward_hook(backward_hook) + + return model + +class ImgLoader: + def __init__(self, img_size): + self.transform = transforms.Compose([ + transforms.Resize((510, 510), Image.BILINEAR), + transforms.CenterCrop((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def load(self, input_img): + if isinstance(input_img, str): + ori_img = cv2.imread(input_img) + img = Image.fromarray(cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)) + elif isinstance(input_img, Image.Image): + img = input_img + else: + raise ValueError("Image invalide") + + if img.mode != "RGB": + img = img.convert("RGB") + + return self.transform(img).unsqueeze(0) + +def cal_backward(out) -> dict: + target_layer_names = ['layer1', 'layer2', 'layer3', 'layer4', + 'FPN1_layer1', 'FPN1_layer2', 'FPN1_layer3', 'FPN1_layer4', 'comb_outs'] + + sum_out = None + for name in target_layer_names: + tmp_out = out[name].mean(1) if name != "comb_outs" else out[name] + tmp_out = torch.softmax(tmp_out, dim=-1) + sum_out = tmp_out if sum_out is None else sum_out + tmp_out + + with torch.no_grad(): + smax = torch.softmax(sum_out, dim=-1) + A = np.transpose(np.matlib.repmat(smax[0], num_classes, 1)) - np.eye(num_classes) + _, _, V = np.linalg.svd(A, full_matrices=True) + V = V[num_classes - 1, :] + if V[0] < 0: + V = -V + V = np.log(V) + V = V - min(V) + V = V / sum(V) + + top5_indices = np.argsort(-V)[:5] + top5_scores = -np.sort(-V)[:5] + + # Construction du dictionnaire pour gr.Label + top5_dict = {classes_list[int(idx)]: float(f"{score:.4f}") for idx, score in zip(top5_indices, top5_scores)} + return top5_dict + +# === Chargement du modĂšle +model = build_model("weights.pt") +img_loader = ImgLoader(data_size) + + + +def predict_image(image: Image.Image): + global features, grads, module_id_mapper + features, grads, module_id_mapper = {}, {}, {} + + if image is None: + return {} +# raise ValueError("Aucune image reçue. VĂ©rifie l'entrĂ©e.") + + if image.mode != "RGB": + image = image.convert("RGB") + + image_path = "temp.jpg" + image.save(image_path) + + img_tensor = img_loader.load(image_path) + out = model(img_tensor) + top5_dict = cal_backward(out) # {classe: score} + + gallery_outputs = [] + for idx, class_name in enumerate(list(top5_dict.keys())): + images = [ + (f"imgs/{class_name}/{class_name}_0001.jpg", f"Exemple {class_name} 1"), + (f"imgs/{class_name}/{class_name}_0002.jpg", f"Exemple {class_name} 2"), + (f"imgs/{class_name}/{class_name}_0003.jpg", f"Exemple {class_name} 3"), + ] + gallery_outputs.append(images) + + return top5_dict, *gallery_outputs + + +# === Interface Gradio +with gr.Blocks(css=""" +.gr-image-upload { display: none !important } +.gallery-container .gr-box { height: auto !important; padding: 0 !important; } +""") as demo: + with gr.Row(): + with gr.Column(scale=1): + with gr.Tab("TĂ©lĂ©versement"): + image_input_upload = gr.Image(type="pil", label="Image Ă  classer (upload)", sources=["upload"]) + with gr.Tab("Webcam"): + image_input_webcam = gr.Image(type="pil", label="Image Ă  classer (webcam)", sources=["webcam"]) + + with gr.Column(scale=1.5): + label_output = gr.Label(label="PrĂ©dictions") + gallery_outputs = [ + gr.Gallery(label=f"", columns=3, height=300, container=True, elem_classes=["gallery-container"]) + for i in range(5) + ] + + image_input_upload.change(fn=predict_image, inputs=image_input_upload, outputs=[label_output] + gallery_outputs) + image_input_webcam.change(fn=predict_image, inputs=image_input_webcam, outputs=[label_output] + gallery_outputs) + +if __name__ == "__main__": + demo.launch() + diff --git a/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0001.jpg b/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c94f8ce7f30972e1016c5bc6593f15b1964924ab Binary files /dev/null and b/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0002.jpg b/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33a8550bd46cc23a82eebb7384d56fcb49e1ed82 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0003.jpg b/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..db0d82f198792aa12ec532059102e46c979dc459 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE/Ferrage_et_accessoires_ANTI_FAUSSE_MANOEUVRE_0003.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0001.jpg b/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..45ce7f1348d102f4462089868de9077c45fd6cea Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0002.jpg b/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cc983d551c94c9c76869b3cf1ff54342eb313bac Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0003.jpg b/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fd70c4dc21abf929d097e72afbc7838846218120 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Busettes/Ferrage_et_accessoires_Busettes_0003.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0001.jpg b/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2109ebb6daa2313a2b4d3038a8e7fbad894b8b7b Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0002.jpg b/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4f8672d17f258b088501c3509b1693b5464ca665 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0003.jpg b/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3257961378c4e7c6ad523280b60ba2630d7a3153 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Butees/Ferrage_et_accessoires_Butees_0003.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0001.jpg b/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6ff0a99d88331c307a282cba29571b28a9fa51ca Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0002.jpg b/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6b3cfee8d5e72d0861498968b38a0a3db6e42f86 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0003.jpg b/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..49c40079b4b814203fb1638dc10f85e025a8129d Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Chariots/Ferrage_et_accessoires_Chariots_0003.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0001.jpg b/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fb37b071dcb75745dcf3580021210f2fb7e06b4d Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0002.jpg b/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3add3e68a8c8bd6be0531ff2652d50410d30212c Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0003.jpg b/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..50ece9ee38e59f815173ec57e4f7847aa563fc03 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Charniere/Ferrage_et_accessoires_Charniere_0003.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0001.jpg b/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..58c4a56445dc38d41b4f5697219e74356e126bf3 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0002.jpg b/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5040ce8687a29703821488f7b46c76684c592289 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0003.jpg b/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cfb923cc6f82ebaa08cb5a2da09f8dd90e6f24b5 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Compas_limiteur/Ferrage_et_accessoires_Compas_limiteur_0003.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0001.jpg b/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..903dec279ec87cac0e5d8b071afae8463efcc71b Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0001.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0002.jpg b/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b73c130ef87dc874a23bce2ba64c779d30abb2fb Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0002.jpg differ diff --git a/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0003.jpg b/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7966ee64eb7a709103f43f73daa16a10a8f6f0a4 Binary files /dev/null and b/imgs/Ferrage_et_accessoires_Renvois_d'angle/Ferrage_et_accessoires_Renvois_d'angle_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0001.jpg b/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c85b764f781b9e7243d5c4456ae8910e5a5d14ec Binary files /dev/null and b/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0002.jpg b/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a762c5e102454f4b9709304ce9eb4688f59b3b18 Binary files /dev/null and b/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0003.jpg b/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f39d721d80cc8caa23fc24dbdd06e6a939b0dadb Binary files /dev/null and b/imgs/Joints_et_consommables_Equerres_aluminium_moulees/Joints_et_consommables_Equerres_aluminium_moulees_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0001.jpg b/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3938d693d5a2035fd18117e042107dadedfc8584 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0002.jpg b/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..df9667b920402c5e45cd3bf8fe7cc38fac8fb3af Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0003.jpg b/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9190ec6ac29f5448f156ad2b6fbcc8d550bad586 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_EPDM/Joints_et_consommables_Joints_EPDM_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0001.jpg b/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1a6f8e358cb656cb883254354182f454b58357da Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0002.jpg b/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2b32dae2695d46d1f24e60e9521b0a22592bbbe0 --- /dev/null +++ b/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0002.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:638fbeb70ccfd165b38132b41ab252c4b3cb484a81aaa6783ea478a0fb35e662 +size 177728 diff --git a/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0003.jpg b/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..164b301b683c7241546a0b7a0e55b70409369abd Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_PVC_aluminium/Joints_et_consommables_Joints_PVC_aluminium_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0001.jpg b/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0d17c40af79c8f19a7797a356972c68d310266da Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0002.jpg b/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d88b40e91733c2b2c1a05755cd8b8075e2007915 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0003.jpg b/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d0b7e4a47faff7138d3eb60b6c2c8b151d1c74e8 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_clipser/Joints_et_consommables_Joints_a_clipser_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0001.jpg b/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..aea35c653a68bf5c2c519f0a7875d95671199f50 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0002.jpg b/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dc19e6feaadc6157e88f605783021ab73de40309 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0003.jpg b/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d960a29e8a32d63493821ba3bed0449c75451b89 --- /dev/null +++ b/imgs/Joints_et_consommables_Joints_a_coller/Joints_et_consommables_Joints_a_coller_0003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2446ab7baa77f0e7609b1b3c3a291fe9e39ffee7488e1b06cf276420bc4e431b +size 593462 diff --git a/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0001.jpg b/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ffb6c1af21147923bd9618bc103ae5006231be89 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0002.jpg b/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d7691b0128fec1ddaeeec0b76228c97449eed3b2 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0003.jpg b/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..799d381a233b00d06f727bad89e41ece38c62142 Binary files /dev/null and b/imgs/Joints_et_consommables_Joints_a_glisser/Joints_et_consommables_Joints_a_glisser_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0001.jpg b/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c182411b8935d2c593884d0dbb8e077a618356b Binary files /dev/null and b/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0002.jpg b/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4055cdbb8b2f22470f0e165498dd8c59204734a4 Binary files /dev/null and b/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0003.jpg b/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..653213d2da1b08e2f1e5b1d4e903b02c2fa800a8 Binary files /dev/null and b/imgs/Joints_et_consommables_Silicone_pour_vitrage_alu/Joints_et_consommables_Silicone_pour_vitrage_alu_0003.jpg differ diff --git a/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0001.jpg b/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8a0c078c033e5c770bbd59bf70f9ef428f7c33cc Binary files /dev/null and b/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0001.jpg differ diff --git a/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0002.jpg b/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7d611119ad62600347ec606aad13503c178f5f3e Binary files /dev/null and b/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0002.jpg differ diff --git a/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0003.jpg b/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cf6c251bac429d737ebf2fa81d8eb513c082fcd4 Binary files /dev/null and b/imgs/Joints_et_consommables_Visserie_inox_alu/Joints_et_consommables_Visserie_inox_alu_0003.jpg differ diff --git a/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0001.jpg b/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce72f93ece72dc0c723052bebefb4020d4317758 Binary files /dev/null and b/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0001.jpg differ diff --git a/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0002.jpg b/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fbf1db378c4a9e8ea8665f8f4e25f1ad36043b7d Binary files /dev/null and b/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0002.jpg differ diff --git a/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0003.jpg b/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1e374eec49a9873efd672e52e869ea8876c1fd0 Binary files /dev/null and b/imgs/Poignee_carre_7_mm/Poignee_carre_7_mm_0003.jpg differ diff --git a/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0001.jpg b/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ef88139f463b5e8fdcd2156df4ddee8befedc99 Binary files /dev/null and b/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0001.jpg differ diff --git a/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0002.jpg b/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..21676eaa67a3e4f4ab1128b72fe1f68d91e099e4 Binary files /dev/null and b/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0002.jpg differ diff --git a/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0003.jpg b/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a197265c8c6f8c936822def6f86f4d1cf46d3d94 --- /dev/null +++ b/imgs/Poignee_carre_8_mm/Poignee_carre_8_mm_0003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:62258eb13f64a7a5facf5cfe408b0876568e0d5c4297320b02c3a42637e96f26 +size 162146 diff --git a/imgs/Poignee_cremone/Poignee_cremone_0001.jpg b/imgs/Poignee_cremone/Poignee_cremone_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..37f2c56cd90cc2c9bf4749a4214a6d504878a6c4 Binary files /dev/null and b/imgs/Poignee_cremone/Poignee_cremone_0001.jpg differ diff --git a/imgs/Poignee_cremone/Poignee_cremone_0002.jpg b/imgs/Poignee_cremone/Poignee_cremone_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99996d058ccd657fcb20ddfe58f84f1eb60abd58 Binary files /dev/null and b/imgs/Poignee_cremone/Poignee_cremone_0002.jpg differ diff --git a/imgs/Poignee_cremone/Poignee_cremone_0003.jpg b/imgs/Poignee_cremone/Poignee_cremone_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ba37bef12012c8c4dc8fdfe2ac604ffe9ee68922 Binary files /dev/null and b/imgs/Poignee_cremone/Poignee_cremone_0003.jpg differ diff --git a/imgs/Poignee_cuvette/Poignee_cuvette_0001.jpg b/imgs/Poignee_cuvette/Poignee_cuvette_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dfebee86aedbb5298c94fd3821a0b4a6269b6119 Binary files /dev/null and b/imgs/Poignee_cuvette/Poignee_cuvette_0001.jpg differ diff --git a/imgs/Poignee_cuvette/Poignee_cuvette_0002.jpg b/imgs/Poignee_cuvette/Poignee_cuvette_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a9f5c3f8e798520fb858f2c9eb3027c193c8fac0 Binary files /dev/null and b/imgs/Poignee_cuvette/Poignee_cuvette_0002.jpg differ diff --git a/imgs/Poignee_cuvette/Poignee_cuvette_0003.jpg b/imgs/Poignee_cuvette/Poignee_cuvette_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f51232c78c908007d9515eba5a2b7ade5c695cc3 --- /dev/null +++ b/imgs/Poignee_cuvette/Poignee_cuvette_0003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:279a4b32b44fb70f4c29b599cef9aac1b659d049e6cba0ed24425d72f732b367 +size 591204 diff --git a/imgs/Poignee_de_tirage/Poignee_de_tirage_0001.jpg b/imgs/Poignee_de_tirage/Poignee_de_tirage_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f123b3112de43482464ce3ed618d6436238d3c3e Binary files /dev/null and b/imgs/Poignee_de_tirage/Poignee_de_tirage_0001.jpg differ diff --git a/imgs/Poignee_de_tirage/Poignee_de_tirage_0002.jpg b/imgs/Poignee_de_tirage/Poignee_de_tirage_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e248e5ad931445ebbf437769d5e7ba39977d2ec Binary files /dev/null and b/imgs/Poignee_de_tirage/Poignee_de_tirage_0002.jpg differ diff --git a/imgs/Poignee_de_tirage/Poignee_de_tirage_0003.jpg b/imgs/Poignee_de_tirage/Poignee_de_tirage_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e7a50a00472268b799ebc4360db71dfb49e296c Binary files /dev/null and b/imgs/Poignee_de_tirage/Poignee_de_tirage_0003.jpg differ diff --git a/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0001.jpg b/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8f5bce37580eb6edb77a87491bca51ce7e4b5fa9 --- /dev/null +++ b/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:15093840d24ebd335bee4d34e9fcdf46034321e8b0506954cbeb33ed9bcd63a3 +size 149863 diff --git a/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0002.jpg b/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..af61689a3d63011763e93e86ab1f162e2ea66d69 Binary files /dev/null and b/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0002.jpg differ diff --git a/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0003.jpg b/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d4c59a22a73eba485f8b014b2c060cbde0cc1fd Binary files /dev/null and b/imgs/Poignee_pour_Levant_Coulissant/Poignee_pour_Levant_Coulissant_0003.jpg differ diff --git a/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0001.jpg b/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b841e6edbbcf9653400ee1f47c7be6a526f60e37 Binary files /dev/null and b/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0001.jpg differ diff --git a/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0002.jpg b/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6690f11b90900c3d0f7f5799d88b4baf4d3e9316 --- /dev/null +++ b/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0002.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2f21d9cacd49797c0b79aa83a70ed804647eb5ff4052177618730e6f0ff9e8d +size 296372 diff --git a/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0003.jpg b/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..faf5e88142ef9a3c45ac8253c0cc3af3f26e13e1 Binary files /dev/null and b/imgs/Serrure_Cremone_multipoints/Serrure_Cremone_multipoints_0003.jpg differ diff --git a/imgs/Serrure_Cuvette/Serrure_Cuvette_0001.jpg b/imgs/Serrure_Cuvette/Serrure_Cuvette_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ce5ad6ecde8e088f959bfb2ba270b6f6348ec9a6 --- /dev/null +++ b/imgs/Serrure_Cuvette/Serrure_Cuvette_0001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42443b52c19a4285cba38ee0d9b015ee8770a4b701af0cb879a3c35ed51c9a50 +size 206242 diff --git a/imgs/Serrure_Cuvette/Serrure_Cuvette_0002.jpg b/imgs/Serrure_Cuvette/Serrure_Cuvette_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1efa9fcaf265fa5b2d6d253adb83cbe4a05dc4b2 Binary files /dev/null and b/imgs/Serrure_Cuvette/Serrure_Cuvette_0002.jpg differ diff --git a/imgs/Serrure_Cuvette/Serrure_Cuvette_0003.jpg b/imgs/Serrure_Cuvette/Serrure_Cuvette_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c85e2c62479a3129b9d3576d7ff8b723a3d99c8 --- /dev/null +++ b/imgs/Serrure_Cuvette/Serrure_Cuvette_0003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:109598a68b0e7d60439217e1787b509dc3817e850e1d906ed296d72d808d80cb +size 113599 diff --git a/imgs/Serrure_Gaches/Serrure_Gaches_0001.jpg b/imgs/Serrure_Gaches/Serrure_Gaches_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..34cd9d365d089be8c61927f6d41e1ee99e42d8da Binary files /dev/null and b/imgs/Serrure_Gaches/Serrure_Gaches_0001.jpg differ diff --git a/imgs/Serrure_Gaches/Serrure_Gaches_0002.jpg b/imgs/Serrure_Gaches/Serrure_Gaches_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d3f3f52e7486fd739195f80a65e6b15b5772a92 --- /dev/null +++ b/imgs/Serrure_Gaches/Serrure_Gaches_0002.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16ca57da25eb0665d518c6c817c92a5d4744809a4c11445d85ef9393fa655f76 +size 590078 diff --git a/imgs/Serrure_Gaches/Serrure_Gaches_0003.jpg b/imgs/Serrure_Gaches/Serrure_Gaches_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..19b98143b45bc845cc197296abdf36137d22e683 --- /dev/null +++ b/imgs/Serrure_Gaches/Serrure_Gaches_0003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b178f40f76ac3a4469058447250838b77af7f01b53b40375cf6011cfac541599 +size 119643 diff --git a/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0001.jpg b/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76093afc4ff6987a6d59d47cd429b117e833f3c7 --- /dev/null +++ b/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db4af0f694fbbfde4332d75d36573add9706d9202a8212408ffee59be73789cf +size 180982 diff --git a/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0002.jpg b/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..61301739e691ea3d48d64bf4bdbe52859c60a148 --- /dev/null +++ b/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0002.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80357f4049bfa3a7bb2ac98b929b0a005a8213586b4c9bb2fab79e1402072a9b +size 400139 diff --git a/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0003.jpg b/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a4d4472d4d89e66f5d30e1bd925d653d8c11aea9 Binary files /dev/null and b/imgs/Serrure_Pene_Crochet/Serrure_Pene_Crochet_0003.jpg differ diff --git a/imgs/Serrure_Tringles/Serrure_Tringles_0001.jpg b/imgs/Serrure_Tringles/Serrure_Tringles_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6d90fa9fe55773e00d60b5e1d21617e0dab50f51 Binary files /dev/null and b/imgs/Serrure_Tringles/Serrure_Tringles_0001.jpg differ diff --git a/imgs/Serrure_Tringles/Serrure_Tringles_0002.jpg b/imgs/Serrure_Tringles/Serrure_Tringles_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4e30f1346a6a84c3dd6e374c3a52554d1576022b Binary files /dev/null and b/imgs/Serrure_Tringles/Serrure_Tringles_0002.jpg differ diff --git a/imgs/Serrure_Tringles/Serrure_Tringles_0003.jpg b/imgs/Serrure_Tringles/Serrure_Tringles_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1ab3e52bbb4705a979b8012568557aa95f169d0b Binary files /dev/null and b/imgs/Serrure_Tringles/Serrure_Tringles_0003.jpg differ diff --git a/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0001.jpg b/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e33b46162cc61cb6857c42cbc37f9a8944a2aac9 Binary files /dev/null and b/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0001.jpg differ diff --git a/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0002.jpg b/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d98bf5d4f6fc878468afe3b3c759a7717d3acd6d Binary files /dev/null and b/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0002.jpg differ diff --git a/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0003.jpg b/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..05ff03a615ffddda40d74185ba72df1214aa65a0 --- /dev/null +++ b/imgs/Serrure_pour_Porte/Serrure_pour_Porte_0003.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:17a71535c4c9b24cc4a4d2688bbeccba6a9569580f49a1c015fa82aa957e4b0f +size 308504 diff --git a/pim_module.py b/pim_module.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0e7797158b641d34af76b52215c00ab85372fb --- /dev/null +++ b/pim_module.py @@ -0,0 +1,573 @@ +import torch +import torch.nn as nn +import torchvision.models as models +import torch.nn.functional as F +from torchvision.models.feature_extraction import get_graph_node_names +from torchvision.models.feature_extraction import create_feature_extractor +from typing import Union +import copy + +class GCNCombiner(nn.Module): + + def __init__(self, + total_num_selects: int, + num_classes: int, + inputs: Union[dict, None] = None, + proj_size: Union[int, None] = None, + fpn_size: Union[int, None] = None): + """ + If building backbone without FPN, set fpn_size to None and MUST give + 'inputs' and 'proj_size', the reason of these setting is to constrain the + dimension of graph convolutional network input. + """ + super(GCNCombiner, self).__init__() + + assert inputs is not None or fpn_size is not None, \ + "To build GCN combiner, you must give one features dimension." + + ### auto-proj + self.fpn_size = fpn_size + if fpn_size is None: + for name in inputs: + if len(name) == 4: + in_size = inputs[name].size(1) + elif len(name) == 3: + in_size = inputs[name].size(2) + else: + raise ValusError("The size of output dimension of previous must be 3 or 4.") + m = nn.Sequential( + nn.Linear(in_size, proj_size), + nn.ReLU(), + nn.Linear(proj_size, proj_size) + ) + self.add_module("proj_"+name, m) + self.proj_size = proj_size + else: + self.proj_size = fpn_size + + ### build one layer structure (with adaptive module) + num_joints = total_num_selects // 64 + + self.param_pool0 = nn.Linear(total_num_selects, num_joints) + + A = torch.eye(num_joints) / 100 + 1 / 100 + self.adj1 = nn.Parameter(copy.deepcopy(A)) + self.conv1 = nn.Conv1d(self.proj_size, self.proj_size, 1) + self.batch_norm1 = nn.BatchNorm1d(self.proj_size) + + self.conv_q1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) + self.conv_k1 = nn.Conv1d(self.proj_size, self.proj_size//4, 1) + self.alpha1 = nn.Parameter(torch.zeros(1)) + + ### merge information + self.param_pool1 = nn.Linear(num_joints, 1) + + #### class predict + self.dropout = nn.Dropout(p=0.1) + self.classifier = nn.Linear(self.proj_size, num_classes) + + self.tanh = nn.Tanh() + + def forward(self, x): + """ + """ + hs = [] + names = [] + for name in x: + if "FPN1_" in name: + continue + if self.fpn_size is None: + _tmp = getattr(self, "proj_"+name)(x[name]) + else: + _tmp = x[name] + hs.append(_tmp) + names.append([name, _tmp.size()]) + + hs = torch.cat(hs, dim=1).transpose(1, 2).contiguous() # B, S', C --> B, C, S + # print(hs.size(), names) + hs = self.param_pool0(hs) + ### adaptive adjacency + q1 = self.conv_q1(hs).mean(1) + k1 = self.conv_k1(hs).mean(1) + A1 = self.tanh(q1.unsqueeze(-1) - k1.unsqueeze(1)) + A1 = self.adj1 + A1 * self.alpha1 + ### graph convolution + hs = self.conv1(hs) + hs = torch.matmul(hs, A1) + hs = self.batch_norm1(hs) + ### predict + hs = self.param_pool1(hs) + hs = self.dropout(hs) + hs = hs.flatten(1) + hs = self.classifier(hs) + + return hs + +class WeaklySelector(nn.Module): + + def __init__(self, inputs: dict, num_classes: int, num_select: dict, fpn_size: Union[int, None] = None): + """ + inputs: dictionary contain torch.Tensors, which comes from backbone + [Tensor1(hidden feature1), Tensor2(hidden feature2)...] + Please note that if len(features.size) equal to 3, the order of dimension must be [B,S,C], + S mean the spatial domain, and if len(features.size) equal to 4, the order must be [B,C,H,W] + """ + super(WeaklySelector, self).__init__() + + self.num_select = num_select + + self.fpn_size = fpn_size + ### build classifier + if self.fpn_size is None: + self.num_classes = num_classes + for name in inputs: + fs_size = inputs[name].size() + if len(fs_size) == 3: + in_size = fs_size[2] + elif len(fs_size) == 4: + in_size = fs_size[1] + m = nn.Linear(in_size, num_classes) + self.add_module("classifier_l_"+name, m) + + self.thresholds = {} + for name in inputs: + self.thresholds[name] = [] + + # def select(self, logits, l_name): + # """ + # logits: [B, S, num_classes] + # """ + # probs = torch.softmax(logits, dim=-1) + # scores, _ = torch.max(probs, dim=-1) + # _, ids = torch.sort(scores, -1, descending=True) + # sn = self.num_select[l_name] + # s_ids = ids[:, :sn] + # not_s_ids = ids[:, sn:] + # return s_ids.unsqueeze(-1), not_s_ids.unsqueeze(-1) + + def forward(self, x, logits=None): + """ + x : + dictionary contain the features maps which + come from your choosen layers. + size must be [B, HxW, C] ([B, S, C]) or [B, C, H, W]. + [B,C,H,W] will be transpose to [B, HxW, C] automatically. + """ + if self.fpn_size is None: + logits = {} + selections = {} + for name in x: + # print("[selector]", name, x[name].size()) + if "FPN1_" in name: + continue + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + x[name] = x[name].view(B, C, H*W).permute(0, 2, 1).contiguous() + C = x[name].size(-1) + if self.fpn_size is None: + logits[name] = getattr(self, "classifier_l_"+name)(x[name]) + + probs = torch.softmax(logits[name], dim=-1) + sum_probs = torch.softmax(logits[name].mean(1), dim=-1) + selections[name] = [] + preds_1 = [] + preds_0 = [] + num_select = self.num_select[name] + for bi in range(logits[name].size(0)): + _, max_ids = torch.max(sum_probs[bi], dim=-1) + confs, ranks = torch.sort(probs[bi, :, max_ids], descending=True) + sf = x[name][bi][ranks[:num_select]] + nf = x[name][bi][ranks[num_select:]] # calculate + selections[name].append(sf) # [num_selected, C] + preds_1.append(logits[name][bi][ranks[:num_select]]) + preds_0.append(logits[name][bi][ranks[num_select:]]) + + if bi >= len(self.thresholds[name]): + self.thresholds[name].append(confs[num_select]) # for initialize + else: + self.thresholds[name][bi] = confs[num_select] + + selections[name] = torch.stack(selections[name]) + preds_1 = torch.stack(preds_1) + preds_0 = torch.stack(preds_0) + + logits["select_"+name] = preds_1 + logits["drop_"+name] = preds_0 + + return selections + + +class FPN(nn.Module): + + def __init__(self, inputs: dict, fpn_size: int, proj_type: str, upsample_type: str): + """ + inputs : dictionary contains torch.Tensor + which comes from backbone output + fpn_size: integer, fpn + proj_type: + in ["Conv", "Linear"] + upsample_type: + in ["Bilinear", "Conv", "Fc"] + for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. + for Vit, "Fc". and Swin-T, "Conv" + """ + super(FPN, self).__init__() + assert proj_type in ["Conv", "Linear"], \ + "FPN projection type {} were not support yet, please choose type 'Conv' or 'Linear'".format(proj_type) + assert upsample_type in ["Bilinear", "Conv"], \ + "FPN upsample type {} were not support yet, please choose type 'Bilinear' or 'Conv'".format(proj_type) + + self.fpn_size = fpn_size + self.upsample_type = upsample_type + inp_names = [name for name in inputs] + + for i, node_name in enumerate(inputs): + ### projection module + if proj_type == "Conv": + m = nn.Sequential( + nn.Conv2d(inputs[node_name].size(1), inputs[node_name].size(1), 1), + nn.ReLU(), + nn.Conv2d(inputs[node_name].size(1), fpn_size, 1) + ) + elif proj_type == "Linear": + in_feat = inputs[node_name] + if isinstance(in_feat, torch.Tensor): + dim = in_feat.size(-1) + else: + raise ValueError(f"EntrĂ©e invalide dans FPN: {type(in_feat)} pour node_name={node_name}") + + m = nn.Sequential( + nn.Linear(dim, dim), + nn.ReLU(), + nn.Linear(dim, fpn_size), + ) + + self.add_module("Proj_"+node_name, m) + + ### upsample module + if upsample_type == "Conv" and i != 0: + assert len(inputs[node_name].size()) == 3 # B, S, C + in_dim = inputs[node_name].size(1) + out_dim = inputs[inp_names[i-1]].size(1) + # if in_dim != out_dim: + m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain + # else: + # m = nn.Identity() + self.add_module("Up_"+node_name, m) + + if upsample_type == "Bilinear": + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') + + def upsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x1_name: str): + """ + return Upsample(x1) + x1 + """ + if self.upsample_type == "Bilinear": + if x1.size(-1) != x0.size(-1): + x1 = self.upsample(x1) + else: + x1 = getattr(self, "Up_"+x1_name)(x1) + return x1 + x0 + + def forward(self, x): + """ + x : dictionary + { + "node_name1": feature1, + "node_name2": feature2, ... + } + """ + ### project to same dimension + hs = [] + for i, name in enumerate(x): + if "FPN1_" in name: + continue + x[name] = getattr(self, "Proj_"+name)(x[name]) + hs.append(name) + + x["FPN1_" + "layer4"] = x["layer4"] + + for i in range(len(hs)-1, 0, -1): + x1_name = hs[i] + x0_name = hs[i-1] + x[x0_name] = self.upsample_add(x[x0_name], + x[x1_name], + x1_name) + x["FPN1_" + x0_name] = x[x0_name] + + return x + + +class FPN_UP(nn.Module): + + def __init__(self, + inputs: dict, + fpn_size: int): + super(FPN_UP, self).__init__() + + inp_names = [name for name in inputs] + + for i, node_name in enumerate(inputs): + ### projection module + m = nn.Sequential( + nn.Linear(fpn_size, fpn_size), + nn.ReLU(), + nn.Linear(fpn_size, fpn_size), + ) + self.add_module("Proj_"+node_name, m) + + ### upsample module + if i != (len(inputs) - 1): + assert len(inputs[node_name].size()) == 3 # B, S, C + in_dim = inputs[node_name].size(1) + out_dim = inputs[inp_names[i+1]].size(1) + m = nn.Conv1d(in_dim, out_dim, 1) # for spatial domain + self.add_module("Down_"+node_name, m) + # print("Down_"+node_name, in_dim, out_dim) + """ + Down_layer1 2304 576 + Down_layer2 576 144 + Down_layer3 144 144 + """ + + def downsample_add(self, x0: torch.Tensor, x1: torch.Tensor, x0_name: str): + """ + return Upsample(x1) + x1 + """ + # print("[downsample_add] Down_" + x0_name) + x0 = getattr(self, "Down_" + x0_name)(x0) + return x1 + x0 + + def forward(self, x): + """ + x : dictionary + { + "node_name1": feature1, + "node_name2": feature2, ... + } + """ + ### project to same dimension + hs = [] + for i, name in enumerate(x): + if "FPN1_" in name: + continue + x[name] = getattr(self, "Proj_"+name)(x[name]) + hs.append(name) + + # print(hs) + for i in range(0, len(hs) - 1): + x0_name = hs[i] + x1_name = hs[i+1] + # print(x0_name, x1_name) + # print(x[x0_name].size(), x[x1_name].size()) + x[x1_name] = self.downsample_add(x[x0_name], + x[x1_name], + x0_name) + return x + + + + +class PluginMoodel(nn.Module): + + def __init__(self, + backbone: torch.nn.Module, + return_nodes: Union[dict, None], + img_size: int, + use_fpn: bool, + fpn_size: Union[int, None], + proj_type: str, + upsample_type: str, + use_selection: bool, + num_classes: int, + num_selects: dict, + use_combiner: bool, + comb_proj_size: Union[int, None] + ): + """ + * backbone: + torch.nn.Module class (recommand pretrained on ImageNet or IG-3.5B-17k(provided by FAIR)) + * return_nodes: + e.g. + return_nodes = { + # node_name: user-specified key for output dict + 'layer1.2.relu_2': 'layer1', + 'layer2.3.relu_2': 'layer2', + 'layer3.5.relu_2': 'layer3', + 'layer4.2.relu_2': 'layer4', + } # you can see the example on https://pytorch.org/vision/main/feature_extraction.html + !!! if using 'Swin-Transformer', please set return_nodes to None + !!! and please set use_fpn to True + * feat_sizes: + tuple or list contain features map size of each layers. + ((C, H, W)). e.g. ((1024, 14, 14), (2048, 7, 7)) + * use_fpn: + boolean, use features pyramid network or not + * fpn_size: + integer, features pyramid network projection dimension + * num_selects: + num_selects = { + # match user-specified in return_nodes + "layer1": 2048, + "layer2": 512, + "layer3": 128, + "layer4": 32, + } + Note: after selector module (WeaklySelector) , the feature map's size is [B, S', C] which + contained by 'logits' or 'selections' dictionary (S' is selection number, different layer + could be different). + """ + super(PluginMoodel, self).__init__() + + ### = = = = = Backbone = = = = = + self.return_nodes = return_nodes + if return_nodes is not None: + self.backbone = create_feature_extractor(backbone, return_nodes=return_nodes) + else: + self.backbone = backbone + + ### get hidden feartues size + rand_in = torch.randn(1, 3, img_size, img_size) + outs = self.backbone(rand_in) + + ### just original backbone + if not use_fpn and (not use_selection and not use_combiner): + for name in outs: + fs_size = outs[name].size() + if len(fs_size) == 3: + out_size = fs_size.size(-1) + elif len(fs_size) == 4: + out_size = fs_size.size(1) + else: + raise ValusError("The size of output dimension of previous must be 3 or 4.") + self.classifier = nn.Linear(out_size, num_classes) + + ### = = = = = FPN = = = = = + self.use_fpn = use_fpn + if self.use_fpn: + self.fpn_down = FPN(outs, fpn_size, proj_type, upsample_type) + self.build_fpn_classifier_down(outs, fpn_size, num_classes) + self.fpn_up = FPN_UP(outs, fpn_size) + self.build_fpn_classifier_up(outs, fpn_size, num_classes) + + self.fpn_size = fpn_size + + ### = = = = = Selector = = = = = + self.use_selection = use_selection + if self.use_selection: + w_fpn_size = self.fpn_size if self.use_fpn else None # if not using fpn, build classifier in weakly selector + self.selector = WeaklySelector(outs, num_classes, num_selects, w_fpn_size) + + ### = = = = = Combiner = = = = = + self.use_combiner = use_combiner + if self.use_combiner: + assert self.use_selection, "Please use selection module before combiner" + if self.use_fpn: + gcn_inputs, gcn_proj_size = None, None + else: + gcn_inputs, gcn_proj_size = outs, comb_proj_size # redundant, fix in future + total_num_selects = sum([num_selects[name] for name in num_selects]) # sum + self.combiner = GCNCombiner(total_num_selects, num_classes, gcn_inputs, gcn_proj_size, self.fpn_size) + + def build_fpn_classifier_up(self, inputs: dict, fpn_size: int, num_classes: int): + """ + Teh results of our experiments show that linear classifier in this case may cause some problem. + """ + for name in inputs: + m = nn.Sequential( + nn.Conv1d(fpn_size, fpn_size, 1), + nn.BatchNorm1d(fpn_size), + nn.ReLU(), + nn.Conv1d(fpn_size, num_classes, 1) + ) + self.add_module("fpn_classifier_up_"+name, m) + + def build_fpn_classifier_down(self, inputs: dict, fpn_size: int, num_classes: int): + """ + Teh results of our experiments show that linear classifier in this case may cause some problem. + """ + for name in inputs: + m = nn.Sequential( + nn.Conv1d(fpn_size, fpn_size, 1), + nn.BatchNorm1d(fpn_size), + nn.ReLU(), + nn.Conv1d(fpn_size, num_classes, 1) + ) + self.add_module("fpn_classifier_down_" + name, m) + + def forward_backbone(self, x): + return self.backbone(x) + + def fpn_predict_down(self, x: dict, logits: dict): + """ + x: [B, C, H, W] or [B, S, C] + [B, C, H, W] --> [B, H*W, C] + """ + for name in x: + if "FPN1_" not in name: + continue + ### predict on each features point + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + logit = x[name].view(B, C, H*W) + elif len(x[name].size()) == 3: + logit = x[name].transpose(1, 2).contiguous() + model_name = name.replace("FPN1_", "") + logits[name] = getattr(self, "fpn_classifier_down_" + model_name)(logit) + logits[name] = logits[name].transpose(1, 2).contiguous() # transpose + + def fpn_predict_up(self, x: dict, logits: dict): + """ + x: [B, C, H, W] or [B, S, C] + [B, C, H, W] --> [B, H*W, C] + """ + for name in x: + if "FPN1_" in name: + continue + ### predict on each features point + if len(x[name].size()) == 4: + B, C, H, W = x[name].size() + logit = x[name].view(B, C, H*W) + elif len(x[name].size()) == 3: + logit = x[name].transpose(1, 2).contiguous() + model_name = name.replace("FPN1_", "") + logits[name] = getattr(self, "fpn_classifier_up_" + model_name)(logit) + logits[name] = logits[name].transpose(1, 2).contiguous() # transpose + + def forward(self, x: torch.Tensor): + + logits = {} + + x = self.forward_backbone(x) + + if self.use_fpn: + x = self.fpn_down(x) + # print([name for name in x]) + self.fpn_predict_down(x, logits) + x = self.fpn_up(x) + self.fpn_predict_up(x, logits) + + if self.use_selection: + selects = self.selector(x, logits) + + if self.use_combiner: + comb_outs = self.combiner(selects) + logits['comb_outs'] = comb_outs + return logits + + if self.use_selection or self.fpn: + return logits + + ### original backbone (only predict final selected layer) + for name in x: + hs = x[name] + + if len(hs.size()) == 4: + hs = F.adaptive_avg_pool2d(hs, (1, 1)) + hs = hs.flatten(1) + else: + hs = hs.mean(1) + out = self.classifier(hs) + logits['ori_out'] = logits + + return diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3edea8ec7fc7a65415557bdd090045a23877f813 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +gdown +gradio==4.25.0 +torch +torchvision +timm +opencv-python +numpy +pillow +requests diff --git a/timm/__init__.py b/timm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04ec7e51b858e528009bbeea6c75af5985aef202 --- /dev/null +++ b/timm/__init__.py @@ -0,0 +1,4 @@ +from .version import __version__ +from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \ + is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \ + get_model_default_value, is_model_pretrained diff --git a/timm/__pycache__/__init__.cpython-310.pyc b/timm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63ffd2dc6695a62a50044cd568199c5b97519f26 Binary files /dev/null and b/timm/__pycache__/__init__.cpython-310.pyc differ diff --git a/timm/__pycache__/__init__.cpython-311.pyc b/timm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f328b07a7c7a6e8849f116bab548c45650b43684 Binary files /dev/null and b/timm/__pycache__/__init__.cpython-311.pyc differ diff --git a/timm/__pycache__/__init__.cpython-313.pyc b/timm/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d33799447f4a9ffb7ad24140203105612360e60 Binary files /dev/null and b/timm/__pycache__/__init__.cpython-313.pyc differ diff --git a/timm/__pycache__/__init__.cpython-38.pyc b/timm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7023d11c41022fcc7fb381ad9f06ed5696c57eb5 Binary files /dev/null and b/timm/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/__pycache__/version.cpython-310.pyc b/timm/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..813f545096b90cb9ca7d0780f96d6b6c356e779d Binary files /dev/null and b/timm/__pycache__/version.cpython-310.pyc differ diff --git a/timm/__pycache__/version.cpython-311.pyc b/timm/__pycache__/version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d0a1be7f36f4fc6bc2cb3a3459ae564be4c4680 Binary files /dev/null and b/timm/__pycache__/version.cpython-311.pyc differ diff --git a/timm/__pycache__/version.cpython-313.pyc b/timm/__pycache__/version.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63adef5f633eb0f5c596578ab7b2970632cd284a Binary files /dev/null and b/timm/__pycache__/version.cpython-313.pyc differ diff --git a/timm/__pycache__/version.cpython-38.pyc b/timm/__pycache__/version.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..853f15d7ee641ce47081b21c7cfe1e1f81773729 Binary files /dev/null and b/timm/__pycache__/version.cpython-38.pyc differ diff --git a/timm/data/__init__.py b/timm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3cb2b4d7e823aabb1d55781149579eeb94b024 --- /dev/null +++ b/timm/data/__init__.py @@ -0,0 +1,12 @@ +from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ + rand_augment_transform, auto_augment_transform +from .config import resolve_data_config +from .constants import * +from .dataset import ImageDataset, IterableImageDataset, AugMixDataset +from .dataset_factory import create_dataset +from .loader import create_loader +from .mixup import Mixup, FastCollateMixup +from .parsers import create_parser +from .real_labels import RealLabelsImagenet +from .transforms import * +from .transforms_factory import create_transform \ No newline at end of file diff --git a/timm/data/__pycache__/__init__.cpython-310.pyc b/timm/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f57ce0e03c8c902a0052dde4fb1499347120be44 Binary files /dev/null and b/timm/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/timm/data/__pycache__/__init__.cpython-311.pyc b/timm/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6856aceb0075363ed4f4473b6e230e5e415e49ac Binary files /dev/null and b/timm/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/timm/data/__pycache__/__init__.cpython-313.pyc b/timm/data/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46267dcdca39f18497a44ffb0040e8ae64ce0c4f Binary files /dev/null and b/timm/data/__pycache__/__init__.cpython-313.pyc differ diff --git a/timm/data/__pycache__/__init__.cpython-38.pyc b/timm/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0202237a12deced6a7acc75b6576f038edc5756a Binary files /dev/null and b/timm/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/data/__pycache__/auto_augment.cpython-310.pyc b/timm/data/__pycache__/auto_augment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f08652dbe654ebfbed04d2c6bbd9b926706c9da7 Binary files /dev/null and b/timm/data/__pycache__/auto_augment.cpython-310.pyc differ diff --git a/timm/data/__pycache__/auto_augment.cpython-311.pyc b/timm/data/__pycache__/auto_augment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb27d9d9e6b38344e6002e21061a49b83eace58e Binary files /dev/null and b/timm/data/__pycache__/auto_augment.cpython-311.pyc differ diff --git a/timm/data/__pycache__/auto_augment.cpython-313.pyc b/timm/data/__pycache__/auto_augment.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55f48ba13643a08e46faba64331c427e87d8c09d Binary files /dev/null and b/timm/data/__pycache__/auto_augment.cpython-313.pyc differ diff --git a/timm/data/__pycache__/auto_augment.cpython-38.pyc b/timm/data/__pycache__/auto_augment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ce32b9a540bd2913174a46f79f879e64b361aef Binary files /dev/null and b/timm/data/__pycache__/auto_augment.cpython-38.pyc differ diff --git a/timm/data/__pycache__/config.cpython-310.pyc b/timm/data/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47488b22271a133a18c68dd6df04caa22796feb Binary files /dev/null and b/timm/data/__pycache__/config.cpython-310.pyc differ diff --git a/timm/data/__pycache__/config.cpython-311.pyc b/timm/data/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fc928eed60f13f0bb22568baedeabeeeed838d4 Binary files /dev/null and b/timm/data/__pycache__/config.cpython-311.pyc differ diff --git a/timm/data/__pycache__/config.cpython-313.pyc b/timm/data/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5891f63b2d701d544530ce220fbab91a9b35a3f8 Binary files /dev/null and b/timm/data/__pycache__/config.cpython-313.pyc differ diff --git a/timm/data/__pycache__/config.cpython-38.pyc b/timm/data/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6ba81b22dae11f00e02ff60e003300b728988a Binary files /dev/null and b/timm/data/__pycache__/config.cpython-38.pyc differ diff --git a/timm/data/__pycache__/constants.cpython-310.pyc b/timm/data/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1748baf0f7521179d960a0232234d7740bc6ac4 Binary files /dev/null and b/timm/data/__pycache__/constants.cpython-310.pyc differ diff --git a/timm/data/__pycache__/constants.cpython-311.pyc b/timm/data/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00afeb758f163ebd37527182687d5c20c68905b9 Binary files /dev/null and b/timm/data/__pycache__/constants.cpython-311.pyc differ diff --git a/timm/data/__pycache__/constants.cpython-313.pyc b/timm/data/__pycache__/constants.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9614d867481d315a2771d257fc15d5074cba4da8 Binary files /dev/null and b/timm/data/__pycache__/constants.cpython-313.pyc differ diff --git a/timm/data/__pycache__/constants.cpython-38.pyc b/timm/data/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e731a1e2e0b7ea2abbff904a8901e8106b71aa9c Binary files /dev/null and b/timm/data/__pycache__/constants.cpython-38.pyc differ diff --git a/timm/data/__pycache__/dataset.cpython-310.pyc b/timm/data/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95973ea6ac0796ddecbed9836d658181d09b0cec Binary files /dev/null and b/timm/data/__pycache__/dataset.cpython-310.pyc differ diff --git a/timm/data/__pycache__/dataset.cpython-311.pyc b/timm/data/__pycache__/dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92f9f8323fa273611e596e226d2520915899d5eb Binary files /dev/null and b/timm/data/__pycache__/dataset.cpython-311.pyc differ diff --git a/timm/data/__pycache__/dataset.cpython-313.pyc b/timm/data/__pycache__/dataset.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aca20c1850d0419817abfa8a082c1d5dc090b5d6 Binary files /dev/null and b/timm/data/__pycache__/dataset.cpython-313.pyc differ diff --git a/timm/data/__pycache__/dataset.cpython-38.pyc b/timm/data/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d04841ba5bdf487b3602ceea84ed791c85b0088 Binary files /dev/null and b/timm/data/__pycache__/dataset.cpython-38.pyc differ diff --git a/timm/data/__pycache__/dataset_factory.cpython-310.pyc b/timm/data/__pycache__/dataset_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..933a00da1d70b111924880b9ac6407be7aa8b7c0 Binary files /dev/null and b/timm/data/__pycache__/dataset_factory.cpython-310.pyc differ diff --git a/timm/data/__pycache__/dataset_factory.cpython-311.pyc b/timm/data/__pycache__/dataset_factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c0d1d6f54812fad18b20acd02725251783a4a0d Binary files /dev/null and b/timm/data/__pycache__/dataset_factory.cpython-311.pyc differ diff --git a/timm/data/__pycache__/dataset_factory.cpython-313.pyc b/timm/data/__pycache__/dataset_factory.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654a0ea72303a3fe04660fccdb4e1e847268858b Binary files /dev/null and b/timm/data/__pycache__/dataset_factory.cpython-313.pyc differ diff --git a/timm/data/__pycache__/dataset_factory.cpython-38.pyc b/timm/data/__pycache__/dataset_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d570ca98602a60eb19c890ff9eef2682ecbb7d6 Binary files /dev/null and b/timm/data/__pycache__/dataset_factory.cpython-38.pyc differ diff --git a/timm/data/__pycache__/distributed_sampler.cpython-310.pyc b/timm/data/__pycache__/distributed_sampler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14c01aa4699129d6213b4ccd78110f09ca65e86e Binary files /dev/null and b/timm/data/__pycache__/distributed_sampler.cpython-310.pyc differ diff --git a/timm/data/__pycache__/distributed_sampler.cpython-311.pyc b/timm/data/__pycache__/distributed_sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b5edf80960b96ee28a8562d7dc29281d81e3da4 Binary files /dev/null and b/timm/data/__pycache__/distributed_sampler.cpython-311.pyc differ diff --git a/timm/data/__pycache__/distributed_sampler.cpython-313.pyc b/timm/data/__pycache__/distributed_sampler.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2144f927f79640411be77f22bb8350cea38f4d1 Binary files /dev/null and b/timm/data/__pycache__/distributed_sampler.cpython-313.pyc differ diff --git a/timm/data/__pycache__/distributed_sampler.cpython-38.pyc b/timm/data/__pycache__/distributed_sampler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9292be99c6921e1572225179c7cc6827d8fdaf36 Binary files /dev/null and b/timm/data/__pycache__/distributed_sampler.cpython-38.pyc differ diff --git a/timm/data/__pycache__/loader.cpython-310.pyc b/timm/data/__pycache__/loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82c5a44f4a114109d0bfaf5620c31bbd03e0cffc Binary files /dev/null and b/timm/data/__pycache__/loader.cpython-310.pyc differ diff --git a/timm/data/__pycache__/loader.cpython-311.pyc b/timm/data/__pycache__/loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05c6bc53f2cb5ca492c84730bfc5dde54d279370 Binary files /dev/null and b/timm/data/__pycache__/loader.cpython-311.pyc differ diff --git a/timm/data/__pycache__/loader.cpython-313.pyc b/timm/data/__pycache__/loader.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd05fa6f642875bb4c7fba83bb94071d1ac2f28c Binary files /dev/null and b/timm/data/__pycache__/loader.cpython-313.pyc differ diff --git a/timm/data/__pycache__/loader.cpython-38.pyc b/timm/data/__pycache__/loader.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb45374c0bade6175fb3cb09315803f8f870ca9d Binary files /dev/null and b/timm/data/__pycache__/loader.cpython-38.pyc differ diff --git a/timm/data/__pycache__/mixup.cpython-310.pyc b/timm/data/__pycache__/mixup.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70322cd601d0343966d8e5ed91196b6258afce63 Binary files /dev/null and b/timm/data/__pycache__/mixup.cpython-310.pyc differ diff --git a/timm/data/__pycache__/mixup.cpython-311.pyc b/timm/data/__pycache__/mixup.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8781d9d4a15823f8c8be303d0a7efde175f4bb01 Binary files /dev/null and b/timm/data/__pycache__/mixup.cpython-311.pyc differ diff --git a/timm/data/__pycache__/mixup.cpython-313.pyc b/timm/data/__pycache__/mixup.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001a5b2b4962a2f8945ebefe3436d867d0f30857 Binary files /dev/null and b/timm/data/__pycache__/mixup.cpython-313.pyc differ diff --git a/timm/data/__pycache__/mixup.cpython-38.pyc b/timm/data/__pycache__/mixup.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec494fb6ab52d0865fab81920559e42101465be Binary files /dev/null and b/timm/data/__pycache__/mixup.cpython-38.pyc differ diff --git a/timm/data/__pycache__/random_erasing.cpython-310.pyc b/timm/data/__pycache__/random_erasing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44e76c19c743587a32c6a5cfe5b0c947f37f4727 Binary files /dev/null and b/timm/data/__pycache__/random_erasing.cpython-310.pyc differ diff --git a/timm/data/__pycache__/random_erasing.cpython-311.pyc b/timm/data/__pycache__/random_erasing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..432d8008fc3e2695306c0d162d9dabdcc2cc89e4 Binary files /dev/null and b/timm/data/__pycache__/random_erasing.cpython-311.pyc differ diff --git a/timm/data/__pycache__/random_erasing.cpython-313.pyc b/timm/data/__pycache__/random_erasing.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d28a12b5bf8d6f28f54347898c260869d89fd1f Binary files /dev/null and b/timm/data/__pycache__/random_erasing.cpython-313.pyc differ diff --git a/timm/data/__pycache__/random_erasing.cpython-38.pyc b/timm/data/__pycache__/random_erasing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08f000882f67a599f194db651d59d60a4a200461 Binary files /dev/null and b/timm/data/__pycache__/random_erasing.cpython-38.pyc differ diff --git a/timm/data/__pycache__/real_labels.cpython-310.pyc b/timm/data/__pycache__/real_labels.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..608edb9c7e2d2789ec57f4007b2ef2d65b7fdba5 Binary files /dev/null and b/timm/data/__pycache__/real_labels.cpython-310.pyc differ diff --git a/timm/data/__pycache__/real_labels.cpython-311.pyc b/timm/data/__pycache__/real_labels.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bc14c7530cec669c9a7996731a1427934115bea Binary files /dev/null and b/timm/data/__pycache__/real_labels.cpython-311.pyc differ diff --git a/timm/data/__pycache__/real_labels.cpython-313.pyc b/timm/data/__pycache__/real_labels.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28a11cf8efb42a27639fb9d1eb1b5cd50f9f11b5 Binary files /dev/null and b/timm/data/__pycache__/real_labels.cpython-313.pyc differ diff --git a/timm/data/__pycache__/real_labels.cpython-38.pyc b/timm/data/__pycache__/real_labels.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04df396104a0d76909140c9d1e19763c151c3779 Binary files /dev/null and b/timm/data/__pycache__/real_labels.cpython-38.pyc differ diff --git a/timm/data/__pycache__/tf_preprocessing.cpython-38.pyc b/timm/data/__pycache__/tf_preprocessing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..854fcba64e87c60ff0c8c38ba10ae75da0771490 Binary files /dev/null and b/timm/data/__pycache__/tf_preprocessing.cpython-38.pyc differ diff --git a/timm/data/__pycache__/transforms.cpython-310.pyc b/timm/data/__pycache__/transforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b25133c0f18d6e789d1047473502d58585200443 Binary files /dev/null and b/timm/data/__pycache__/transforms.cpython-310.pyc differ diff --git a/timm/data/__pycache__/transforms.cpython-311.pyc b/timm/data/__pycache__/transforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5478091605507847283f586d0d61190218da0ae0 Binary files /dev/null and b/timm/data/__pycache__/transforms.cpython-311.pyc differ diff --git a/timm/data/__pycache__/transforms.cpython-313.pyc b/timm/data/__pycache__/transforms.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d46a8a0ca9c42ccc57ccd3fe1f86d6b83b67c6f7 Binary files /dev/null and b/timm/data/__pycache__/transforms.cpython-313.pyc differ diff --git a/timm/data/__pycache__/transforms.cpython-38.pyc b/timm/data/__pycache__/transforms.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3737fcaba8e3146e122d5318d671c48eab3194e9 Binary files /dev/null and b/timm/data/__pycache__/transforms.cpython-38.pyc differ diff --git a/timm/data/__pycache__/transforms_factory.cpython-310.pyc b/timm/data/__pycache__/transforms_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f45162490bc43bc8e0c9097ff6e7e2ee78cdc154 Binary files /dev/null and b/timm/data/__pycache__/transforms_factory.cpython-310.pyc differ diff --git a/timm/data/__pycache__/transforms_factory.cpython-311.pyc b/timm/data/__pycache__/transforms_factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf4dadd24a224654d018ee7a37ebfded4077d4f8 Binary files /dev/null and b/timm/data/__pycache__/transforms_factory.cpython-311.pyc differ diff --git a/timm/data/__pycache__/transforms_factory.cpython-313.pyc b/timm/data/__pycache__/transforms_factory.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0469ded69d0c9dbfd2f251e905159e6ccfedfdc6 Binary files /dev/null and b/timm/data/__pycache__/transforms_factory.cpython-313.pyc differ diff --git a/timm/data/__pycache__/transforms_factory.cpython-38.pyc b/timm/data/__pycache__/transforms_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3c847fb2efd91507e68f8f095acd8135c67868 Binary files /dev/null and b/timm/data/__pycache__/transforms_factory.cpython-38.pyc differ diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..7cbd2dee0a2b2594bf2d229d5c440a268c2e752d --- /dev/null +++ b/timm/data/auto_augment.py @@ -0,0 +1,822 @@ +""" AutoAugment, RandAugment, and AugMix for PyTorch + +This code implements the searched ImageNet policies with various tweaks and improvements and +does not include any of the search code. + +AA and RA Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py + +AugMix adapted from: + https://github.com/google-research/augmix + +Papers: + AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import random +import math +import re +from PIL import Image, ImageOps, ImageEnhance, ImageChops +import PIL +import numpy as np + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return level, + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return (level / _MAX_LEVEL) * 1.8 + 0.1, + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * .9 + level = 1.0 + _randomly_negate(level) + return level, + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return level, + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return level, + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get('translate_pct', 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return level, + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4), + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return 4 - _posterize_level_to_arg(level, hparams)[0], + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 4) + 4, + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return int((level / _MAX_LEVEL) * 256), + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return 256 - _solarize_level_to_arg(level, _hparams)[0], + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return int((level / _MAX_LEVEL) * 110), + + +LEVEL_TO_ARG = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Rotate': _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'Posterize': _posterize_level_to_arg, + 'PosterizeIncreasing': _posterize_increasing_level_to_arg, + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeIncreasing': _solarize_increasing_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'ColorIncreasing': _enhance_increasing_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'ContrastIncreasing': _enhance_increasing_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'BrightnessIncreasing': _enhance_increasing_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'SharpnessIncreasing': _enhance_increasing_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Rotate': rotate, + 'Posterize': posterize, + 'PosterizeIncreasing': posterize, + 'PosterizeOriginal': posterize, + 'Solarize': solarize, + 'SolarizeIncreasing': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'ColorIncreasing': color, + 'Contrast': contrast, + 'ContrastIncreasing': contrast, + 'Brightness': brightness, + 'BrightnessIncreasing': brightness, + 'Sharpness': sharpness, + 'SharpnessIncreasing': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AugmentOp: + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + # If magnitude_std is inf, we sample magnitude from a uniform distribution + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if self.prob < 1.0 and random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std: + if self.magnitude_std == float('inf'): + magnitude = random.uniform(0, magnitude) + elif self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +def auto_augment_policy_v0(hparams): + # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference. + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_v0r(hparams): + # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used + # in Google research implementation (number of bits discarded increases with magnitude) + policy = [ + [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)], + [('Color', 0.4, 9), ('Equalize', 0.6, 3)], + [('Color', 0.4, 1), ('Rotate', 0.6, 8)], + [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)], + [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)], + [('Color', 0.2, 0), ('Equalize', 0.8, 8)], + [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)], + [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)], + [('Color', 0.6, 1), ('Equalize', 1.0, 2)], + [('Invert', 0.4, 9), ('Rotate', 0.6, 0)], + [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)], + [('Color', 0.4, 7), ('Equalize', 0.6, 0)], + [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)], + [('Solarize', 0.6, 8), ('Color', 0.6, 9)], + [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)], + [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)], + [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)], + [('ShearY', 0.8, 0), ('Color', 0.6, 4)], + [('Color', 1.0, 0), ('Rotate', 0.6, 2)], + [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)], + [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)], + [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)], + [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)], + [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)], + [('Color', 0.8, 6), ('Rotate', 0.4, 5)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_original(hparams): + # ImageNet policy from https://arxiv.org/abs/1805.09501 + policy = [ + [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy_originalr(hparams): + # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation + policy = [ + [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], + [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)], + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], + [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)], + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], + ] + pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy] + return pc + + +def auto_augment_policy(name='v0', hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + if name == 'original': + return auto_augment_policy_original(hparams) + elif name == 'originalr': + return auto_augment_policy_originalr(hparams) + elif name == 'v0': + return auto_augment_policy_v0(hparams) + elif name == 'v0r': + return auto_augment_policy_v0r(hparams) + else: + assert False, 'Unknown AA policy (%s)' % name + + +class AutoAugment: + + def __init__(self, policy): + self.policy = policy + + def __call__(self, img): + sub_policy = random.choice(self.policy) + for op in sub_policy: + img = op(img) + return img + + +def auto_augment_transform(config_str, hparams): + """ + Create a AutoAugment transform + + :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr'). + The remaining sections, not order sepecific determine + 'mstd' - float std deviation of magnitude noise applied + Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5 + + :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme + + :return: A PyTorch compatible Transform + """ + config = config_str.split('-') + policy_name = config[0] + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + else: + assert False, 'Unknown AutoAugment config section' + aa_policy = auto_augment_policy(policy_name, hparams=hparams) + return AutoAugment(aa_policy) + + +_RAND_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'Posterize', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # NOTE I've implement this as random erasing separately +] + + +_RAND_INCREASING_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'SolarizeAdd', + 'ColorIncreasing', + 'ContrastIncreasing', + 'BrightnessIncreasing', + 'SharpnessIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # NOTE I've implement this as random erasing separately +] + + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'Posterize': 0, + 'Invert': 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + Create a RandAugment transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'inc': + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) + + +_AUGMIX_TRANSFORMS = [ + 'AutoContrast', + 'ColorIncreasing', # not in paper + 'ContrastIncreasing', # not in paper + 'BrightnessIncreasing', # not in paper + 'SharpnessIncreasing', # not in paper + 'Equalize', + 'Rotate', + 'PosterizeIncreasing', + 'SolarizeIncreasing', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', +] + + +def augmix_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _AUGMIX_TRANSFORMS + return [AugmentOp( + name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class AugMixAugment: + """ AugMix Transform + Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + """ + def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): + self.ops = ops + self.alpha = alpha + self.width = width + self.depth = depth + self.blended = blended # blended mode is faster but not well tested + + def _calc_blended_weights(self, ws, m): + ws = ws * m + cump = 1. + rws = [] + for w in ws[::-1]: + alpha = w / cump + cump *= (1 - alpha) + rws.append(alpha) + return np.array(rws[::-1], dtype=np.float32) + + def _apply_blended(self, img, mixing_weights, m): + # This is my first crack and implementing a slightly faster mixed augmentation. Instead + # of accumulating the mix for each chain in a Numpy array and then blending with original, + # it recomputes the blending coefficients and applies one PIL image blend per chain. + # TODO the results appear in the right ballpark but they differ by more than rounding. + img_orig = img.copy() + ws = self._calc_blended_weights(mixing_weights, m) + for w in ws: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img_orig # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + img = Image.blend(img, img_aug, w) + return img + + def _apply_basic(self, img, mixing_weights, m): + # This is a literal adaptation of the paper/official implementation without normalizations and + # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the + # typical augmentation transforms, could use a GPU / Kornia implementation. + img_shape = img.size[0], img.size[1], len(img.getbands()) + mixed = np.zeros(img_shape, dtype=np.float32) + for mw in mixing_weights: + depth = self.depth if self.depth > 0 else np.random.randint(1, 4) + ops = np.random.choice(self.ops, depth, replace=True) + img_aug = img # no ops are in-place, deep copy not necessary + for op in ops: + img_aug = op(img_aug) + mixed += mw * np.asarray(img_aug, dtype=np.float32) + np.clip(mixed, 0, 255., out=mixed) + mixed = Image.fromarray(mixed.astype(np.uint8)) + return Image.blend(img, mixed, m) + + def __call__(self, img): + mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width)) + m = np.float32(np.random.beta(self.alpha, self.alpha)) + if self.blended: + mixed = self._apply_blended(img, mixing_weights, m) + else: + mixed = self._apply_basic(img, mixing_weights, m) + return mixed + + +def augment_and_mix_transform(config_str, hparams): + """ Create AugMix PyTorch transform + + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude (severity) of augmentation mix (default: 3) + 'w' - integer width of augmentation chain (default: 3) + 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1) + 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0) + 'mstd' - float std deviation of magnitude noise applied (default: 0) + Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2 + + :param hparams: Other hparams (kwargs) for the Augmentation transforms + + :return: A PyTorch compatible Transform + """ + magnitude = 3 + width = 3 + depth = -1 + alpha = 1. + blended = False + hparams['magnitude_std'] = float('inf') + config = config_str.split('-') + assert config[0] == 'augmix' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'w': + width = int(val) + elif key == 'd': + depth = int(val) + elif key == 'a': + alpha = float(val) + elif key == 'b': + blended = bool(val) + else: + assert False, 'Unknown AugMix config section' + ops = augmix_ops(magnitude=magnitude, hparams=hparams) + return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended) diff --git a/timm/data/config.py b/timm/data/config.py new file mode 100644 index 0000000000000000000000000000000000000000..38f5689a707f5602e38cb717ed6115f26a0d7ea2 --- /dev/null +++ b/timm/data/config.py @@ -0,0 +1,78 @@ +import logging +from .constants import * + + +_logger = logging.getLogger(__name__) + + +def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): + new_config = {} + default_cfg = default_cfg + if not default_cfg and model is not None and hasattr(model, 'default_cfg'): + default_cfg = model.default_cfg + + # Resolve input/image size + in_chans = 3 + if 'chans' in args and args['chans'] is not None: + in_chans = args['chans'] + + input_size = (in_chans, 224, 224) + if 'input_size' in args and args['input_size'] is not None: + assert isinstance(args['input_size'], (tuple, list)) + assert len(args['input_size']) == 3 + input_size = tuple(args['input_size']) + in_chans = input_size[0] # input_size overrides in_chans + elif 'img_size' in args and args['img_size'] is not None: + assert isinstance(args['img_size'], int) + input_size = (in_chans, args['img_size'], args['img_size']) + else: + if use_test_size and 'test_input_size' in default_cfg: + input_size = default_cfg['test_input_size'] + elif 'input_size' in default_cfg: + input_size = default_cfg['input_size'] + new_config['input_size'] = input_size + + # resolve interpolation method + new_config['interpolation'] = 'bicubic' + if 'interpolation' in args and args['interpolation']: + new_config['interpolation'] = args['interpolation'] + elif 'interpolation' in default_cfg: + new_config['interpolation'] = default_cfg['interpolation'] + + # resolve dataset + model mean for normalization + new_config['mean'] = IMAGENET_DEFAULT_MEAN + if 'mean' in args and args['mean'] is not None: + mean = tuple(args['mean']) + if len(mean) == 1: + mean = tuple(list(mean) * in_chans) + else: + assert len(mean) == in_chans + new_config['mean'] = mean + elif 'mean' in default_cfg: + new_config['mean'] = default_cfg['mean'] + + # resolve dataset + model std deviation for normalization + new_config['std'] = IMAGENET_DEFAULT_STD + if 'std' in args and args['std'] is not None: + std = tuple(args['std']) + if len(std) == 1: + std = tuple(list(std) * in_chans) + else: + assert len(std) == in_chans + new_config['std'] = std + elif 'std' in default_cfg: + new_config['std'] = default_cfg['std'] + + # resolve default crop percentage + new_config['crop_pct'] = DEFAULT_CROP_PCT + if 'crop_pct' in args and args['crop_pct'] is not None: + new_config['crop_pct'] = args['crop_pct'] + elif 'crop_pct' in default_cfg: + new_config['crop_pct'] = default_cfg['crop_pct'] + + if verbose: + _logger.info('Data processing configuration for current model + dataset:') + for n, v in new_config.items(): + _logger.info('\t%s: %s' % (n, str(v))) + + return new_config diff --git a/timm/data/constants.py b/timm/data/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d4a01b0316989a3f5142167f1e384b098bc930 --- /dev/null +++ b/timm/data/constants.py @@ -0,0 +1,7 @@ +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) diff --git a/timm/data/dataset.py b/timm/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e719f3f6d7db178eb29fd902e85b64ac5ec09dd8 --- /dev/null +++ b/timm/data/dataset.py @@ -0,0 +1,146 @@ +""" Quick n Simple Image Folder, Tarfile based DataSet + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.utils.data as data +import os +import torch +import logging + +from PIL import Image + +from .parsers import create_parser + +_logger = logging.getLogger(__name__) + + +_ERROR_RETRY = 50 + + +class ImageDataset(data.Dataset): + + def __init__( + self, + root, + parser=None, + class_map='', + load_bytes=False, + transform=None, + ): + if parser is None or isinstance(parser, str): + parser = create_parser(parser or '', root=root, class_map=class_map) + self.parser = parser + self.load_bytes = load_bytes + self.transform = transform + self._consecutive_errors = 0 + + def __getitem__(self, index): + img, target = self.parser[index] + try: + img = img.read() if self.load_bytes else Image.open(img).convert('RGB') + except Exception as e: + _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') + self._consecutive_errors += 1 + if self._consecutive_errors < _ERROR_RETRY: + return self.__getitem__((index + 1) % len(self.parser)) + else: + raise e + self._consecutive_errors = 0 + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.tensor(-1, dtype=torch.long) + return img, target + + def __len__(self): + return len(self.parser) + + def filename(self, index, basename=False, absolute=False): + return self.parser.filename(index, basename, absolute) + + def filenames(self, basename=False, absolute=False): + return self.parser.filenames(basename, absolute) + + +class IterableImageDataset(data.IterableDataset): + + def __init__( + self, + root, + parser=None, + split='train', + is_training=False, + batch_size=None, + class_map='', + load_bytes=False, + repeats=0, + transform=None, + ): + assert parser is not None + if isinstance(parser, str): + self.parser = create_parser( + parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) + else: + self.parser = parser + self.transform = transform + self._consecutive_errors = 0 + + def __iter__(self): + for img, target in self.parser: + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.tensor(-1, dtype=torch.long) + yield img, target + + def __len__(self): + if hasattr(self.parser, '__len__'): + return len(self.parser) + else: + return 0 + + def filename(self, index, basename=False, absolute=False): + assert False, 'Filename lookup by index not supported, use filenames().' + + def filenames(self, basename=False, absolute=False): + return self.parser.filenames(basename, absolute) + + +class AugMixDataset(torch.utils.data.Dataset): + """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" + + def __init__(self, dataset, num_splits=2): + self.augmentation = None + self.normalize = None + self.dataset = dataset + if self.dataset.transform is not None: + self._set_transforms(self.dataset.transform) + self.num_splits = num_splits + + def _set_transforms(self, x): + assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' + self.dataset.transform = x[0] + self.augmentation = x[1] + self.normalize = x[2] + + @property + def transform(self): + return self.dataset.transform + + @transform.setter + def transform(self, x): + self._set_transforms(x) + + def _normalize(self, x): + return x if self.normalize is None else self.normalize(x) + + def __getitem__(self, i): + x, y = self.dataset[i] # all splits share the same dataset base transform + x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) + # run the full augmentation on the remaining splits + for _ in range(self.num_splits - 1): + x_list.append(self._normalize(self.augmentation(x))) + return tuple(x_list), y + + def __len__(self): + return len(self.dataset) diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc99d5c2c19b480a30cad74dacccceff24df61e --- /dev/null +++ b/timm/data/dataset_factory.py @@ -0,0 +1,30 @@ +import os + +from .dataset import IterableImageDataset, ImageDataset + + +def _search_split(root, split): + # look for sub-folder with name of split in root and use that if it exists + split_name = split.split('[')[0] + try_root = os.path.join(root, split_name) + if os.path.exists(try_root): + return try_root + if split_name == 'validation': + try_root = os.path.join(root, 'val') + if os.path.exists(try_root): + return try_root + return root + + +def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): + name = name.lower() + if name.startswith('tfds'): + ds = IterableImageDataset( + root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) + else: + # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future + kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier + if search_split and os.path.isdir(root): + root = _search_split(root, split) + ds = ImageDataset(root, parser=name, **kwargs) + return ds diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9506a8805dc3cec25498cd32d7c7476b1b372f8a --- /dev/null +++ b/timm/data/distributed_sampler.py @@ -0,0 +1,51 @@ +import math +import torch +from torch.utils.data import Sampler +import torch.distributed as dist + + +class OrderedDistributedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[:(self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples diff --git a/timm/data/loader.py b/timm/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..76144669090aca1e962d75bfeab66aaf923e7ec5 --- /dev/null +++ b/timm/data/loader.py @@ -0,0 +1,262 @@ +""" Loader Factory, Fast Collate, CUDA Prefetcher + +Prefetcher and Fast Collate inspired by NVIDIA APEX example at +https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch.utils.data +import numpy as np + +from .transforms_factory import create_transform +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .distributed_sampler import OrderedDistributedSampler +from .random_erasing import RandomErasing +from .mixup import FastCollateMixup + + +def fast_collate(batch): + """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], tuple): + # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position + # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + inner_tuple_size = len(batch[0][0]) + flattened_batch_size = batch_size * inner_tuple_size + targets = torch.zeros(flattened_batch_size, dtype=torch.int64) + tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length + for j in range(inner_tuple_size): + targets[i + j * batch_size] = batch[i][1] + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + return tensor, targets + elif isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + elif isinstance(batch[0][0], torch.Tensor): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i].copy_(batch[i][0]) + return tensor, targets + else: + assert False + + +class PrefetchLoader: + + def __init__(self, + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): + self.loader = loader + self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) + self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) + self.fp16 = fp16 + if fp16: + self.mean = self.mean.half() + self.std = self.std.half() + if re_prob > 0.: + self.random_erasing = RandomErasing( + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + else: + self.random_erasing = None + + def __iter__(self): + stream = torch.cuda.Stream() + first = True + + for next_input, next_target in self.loader: + with torch.cuda.stream(stream): + next_input = next_input.cuda(non_blocking=True) + next_target = next_target.cuda(non_blocking=True) + if self.fp16: + next_input = next_input.half().sub_(self.mean).div_(self.std) + else: + next_input = next_input.float().sub_(self.mean).div_(self.std) + if self.random_erasing is not None: + next_input = self.random_erasing(next_input) + + if not first: + yield input, target + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + input = next_input + target = next_target + + yield input, target + + def __len__(self): + return len(self.loader) + + @property + def sampler(self): + return self.loader.sampler + + @property + def dataset(self): + return self.loader.dataset + + @property + def mixup_enabled(self): + if isinstance(self.loader.collate_fn, FastCollateMixup): + return self.loader.collate_fn.mixup_enabled + else: + return False + + @mixup_enabled.setter + def mixup_enabled(self, x): + if isinstance(self.loader.collate_fn, FastCollateMixup): + self.loader.collate_fn.mixup_enabled = x + + +def create_loader( + dataset, + input_size, + batch_size, + is_training=False, + use_prefetcher=True, + no_aug=False, + re_prob=0., + re_mode='const', + re_count=1, + re_split=False, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., + color_jitter=0.4, + auto_augment=None, + num_aug_splits=0, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_workers=1, + distributed=False, + crop_pct=None, + collate_fn=None, + pin_memory=False, + fp16=False, + tf_preprocessing=False, + use_multi_epochs_loader=False, + persistent_workers=True, +): + re_num_splits = 0 + if re_split: + # apply RE to second half of batch if no aug split otherwise line up with aug split + re_num_splits = num_aug_splits or 2 + dataset.transform = create_transform( + input_size, + is_training=is_training, + use_prefetcher=use_prefetcher, + no_aug=no_aug, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + mean=mean, + std=std, + crop_pct=crop_pct, + tf_preprocessing=tf_preprocessing, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=num_aug_splits > 0, + ) + + sampler = None + if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): + if is_training: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + # This will add extra duplicate entries to result in equal num + # of samples per-process, will slightly alter validation results + sampler = OrderedDistributedSampler(dataset) + + if collate_fn is None: + collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate + + loader_class = torch.utils.data.DataLoader + + if use_multi_epochs_loader: + loader_class = MultiEpochsDataLoader + + loader_args = dict( + batch_size=batch_size, + shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training, + num_workers=num_workers, + sampler=sampler, + collate_fn=collate_fn, + pin_memory=pin_memory, + drop_last=is_training, + persistent_workers=persistent_workers) + try: + loader = loader_class(dataset, **loader_args) + except TypeError as e: + loader_args.pop('persistent_workers') # only in Pytorch 1.7+ + loader = loader_class(dataset, **loader_args) + if use_prefetcher: + prefetch_re_prob = re_prob if is_training and not no_aug else 0. + loader = PrefetchLoader( + loader, + mean=mean, + std=std, + fp16=fp16, + re_prob=prefetch_re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits + ) + + return loader + + +class MultiEpochsDataLoader(torch.utils.data.DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = _RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ Sampler that repeats forever. + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) diff --git a/timm/data/mixup.py b/timm/data/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..38477548a070a1a338ed18ddc74cdaf5050f84be --- /dev/null +++ b/timm/data/mixup.py @@ -0,0 +1,316 @@ +""" Mixup and Cutmix + +Papers: +mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) + +CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) + +Code Reference: +CutMix: https://github.com/clovaai/CutMix-PyTorch + +Hacked together by / Copyright 2020 Ross Wightman +""" +import numpy as np +import torch + + +def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): + x = x.long().view(-1, 1) + return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) + + +def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): + off_value = smoothing / num_classes + on_value = 1. - smoothing + off_value + y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) + y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) + return y1 * lam + y2 * (1. - lam) + + +def rand_bbox(img_shape, lam, margin=0., count=None): + """ Standard CutMix bounding-box + Generates a random square bbox based on lambda value. This impl includes + support for enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) + count (int): Number of bbox to generate + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape[-2:] + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + +def rand_bbox_minmax(img_shape, minmax, count=None): + """ Min-Max CutMix bounding-box + Inspired by Darknet cutmix impl, generates a random rectangular bbox + based on min/max percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + minmax (tuple or list): Min and max bbox ratios (as percent of image size) + count (int): Number of bbox to generate + """ + assert len(minmax) == 2 + img_h, img_w = img_shape[-2:] + cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) + cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + +def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): + """ Generate bbox and apply lambda correction. + """ + if ratio_minmax is not None: + yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) + else: + yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) + if correct_lam or ratio_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) + return (yl, yu, xl, xu), lam + + +class Mixup: + """ Mixup/Cutmix that applies different params to each element or whole batch + + Args: + mixup_alpha (float): mixup alpha value, mixup is active if > 0. + cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. + cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. + prob (float): probability of applying mixup or cutmix per batch or element + switch_prob (float): probability of switching to cutmix instead of mixup when both are active + mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) + correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders + label_smoothing (float): apply label smoothing to the mixed target tensor + num_classes (int): number of classes for target + """ + def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, + mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): + self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.cutmix_minmax = cutmix_minmax + if self.cutmix_minmax is not None: + assert len(self.cutmix_minmax) == 2 + # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe + self.cutmix_alpha = 1.0 + self.mix_prob = prob + self.switch_prob = switch_prob + self.label_smoothing = label_smoothing + self.num_classes = num_classes + self.mode = mode + self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix + self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) + + def _params_per_elem(self, batch_size): + lam = np.ones(batch_size, dtype=np.float32) + use_cutmix = np.zeros(batch_size, dtype=np.bool) + if self.mixup_enabled: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand(batch_size) < self.switch_prob + lam_mix = np.where( + use_cutmix, + np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), + np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) + elif self.cutmix_alpha > 0.: + use_cutmix = np.ones(batch_size, dtype=np.bool) + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) + return lam, use_cutmix + + def _params_per_batch(self): + lam = 1. + use_cutmix = False + if self.mixup_enabled and np.random.rand() < self.mix_prob: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand() < self.switch_prob + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ + np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.cutmix_alpha > 0.: + use_cutmix = True + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = float(lam_mix) + return lam, use_cutmix + + def _mix_elem(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_pair(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + x[j] = x[j] * lam + x_orig[i] * (1 - lam) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_batch(self, x): + lam, use_cutmix = self._params_per_batch() + if lam == 1.: + return 1. + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] + else: + x_flipped = x.flip(0).mul_(1. - lam) + x.mul_(lam).add_(x_flipped) + return lam + + def __call__(self, x, target): + assert len(x) % 2 == 0, 'Batch size should be even when using this' + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + return x, target + + +class FastCollateMixup(Mixup): + """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch + + A Mixup impl that's performed while collating the batches. + """ + + def _mix_elem_collate(self, output, batch, half=False): + batch_size = len(batch) + num_elem = batch_size // 2 if half else batch_size + assert len(output) == num_elem + lam_batch, use_cutmix = self._params_per_elem(num_elem) + for i in range(num_elem): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed = batch[i][0] + if lam != 1.: + if use_cutmix[i]: + if not half: + mixed = mixed.copy() + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + if half: + lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_pair_collate(self, output, batch): + batch_size = len(batch) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed_i = batch[i][0] + mixed_j = batch[j][0] + assert 0 <= lam <= 1.0 + if lam < 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + patch_i = mixed_i[:, yl:yh, xl:xh].copy() + mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] + mixed_j[:, yl:yh, xl:xh] = patch_i + lam_batch[i] = lam + else: + mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) + mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) + mixed_i = mixed_temp + np.rint(mixed_j, out=mixed_j) + np.rint(mixed_i, out=mixed_i) + output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) + output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_batch_collate(self, output, batch): + batch_size = len(batch) + lam, use_cutmix = self._params_per_batch() + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + for i in range(batch_size): + j = batch_size - i - 1 + mixed = batch[i][0] + if lam != 1.: + if use_cutmix: + mixed = mixed.copy() # don't want to modify the original while iterating + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.rint(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + return lam + + def __call__(self, batch, _=None): + batch_size = len(batch) + assert batch_size % 2 == 0, 'Batch size should be even when using this' + half = 'half' in self.mode + if half: + batch_size //= 2 + output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + if self.mode == 'elem' or self.mode == 'half': + lam = self._mix_elem_collate(output, batch, half=half) + elif self.mode == 'pair': + lam = self._mix_pair_collate(output, batch) + else: + lam = self._mix_batch_collate(output, batch) + target = torch.tensor([b[1] for b in batch], dtype=torch.int64) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + target = target[:batch_size] + return output, target + diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb44e3714eff75028e15214e0e65bf2afebd86c --- /dev/null +++ b/timm/data/parsers/__init__.py @@ -0,0 +1 @@ +from .parser_factory import create_parser diff --git a/timm/data/parsers/__pycache__/__init__.cpython-310.pyc b/timm/data/parsers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2016cb4c181c337c2f5f7b9556ea882c1958bff Binary files /dev/null and b/timm/data/parsers/__pycache__/__init__.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/__init__.cpython-311.pyc b/timm/data/parsers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..836d4d745d712bea8aff38631a59aa5ee887b62d Binary files /dev/null and b/timm/data/parsers/__pycache__/__init__.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/__init__.cpython-313.pyc b/timm/data/parsers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a2aa651001919b848891e1db1ef8d12631c722 Binary files /dev/null and b/timm/data/parsers/__pycache__/__init__.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/__init__.cpython-38.pyc b/timm/data/parsers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fb8a4d417a30fcc7550377fb0a860d881d06712 Binary files /dev/null and b/timm/data/parsers/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/class_map.cpython-310.pyc b/timm/data/parsers/__pycache__/class_map.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84b320896d882722f95ff1ba04e8f254c1cd2e96 Binary files /dev/null and b/timm/data/parsers/__pycache__/class_map.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/class_map.cpython-311.pyc b/timm/data/parsers/__pycache__/class_map.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f36c20d7dba8ce861ff4fb800c7d8198abefce30 Binary files /dev/null and b/timm/data/parsers/__pycache__/class_map.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/class_map.cpython-313.pyc b/timm/data/parsers/__pycache__/class_map.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9902f8b0c3dd9a7e79c730df660faac8315c8c2 Binary files /dev/null and b/timm/data/parsers/__pycache__/class_map.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/class_map.cpython-38.pyc b/timm/data/parsers/__pycache__/class_map.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6788b256a806ddb90821e30f28dbe3e1ad741291 Binary files /dev/null and b/timm/data/parsers/__pycache__/class_map.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/constants.cpython-310.pyc b/timm/data/parsers/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20e2498c9a0ffdc6fff5544d22426b0e273cba70 Binary files /dev/null and b/timm/data/parsers/__pycache__/constants.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/constants.cpython-311.pyc b/timm/data/parsers/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..246b54e22f7bb7f3f1145c594960764b944124ce Binary files /dev/null and b/timm/data/parsers/__pycache__/constants.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/constants.cpython-313.pyc b/timm/data/parsers/__pycache__/constants.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cff96eea9c6c67427e6a60703577b67e95e150d7 Binary files /dev/null and b/timm/data/parsers/__pycache__/constants.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/constants.cpython-38.pyc b/timm/data/parsers/__pycache__/constants.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd66e553be3c38bf2df0e9e251cdcc9bdf573da0 Binary files /dev/null and b/timm/data/parsers/__pycache__/constants.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser.cpython-310.pyc b/timm/data/parsers/__pycache__/parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6cb054906bc8874425aee5226ceb81c061d19de Binary files /dev/null and b/timm/data/parsers/__pycache__/parser.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/parser.cpython-311.pyc b/timm/data/parsers/__pycache__/parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3273724204f106afa1546a954260fde2f892fac6 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/parser.cpython-313.pyc b/timm/data/parsers/__pycache__/parser.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fa08ab47645fa545808a9d238f1b8572d7ab30e Binary files /dev/null and b/timm/data/parsers/__pycache__/parser.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/parser.cpython-38.pyc b/timm/data/parsers/__pycache__/parser.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b221f49fcb60af2b13141ab2ccb4adcafdac2f3 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc b/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62174315a7d73a06575745fa3a56078dd359d568 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_factory.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_factory.cpython-311.pyc b/timm/data/parsers/__pycache__/parser_factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92b788314c341bce753ad9ef72b89104becfb88 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_factory.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_factory.cpython-313.pyc b/timm/data/parsers/__pycache__/parser_factory.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63ae370c67c16addc39daccc004b2308458bc142 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_factory.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..386e4a6d7e21a59fc5c4f02928677d92b3cba010 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_factory.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc b/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7640d62ca2c3d4e1e16a49f4652ca03e71040752 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_folder.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_folder.cpython-311.pyc b/timm/data/parsers/__pycache__/parser_image_folder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f593b57f8cce96e1334154912d85e61eba00d7f8 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_folder.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_folder.cpython-313.pyc b/timm/data/parsers/__pycache__/parser_image_folder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e07b701216b00e9f3ffa5c9285237e319380ae3 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_folder.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cb5e8d25966a10b2b103001d4382e4469f888e1 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_folder.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b47bdf145f274b5dd5efcf8820374233f2d1fc Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-311.pyc b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e2aae3b15124bdc07293a071417fa4346822d72 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-313.pyc b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05aa1e273d8f5b38117d651ef45a6a935f667c8d Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01dde3cbf20e6fde16c36c7e91bdc2fffe1bda92 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc b/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a73afa0fd5dc6f87db42add026fb1fc0b80e7991 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_tar.cpython-310.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_tar.cpython-311.pyc b/timm/data/parsers/__pycache__/parser_image_tar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e7ff55b1d10f50d8c1814ae12d51da318cd7390 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_tar.cpython-311.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_tar.cpython-313.pyc b/timm/data/parsers/__pycache__/parser_image_tar.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..992559cbcfd10284240412d3824d223468af1f65 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_tar.cpython-313.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f14c9584a5b67345b42ccd8185ca997aeba3891 Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_tar.cpython-38.pyc differ diff --git a/timm/data/parsers/__pycache__/parser_tfds.cpython-38.pyc b/timm/data/parsers/__pycache__/parser_tfds.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aef63f2d77310ed2cb1654d8b6db04f5da8a184b Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_tfds.cpython-38.pyc differ diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py new file mode 100644 index 0000000000000000000000000000000000000000..9ef4d1fab4cb126c7737e6888420af76abed19bf --- /dev/null +++ b/timm/data/parsers/class_map.py @@ -0,0 +1,16 @@ +import os + + +def load_class_map(filename, root=''): + class_map_path = filename + if not os.path.exists(class_map_path): + class_map_path = os.path.join(root, filename) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename + class_map_ext = os.path.splitext(filename)[-1].lower() + if class_map_ext == '.txt': + with open(class_map_path) as f: + class_to_idx = {v.strip(): k for k, v in enumerate(f)} + else: + assert False, 'Unsupported class map extension' + return class_to_idx + diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ba484e729b7ac976b2cedaa43be1c3b308eeeb --- /dev/null +++ b/timm/data/parsers/constants.py @@ -0,0 +1 @@ +IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') diff --git a/timm/data/parsers/parser.py b/timm/data/parsers/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..76ab6d18283644702424d0ff2af5832d6d6dd3b7 --- /dev/null +++ b/timm/data/parsers/parser.py @@ -0,0 +1,17 @@ +from abc import abstractmethod + + +class Parser: + def __init__(self): + pass + + @abstractmethod + def _filename(self, index, basename=False, absolute=False): + pass + + def filename(self, index, basename=False, absolute=False): + return self._filename(index, basename=basename, absolute=absolute) + + def filenames(self, basename=False, absolute=False): + return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))] + diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..419ffe899b476233dba84b6cb8d0851801da27a5 --- /dev/null +++ b/timm/data/parsers/parser_factory.py @@ -0,0 +1,29 @@ +import os + +from .parser_image_folder import ParserImageFolder +from .parser_image_tar import ParserImageTar +from .parser_image_in_tar import ParserImageInTar + + +def create_parser(name, root, split='train', **kwargs): + name = name.lower() + name = name.split('/', 2) + prefix = '' + if len(name) > 1: + prefix = name[0] + name = name[-1] + + # FIXME improve the selection right now just tfds prefix or fallback path, will need options to + # explicitly select other options shortly + if prefix == 'tfds': + from .parser_tfds import ParserTfds # defer tensorflow import + parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) + else: + assert os.path.exists(root) + # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder + # FIXME support split here, in parser? + if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar': + parser = ParserImageInTar(root, **kwargs) + else: + parser = ParserImageFolder(root, **kwargs) + return parser diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..ed349009a4caed92e290e05637eca20e46cc275b --- /dev/null +++ b/timm/data/parsers/parser_image_folder.py @@ -0,0 +1,69 @@ +""" A dataset parser that reads images from folders + +Folders are scannerd recursively to find image files. Labels are based +on the folder hierarchy, just leaf folders by default. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os + +from timm.utils.misc import natural_key + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS + + +def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + labels = [] + filenames = [] + for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True): + rel_path = os.path.relpath(root, folder) if (root != folder) else '' + label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') + for f in files: + base, ext = os.path.splitext(f) + if ext.lower() in types: + filenames.append(os.path.join(root, f)) + labels.append(label) + if class_to_idx is None: + # building class index + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx] + if sort: + images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) + return images_and_targets, class_to_idx + + +class ParserImageFolder(Parser): + + def __init__( + self, + root, + class_map=''): + super().__init__() + + self.root = root + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx) + if len(self.samples) == 0: + raise RuntimeError( + f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}') + + def __getitem__(self, index): + path, target = self.samples[index] + return open(path, 'rb'), target + + def __len__(self): + return len(self.samples) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0] + if basename: + filename = os.path.basename(filename) + elif not absolute: + filename = os.path.relpath(filename, self.root) + return filename diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ada962ca96eaa7d770014ba130c6de5b36b6ec --- /dev/null +++ b/timm/data/parsers/parser_image_in_tar.py @@ -0,0 +1,222 @@ +""" A dataset parser that reads tarfile based datasets + +This parser can read and extract image samples from: +* a single tar of image files +* a folder of multiple tarfiles containing imagefiles +* a tar of tars containing image files + +Labels are based on the combined folder and/or tar name structure. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import tarfile +import pickle +import logging +import numpy as np +from glob import glob +from typing import List, Dict + +from timm.utils.misc import natural_key + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS + + +_logger = logging.getLogger(__name__) +CACHE_FILENAME_SUFFIX = '_tarinfos.pickle' + + +class TarState: + + def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None): + self.tf: tarfile.TarFile = tf + self.ti: tarfile.TarInfo = ti + self.children: Dict[str, TarState] = {} # child states (tars within tars) + + def reset(self): + self.tf = None + + +def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS): + sample_count = 0 + for i, ti in enumerate(tf): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + name, ext = os.path.splitext(basename) + ext = ext.lower() + if ext == '.tar': + with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf: + child_info = dict( + name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[]) + sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions) + _logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.') + parent_info['children'].append(child_info) + elif ext in extensions: + parent_info['samples'].append(ti) + sample_count += 1 + return sample_count + + +def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True): + root_is_tar = False + if os.path.isfile(root): + assert os.path.splitext(root)[-1].lower() == '.tar' + tar_filenames = [root] + root, root_name = os.path.split(root) + root_name = os.path.splitext(root_name)[0] + root_is_tar = True + else: + root_name = root.strip(os.path.sep).split(os.path.sep)[-1] + tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True) + num_tars = len(tar_filenames) + tar_bytes = sum([os.path.getsize(f) for f in tar_filenames]) + assert num_tars, f'No .tar files found at specified path ({root}).' + + _logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...') + info = dict(tartrees=[]) + cache_path = '' + if cache_tarinfo is None: + cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB + if cache_tarinfo: + cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX + cache_path = os.path.join(root, cache_filename) + if os.path.exists(cache_path): + _logger.info(f'Reading tar info from cache file {cache_path}.') + with open(cache_path, 'rb') as pf: + info = pickle.load(pf) + assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles" + else: + for i, fn in enumerate(tar_filenames): + path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0] + with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode + parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[]) + num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions) + num_children = len(parent_info["children"]) + _logger.debug( + f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.') + info['tartrees'].append(parent_info) + if cache_path: + _logger.info(f'Writing tar info to cache file {cache_path}.') + with open(cache_path, 'wb') as pf: + pickle.dump(info, pf) + + samples = [] + labels = [] + build_class_map = False + if class_name_to_idx is None: + build_class_map = True + + # Flatten tartree info into lists of samples and targets w/ targets based on label id via + # class map arg or from unique paths. + # NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children + # this covers my current use cases and keeps things a little easier to test for now. + tarfiles = [] + + def _label_from_paths(*path, leaf_only=True): + path = os.path.join(*path).strip(os.path.sep) + return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_') + + def _add_samples(info, fn): + added = 0 + for s in info['samples']: + label = _label_from_paths(info['path'], os.path.dirname(s.path)) + if not build_class_map and label not in class_name_to_idx: + continue + samples.append((s, fn, info['ti'])) + labels.append(label) + added += 1 + return added + + _logger.info(f'Collecting samples and building tar states.') + for parent_info in info['tartrees']: + # if tartree has children, we assume all samples are at the child level + tar_name = None if root_is_tar else parent_info['name'] + tar_state = TarState() + parent_added = 0 + for child_info in parent_info['children']: + child_added = _add_samples(child_info, fn=tar_name) + if child_added: + tar_state.children[child_info['name']] = TarState(ti=child_info['ti']) + parent_added += child_added + parent_added += _add_samples(parent_info, fn=tar_name) + if parent_added: + tarfiles.append((tar_name, tar_state)) + del info + + if build_class_map: + # build class index + sorted_labels = list(sorted(set(labels), key=natural_key)) + class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + + _logger.info(f'Mapping targets and sorting samples.') + samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx] + if sort: + samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path)) + samples, targets = zip(*samples_and_targets) + samples = np.array(samples) + targets = np.array(targets) + _logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.') + return samples, targets, class_name_to_idx, tarfiles + + +class ParserImageInTar(Parser): + """ Multi-tarfile dataset parser where there is one .tar file per class + """ + + def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None): + super().__init__() + + class_name_to_idx = None + if class_map: + class_name_to_idx = load_class_map(class_map, root) + self.root = root + self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos( + self.root, + class_name_to_idx=class_name_to_idx, + cache_tarinfo=cache_tarinfo, + extensions=IMG_EXTENSIONS) + self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()} + if len(tarfiles) == 1 and tarfiles[0][0] is None: + self.root_is_tar = True + self.tar_state = tarfiles[0][1] + else: + self.root_is_tar = False + self.tar_state = dict(tarfiles) + self.cache_tarfiles = cache_tarfiles + + def __len__(self): + return len(self.samples) + + def __getitem__(self, index): + sample = self.samples[index] + target = self.targets[index] + sample_ti, parent_fn, child_ti = sample + parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root + + tf = None + cache_state = None + if self.cache_tarfiles: + cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn] + tf = cache_state.tf + if tf is None: + tf = tarfile.open(parent_abs) + if self.cache_tarfiles: + cache_state.tf = tf + if child_ti is not None: + ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None + if ctf is None: + ctf = tarfile.open(fileobj=tf.extractfile(child_ti)) + if self.cache_tarfiles: + cache_state.children[child_ti.name].tf = ctf + tf = ctf + + return tf.extractfile(sample_ti), target + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py new file mode 100644 index 0000000000000000000000000000000000000000..467537f479873bbc09fa2b576cdbde9d2a956e7b --- /dev/null +++ b/timm/data/parsers/parser_image_tar.py @@ -0,0 +1,72 @@ +""" A dataset parser that reads single tarfile based datasets + +This parser can read datasets consisting if a single tarfile containing images. +I am planning to deprecated it in favour of ParerImageInTar. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import tarfile + +from .parser import Parser +from .class_map import load_class_map +from .constants import IMG_EXTENSIONS +from timm.utils.misc import natural_key + + +def extract_tarinfo(tarfile, class_to_idx=None, sort=True): + files = [] + labels = [] + for ti in tarfile.getmembers(): + if not ti.isfile(): + continue + dirname, basename = os.path.split(ti.path) + label = os.path.basename(dirname) + ext = os.path.splitext(basename)[1] + if ext.lower() in IMG_EXTENSIONS: + files.append(ti) + labels.append(label) + if class_to_idx is None: + unique_labels = set(labels) + sorted_labels = list(sorted(unique_labels, key=natural_key)) + class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)} + tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx] + if sort: + tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path)) + return tarinfo_and_targets, class_to_idx + + +class ParserImageTar(Parser): + """ Single tarfile dataset where classes are mapped to folders within tar + NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can + operate on folders of tars or tars in tars. + """ + def __init__(self, root, class_map=''): + super().__init__() + + class_to_idx = None + if class_map: + class_to_idx = load_class_map(class_map, root) + assert os.path.isfile(root) + self.root = root + + with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later + self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx) + self.imgs = self.samples + self.tarfile = None # lazy init in __getitem__ + + def __getitem__(self, index): + if self.tarfile is None: + self.tarfile = tarfile.open(self.root) + tarinfo, target = self.samples[index] + fileobj = self.tarfile.extractfile(tarinfo) + return fileobj, target + + def __len__(self): + return len(self.samples) + + def _filename(self, index, basename=False, absolute=False): + filename = self.samples[index][0].name + if basename: + filename = os.path.basename(filename) + return filename diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff90b09f373a314ea24d0f84cf62c3dfcc02731 --- /dev/null +++ b/timm/data/parsers/parser_tfds.py @@ -0,0 +1,223 @@ +""" Dataset parser interface that wraps TFDS datasets + +Wraps many (most?) TFDS image-classification datasets +from https://github.com/tensorflow/datasets +https://www.tensorflow.org/datasets/catalog/overview#image_classification + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import io +import math +import torch +import torch.distributed as dist +from PIL import Image + +try: + import tensorflow as tf + tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + import tensorflow_datasets as tfds +except ImportError as e: + print(e) + print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") + exit(1) +from .parser import Parser + + +MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities +SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +PREFETCH_SIZE = 2048 # samples to prefetch + + +def even_split_indices(split, n, num_samples): + partitions = [round(i * num_samples / n) for i in range(n + 1)] + return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] + + +class ParserTfds(Parser): + """ Wrap Tensorflow Datasets for use in PyTorch + + There several things to be aware of: + * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of + dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last + https://github.com/pytorch/pytorch/issues/33413 + * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch + from each worker could be a different size. For training this is worked around by option above, for + validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced + across replicas are of same size. This will slightly alter the results, distributed validation will not be + 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse + since there are up to N * J extra samples with IterableDatasets. + * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of + replicas and dataloader workers you can use. For really small datasets that only contain a few shards + you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the + benefit of distributed training or fast dataloading should be much less for small datasets. + * This wrapper is currently configured to return individual, decompressed image samples from the TFDS + dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible + to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream + components. + + """ + def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): + super().__init__() + self.root = root + self.split = split + self.shuffle = shuffle + self.is_training = is_training + if self.is_training: + assert batch_size is not None,\ + "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" + self.batch_size = batch_size + self.repeats = repeats + self.subsplit = None + + self.builder = tfds.builder(name, data_dir=root) + # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call + # download_and_prepare() by default here as it's caused issues generating unwanted paths. + self.num_samples = self.builder.info.splits[split].num_examples + self.ds = None # initialized lazily on each dataloader worker process + + self.worker_info = None + self.dist_rank = 0 + self.dist_num_replicas = 1 + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + self.dist_rank = dist.get_rank() + self.dist_num_replicas = dist.get_world_size() + + def _lazy_init(self): + """ Lazily initialize the dataset. + + This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that + will be using the dataset instance. The __init__ method is called on the main process, + this will be called in a dataloader worker process. + + NOTE: There will be problems if you try to re-use this dataset across different loader/worker + instances once it has been initialized. Do not call any dataset methods that can call _lazy_init + before it is passed to dataloader. + """ + worker_info = torch.utils.data.get_worker_info() + + # setup input context to split dataset across distributed processes + split = self.split + num_workers = 1 + if worker_info is not None: + self.worker_info = worker_info + num_workers = worker_info.num_workers + global_num_workers = self.dist_num_replicas * num_workers + worker_id = worker_info.id + + # FIXME I need to spend more time figuring out the best way to distribute/split data across + # combo of distributed replicas + dataloader worker processes + """ + InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. + My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) + between the splits each iteration, but that understanding could be wrong. + Possible split options include: + * InputContext for both distributed & worker processes (current) + * InputContext for distributed and sub-splits for worker processes + * sub-splits for both + """ + # split_size = self.num_samples // num_workers + # start = worker_id * split_size + # if worker_id == num_workers - 1: + # split = split + '[{}:]'.format(start) + # else: + # split = split + '[{}:{}]'.format(start, start + split_size) + if not self.is_training and '[' not in self.split: + # If not training, and split doesn't define a subsplit, manually split the dataset + # for more even samples / worker + self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ + self.dist_rank * num_workers + worker_id] + + if self.subsplit is None: + input_context = tf.distribute.InputContext( + num_input_pipelines=self.dist_num_replicas * num_workers, + input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? + ) + else: + input_context = None + + read_config = tfds.ReadConfig( + shuffle_seed=42, + shuffle_reshuffle_each_iteration=True, + input_context=input_context) + ds = self.builder.as_dataset( + split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) + # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + options.experimental_threading.max_intra_op_parallelism = 1 + ds = ds.with_options(options) + if self.is_training or self.repeats > 1: + # to prevent excessive drop_last batch behaviour w/ IterableDatasets + # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading + ds = ds.repeat() # allow wrap around and break iteration manually + if self.shuffle: + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) + ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) + self.ds = tfds.as_numpy(ds) + + def __iter__(self): + if self.ds is None: + self._lazy_init() + # compute a rounded up sample count that is used to: + # 1. make batches even cross workers & replicas in distributed validation. + # This adds extra samples and will slightly alter validation results. + # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size + # batches are produced (underlying tfds iter wraps around) + target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) + if self.is_training: + # round up to nearest batch_size per worker-replica + target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + sample_count = 0 + for sample in self.ds: + img = Image.fromarray(sample['image'], mode='RGB') + yield img, sample['label'] + sample_count += 1 + if self.is_training and sample_count >= target_sample_count: + # Need to break out of loop when repeat() is enabled for training w/ oversampling + # this results in extra samples per epoch but seems more desirable than dropping + # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) + break + if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count: + # Validation batch padding only done for distributed training where results are reduced across nodes. + # For single process case, it won't matter if workers return different batch sizes. + # FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this + # approach is not optimal + yield img, sample['label'] # yield prev sample again + sample_count += 1 + + @property + def _num_workers(self): + return 1 if self.worker_info is None else self.worker_info.num_workers + + @property + def _num_pipelines(self): + return self._num_workers * self.dist_num_replicas + + def __len__(self): + # this is just an estimate and does not factor in extra samples added to pad batches based on + # complete worker & replica info (not available until init in dataloader). + return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) + + def _filename(self, index, basename=False, absolute=False): + assert False, "Not supported" # no random access to samples + + def filenames(self, basename=False, absolute=False): + """ Return all filenames in dataset, overrides base""" + if self.ds is None: + self._lazy_init() + names = [] + for sample in self.ds: + if len(names) > self.num_samples: + break # safety for ds.repeat() case + if 'file_name' in sample: + name = sample['file_name'] + elif 'filename' in sample: + name = sample['filename'] + elif 'id' in sample: + name = sample['id'] + else: + assert False, "No supported name field present" + names.append(name) + return names diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py new file mode 100644 index 0000000000000000000000000000000000000000..78967d105dd77b56a3ccefb6ff1838a8058c0384 --- /dev/null +++ b/timm/data/random_erasing.py @@ -0,0 +1,97 @@ +""" Random Erasing (Cutout) + +Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 +Copyright Zhun Zhong & Liang Zheng + +Hacked together by / Copyright 2020 Ross Wightman +""" +import random +import math +import torch + + +def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, + mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + if mode == 'rand': + self.rand_color = True # per block random normal + elif mode == 'pixel': + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == 'const' + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = self.min_count if self.min_count == self.max_count else \ + random.randint(self.min_count, self.max_count) + for _ in range(count): + for attempt in range(10): + target_area = random.uniform(self.min_area, self.max_area) * area / count + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, self.rand_color, (chan, h, w), + dtype=dtype, device=self.device) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/timm/data/real_labels.py b/timm/data/real_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..939c34867e7915ce3e4cc7da04a5bc1653ec4f2c --- /dev/null +++ b/timm/data/real_labels.py @@ -0,0 +1,42 @@ +""" Real labels evaluator for ImageNet +Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159 +Based on Numpy example at https://github.com/google-research/reassessed-imagenet + +Hacked together by / Copyright 2020 Ross Wightman +""" +import os +import json +import numpy as np + + +class RealLabelsImagenet: + + def __init__(self, filenames, real_json='real.json', topk=(1, 5)): + with open(real_json) as real_labels: + real_labels = json.load(real_labels) + real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)} + self.real_labels = real_labels + self.filenames = filenames + assert len(self.filenames) == len(self.real_labels) + self.topk = topk + self.is_correct = {k: [] for k in topk} + self.sample_idx = 0 + + def add_result(self, output): + maxk = max(self.topk) + _, pred_batch = output.topk(maxk, 1, True, True) + pred_batch = pred_batch.cpu().numpy() + for pred in pred_batch: + filename = self.filenames[self.sample_idx] + filename = os.path.basename(filename) + if self.real_labels[filename]: + for k in self.topk: + self.is_correct[k].append( + any([p in self.real_labels[filename] for p in pred[:k]])) + self.sample_idx += 1 + + def get_accuracy(self, k=None): + if k is None: + return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk} + else: + return float(np.mean(self.is_correct[k])) * 100 diff --git a/timm/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..44b4a3af7372c6865b1cdddda0a8da0ccc6b93a0 --- /dev/null +++ b/timm/data/tf_preprocessing.py @@ -0,0 +1,232 @@ +""" Tensorflow Preprocessing Adapter + +Allows use of Tensorflow preprocessing pipeline in PyTorch Transform + +Copyright of original Tensorflow code below. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing for MnasNet.""" +import tensorflow as tf +import numpy as np + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +def distorted_bounding_box_crop(image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100, + scope=None): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image_bytes: `Tensor` of binary image data. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` + where each coordinate is [0, 1) and the coordinates are arranged + as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole + image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped + area of the image must contain at least this fraction of any bounding + box supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image + must contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + scope: Optional `str` for name scope. + Returns: + cropped image `Tensor` + """ + with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]): + shape = tf.image.extract_jpeg_shape(image_bytes) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _decode_and_random_crop(image_bytes, image_size, resize_method): + """Make a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = distorted_bounding_box_crop( + image_bytes, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10, + scope=None) + original_shape = tf.image.extract_jpeg_shape(image_bytes) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond( + bad, + lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0]) + + return image + + +def _decode_and_center_crop(image_bytes, image_size, resize_method): + """Crops to center of image with padding then scales image_size.""" + shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), + tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, + padded_center_crop_size, padded_center_crop_size]) + image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) + image = tf.image.resize([image], [image_size, image_size], resize_method)[0] + + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + interpolation: image interpolation method + + Returns: + A preprocessed image `Tensor`. + """ + resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR + image = _decode_and_random_crop(image_bytes, image_size, resize_method) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'): + """Preprocesses the given image for evaluation. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + interpolation: image interpolation method + + Returns: + A preprocessed image `Tensor`. + """ + resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR + image = _decode_and_center_crop(image_bytes, image_size, resize_method) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype( + image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) + return image + + +def preprocess_image(image_bytes, + is_training=False, + use_bfloat16=False, + image_size=IMAGE_SIZE, + interpolation='bicubic'): + """Preprocesses the given image. + + Args: + image_bytes: `Tensor` representing an image binary of arbitrary size. + is_training: `bool` for whether the preprocessing is for training. + use_bfloat16: `bool` for whether to use bfloat16. + image_size: image size. + interpolation: image interpolation method + + Returns: + A preprocessed image `Tensor` with value range of [0, 255]. + """ + if is_training: + return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation) + else: + return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation) + + +class TfPreprocessTransform: + + def __init__(self, is_training=False, size=224, interpolation='bicubic'): + self.is_training = is_training + self.size = size[0] if isinstance(size, tuple) else size + self.interpolation = interpolation + self._image_bytes = None + self.process_image = self._build_tf_graph() + self.sess = None + + def _build_tf_graph(self): + with tf.device('/cpu:0'): + self._image_bytes = tf.placeholder( + shape=[], + dtype=tf.string, + ) + img = preprocess_image( + self._image_bytes, self.is_training, False, self.size, self.interpolation) + return img + + def __call__(self, image_bytes): + if self.sess is None: + self.sess = tf.Session() + img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes}) + img = img.round().clip(0, 255).astype(np.uint8) + if img.ndim < 3: + img = np.expand_dims(img, axis=-1) + img = np.rollaxis(img, 2) # HWC to CHW + return img diff --git a/timm/data/transforms.py b/timm/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..4220304f664d861cad64760a6cbe05dfafdf4fe6 --- /dev/null +++ b/timm/data/transforms.py @@ -0,0 +1,158 @@ +import torch +import torchvision.transforms.functional as F +from PIL import Image +import warnings +import math +import random +import numpy as np + + +class ToNumpy: + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return np_img + + +class ToTensor: + + def __init__(self, dtype=torch.float32): + self.dtype = dtype + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return torch.from_numpy(np_img).to(dtype=self.dtype) + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + # default bilinear, do we want to allow nearest? + return Image.BILINEAR + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), + interpolation='bilinear'): + if isinstance(size, (list, tuple)): + self.size = tuple(size) + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..df6e0de0338554bf410065a61810e69b167738fe --- /dev/null +++ b/timm/data/transforms_factory.py @@ -0,0 +1,236 @@ +""" Transforms Factory +Factory methods for building image transforms for use with TIMM (PyTorch Image Models) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math + +import torch +from torchvision import transforms + +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform +from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.random_erasing import RandomErasing + + +def transforms_noaug_train( + img_size=224, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, +): + if interpolation == 'random': + # random interpolation not supported with no-aug + interpolation = 'bilinear' + tfl = [ + transforms.Resize(img_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size) + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + return transforms.Compose(tfl) + + +def transforms_imagenet_train( + img_size=224, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., + color_jitter=0.4, + auto_augment=None, + interpolation='random', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, + separate=False, +): + """ + If separate==True, the transforms are returned as a tuple of 3 separate transforms + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform + """ + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range + primary_tfl = [ + RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] + if hflip > 0.: + primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] + if vflip > 0.: + primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] + + secondary_tfl = [] + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, (tuple, list)): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = dict( + translate_const=int(img_size_min * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in mean]), + ) + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] + elif auto_augment.startswith('augmix'): + aa_params['translate_pct'] = 0.3 + secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] + else: + secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] + elif color_jitter is not None: + # color jitter is enabled when not using AA + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + secondary_tfl += [transforms.ColorJitter(*color_jitter)] + + final_tfl = [] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + final_tfl += [ToNumpy()] + else: + final_tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + if re_prob > 0.: + final_tfl.append( + RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) + + if separate: + return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + else: + return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + + +def transforms_imagenet_eval( + img_size=224, + crop_pct=None, + interpolation='bilinear', + use_prefetcher=False, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD): + crop_pct = crop_pct or DEFAULT_CROP_PCT + + if isinstance(img_size, (tuple, list)): + assert len(img_size) == 2 + if img_size[-1] == img_size[-2]: + # fall-back to older behaviour so Resize scales to shortest edge if target is square + scale_size = int(math.floor(img_size[0] / crop_pct)) + else: + scale_size = tuple([int(x / crop_pct) for x in img_size]) + else: + scale_size = int(math.floor(img_size / crop_pct)) + + tfl = [ + transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.CenterCrop(img_size), + ] + if use_prefetcher: + # prefetcher and collate will handle tensor conversion and norm + tfl += [ToNumpy()] + else: + tfl += [ + transforms.ToTensor(), + transforms.Normalize( + mean=torch.tensor(mean), + std=torch.tensor(std)) + ] + + return transforms.Compose(tfl) + + +def create_transform( + input_size, + is_training=False, + use_prefetcher=False, + no_aug=False, + scale=None, + ratio=None, + hflip=0.5, + vflip=0., + color_jitter=0.4, + auto_augment=None, + interpolation='bilinear', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0, + crop_pct=None, + tf_preprocessing=False, + separate=False): + + if isinstance(input_size, (tuple, list)): + img_size = input_size[-2:] + else: + img_size = input_size + + if tf_preprocessing and use_prefetcher: + assert not separate, "Separate transforms not supported for TF preprocessing" + from timm.data.tf_preprocessing import TfPreprocessTransform + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=interpolation) + else: + if is_training and no_aug: + assert not separate, "Cannot perform split augmentation with no_aug" + transform = transforms_noaug_train( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std) + elif is_training: + transform = transforms_imagenet_train( + img_size, + scale=scale, + ratio=ratio, + hflip=hflip, + vflip=vflip, + color_jitter=color_jitter, + auto_augment=auto_augment, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + re_prob=re_prob, + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits, + separate=separate) + else: + assert not separate, "Separate transforms not supported for validation preprocessing" + transform = transforms_imagenet_eval( + img_size, + interpolation=interpolation, + use_prefetcher=use_prefetcher, + mean=mean, + std=std, + crop_pct=crop_pct) + + return transform diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..28a686ce896f4335dcf074717b52077b43c237d7 --- /dev/null +++ b/timm/loss/__init__.py @@ -0,0 +1,3 @@ +from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from .jsd import JsdCrossEntropy +from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel \ No newline at end of file diff --git a/timm/loss/__pycache__/__init__.cpython-38.pyc b/timm/loss/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..884db30f1f87bf00f602f7f848e05e04a59d964d Binary files /dev/null and b/timm/loss/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/loss/__pycache__/asymmetric_loss.cpython-38.pyc b/timm/loss/__pycache__/asymmetric_loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87e5fccc58ff9f5be51a0ed6168b8312e9678d1 Binary files /dev/null and b/timm/loss/__pycache__/asymmetric_loss.cpython-38.pyc differ diff --git a/timm/loss/__pycache__/binary_cross_entropy.cpython-38.pyc b/timm/loss/__pycache__/binary_cross_entropy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f241fcc18aed1fd24fa04e14a3c2eee4e8ce494b Binary files /dev/null and b/timm/loss/__pycache__/binary_cross_entropy.cpython-38.pyc differ diff --git a/timm/loss/__pycache__/cross_entropy.cpython-38.pyc b/timm/loss/__pycache__/cross_entropy.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9b62c830a32c6337912ab9d1a6a6005104f5a0f Binary files /dev/null and b/timm/loss/__pycache__/cross_entropy.cpython-38.pyc differ diff --git a/timm/loss/__pycache__/jsd.cpython-38.pyc b/timm/loss/__pycache__/jsd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b73f17a561353d05ab81a288a633afa6daf856 Binary files /dev/null and b/timm/loss/__pycache__/jsd.cpython-38.pyc differ diff --git a/timm/loss/asymmetric_loss.py b/timm/loss/asymmetric_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b10f9c797c2cb3b2652302717b592dada216f3 --- /dev/null +++ b/timm/loss/asymmetric_loss.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + + +class AsymmetricLossMultiLabel(nn.Module): + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): + super(AsymmetricLossMultiLabel, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + # Calculating Probabilities + x_sigmoid = torch.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clamp(max=1) + + # Basic CE calculation + los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) + los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) + loss = los_pos + los_neg + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = torch.pow(1 - pt, one_sided_gamma) + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(True) + loss *= one_sided_w + + return -loss.sum() + + +class AsymmetricLossSingleLabel(nn.Module): + def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): + super(AsymmetricLossSingleLabel, self).__init__() + + self.eps = eps + self.logsoftmax = nn.LogSoftmax(dim=-1) + self.targets_classes = [] # prevent gpu repeated memory allocation + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.reduction = reduction + + def forward(self, inputs, target, reduction=None): + """" + Parameters + ---------- + x: input logits + y: targets (1-hot vector) + """ + + num_classes = inputs.size()[-1] + log_preds = self.logsoftmax(inputs) + self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) + + # ASL weights + targets = self.targets_classes + anti_targets = 1 - targets + xs_pos = torch.exp(log_preds) + xs_neg = 1 - xs_pos + xs_pos = xs_pos * targets + xs_neg = xs_neg * anti_targets + asymmetric_w = torch.pow(1 - xs_pos - xs_neg, + self.gamma_pos * targets + self.gamma_neg * anti_targets) + log_preds = log_preds * asymmetric_w + + if self.eps > 0: # label smoothing + self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) + + # loss calculation + loss = - self.targets_classes.mul(log_preds) + + loss = loss.sum(dim=-1) + if self.reduction == 'mean': + loss = loss.mean() + + return loss diff --git a/timm/loss/binary_cross_entropy.py b/timm/loss/binary_cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..ed76c1e8e004ca9a7255cf3650e322e6525c0577 --- /dev/null +++ b/timm/loss/binary_cross_entropy.py @@ -0,0 +1,47 @@ +""" Binary Cross Entropy w/ a few extras + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BinaryCrossEntropy(nn.Module): + """ BCE with optional one-hot from dense targets, label smoothing, thresholding + NOTE for experiments comparing CE to BCE /w label smoothing, may remove + """ + def __init__( + self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None, + reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None): + super(BinaryCrossEntropy, self).__init__() + assert 0. <= smoothing < 1.0 + self.smoothing = smoothing + self.target_threshold = target_threshold + self.reduction = reduction + self.register_buffer('weight', weight) + self.register_buffer('pos_weight', pos_weight) + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == target.shape[0] + if target.shape != x.shape: + # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse + num_classes = x.shape[-1] + # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ + off_value = self.smoothing / num_classes + on_value = 1. - self.smoothing + off_value + target = target.long().view(-1, 1) + target = torch.full( + (target.size()[0], num_classes), + off_value, + device=x.device, dtype=x.dtype).scatter_(1, target, on_value) + if self.target_threshold is not None: + # Make target 0, or 1 if threshold set + target = target.gt(self.target_threshold).to(dtype=target.dtype) + return F.binary_cross_entropy_with_logits( + x, target, + self.weight, + pos_weight=self.pos_weight, + reduction=self.reduction) diff --git a/timm/loss/cross_entropy.py b/timm/loss/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..60bef646cc6c31fd734f234346dbc4255def6622 --- /dev/null +++ b/timm/loss/cross_entropy.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LabelSmoothingCrossEntropy(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, smoothing=0.1): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothingCrossEntropy, self).__init__() + assert smoothing < 1.0 + self.smoothing = smoothing + self.confidence = 1. - smoothing + + def forward(self, x, target): + logprobs = F.log_softmax(x, dim=-1) + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + + +class SoftTargetCrossEntropy(nn.Module): + + def __init__(self): + super(SoftTargetCrossEntropy, self).__init__() + + def forward(self, x, target): + loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) + return loss.mean() diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py new file mode 100644 index 0000000000000000000000000000000000000000..dd64e156c23d27aa03817a587ae367e8175fc126 --- /dev/null +++ b/timm/loss/jsd.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .cross_entropy import LabelSmoothingCrossEntropy + + +class JsdCrossEntropy(nn.Module): + """ Jensen-Shannon Divergence + Cross-Entropy Loss + + Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + + Hacked together by / Copyright 2020 Ross Wightman + """ + def __init__(self, num_splits=3, alpha=12, smoothing=0.1): + super().__init__() + self.num_splits = num_splits + self.alpha = alpha + if smoothing is not None and smoothing > 0: + self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing) + else: + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def __call__(self, output, target): + split_size = output.shape[0] // self.num_splits + assert split_size * self.num_splits == output.shape[0] + logits_split = torch.split(output, split_size) + + # Cross-entropy is only computed on clean images + loss = self.cross_entropy_loss(logits_split[0], target[:split_size]) + probs = [F.softmax(logits, dim=1) for logits in logits_split] + + # Clamp mixture distribution to avoid exploding KL divergence + logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log() + loss += self.alpha * sum([F.kl_div( + logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs) + return loss diff --git a/timm/models/__init__.py b/timm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06217e185741edae2b4f22a40a51d04465e63ab7 --- /dev/null +++ b/timm/models/__init__.py @@ -0,0 +1,53 @@ +from .byoanet import * +from .byobnet import * +from .cait import * +from .coat import * +from .convit import * +from .cspnet import * +from .densenet import * +from .dla import * +from .dpn import * +from .efficientnet import * +from .ghostnet import * +from .gluon_resnet import * +from .gluon_xception import * +from .hardcorenas import * +from .hrnet import * +from .inception_resnet_v2 import * +from .inception_v3 import * +from .inception_v4 import * +from .levit import * +from .mlp_mixer import * +from .mobilenetv3 import * +from .nasnet import * +from .nfnet import * +from .pit import * +from .pnasnet import * +from .regnet import * +from .res2net import * +from .resnest import * +from .resnet import * +from .resnetv2 import * +from .rexnet import * +from .selecsls import * +from .senet import * +from .sknet import * +from .swin_transformer import * +from .tnt import * +from .tresnet import * +from .vgg import * +from .visformer import * +from .vision_transformer import * +from .vision_transformer_hybrid import * +from .vovnet import * +from .xception import * +from .xception_aligned import * +from .twins import * + +from .factory import create_model, split_model_name, safe_model_name +from .helpers import load_checkpoint, resume_checkpoint, model_parameters +from .layers import TestTimePoolHead, apply_test_time_pool +from .layers import convert_splitbn_model +from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ + has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained diff --git a/timm/models/__pycache__/__init__.cpython-310.pyc b/timm/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d42b15ee4512dc661ca3ce6ffa6d89c239ec596 Binary files /dev/null and b/timm/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/timm/models/__pycache__/__init__.cpython-311.pyc b/timm/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aa9e873d9836c96ed30b904993cb04dfd5558d9 Binary files /dev/null and b/timm/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/timm/models/__pycache__/__init__.cpython-313.pyc b/timm/models/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d79beb8a774c05c2b70c2f1cffdf2483bde003fc Binary files /dev/null and b/timm/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/timm/models/__pycache__/__init__.cpython-38.pyc b/timm/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfebaa65e2066eb1304cfb021e77e78eb4cc84e8 Binary files /dev/null and b/timm/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/models/__pycache__/beit.cpython-38.pyc b/timm/models/__pycache__/beit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6484aad693d1cf1e181483f466f0646c24dfe05e Binary files /dev/null and b/timm/models/__pycache__/beit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/byoanet.cpython-310.pyc b/timm/models/__pycache__/byoanet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..550fe01daffe0650a50df4bf767cfd0c61a75812 Binary files /dev/null and b/timm/models/__pycache__/byoanet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/byoanet.cpython-311.pyc b/timm/models/__pycache__/byoanet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c32979c94c0dd0d8cdf8a56d255b5437242fc887 Binary files /dev/null and b/timm/models/__pycache__/byoanet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/byoanet.cpython-313.pyc b/timm/models/__pycache__/byoanet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca8c17870203033184ce5aa16d55766461c3efcb Binary files /dev/null and b/timm/models/__pycache__/byoanet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/byoanet.cpython-38.pyc b/timm/models/__pycache__/byoanet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8862fe3a411fc02b46ec7cf83d251af484de4886 Binary files /dev/null and b/timm/models/__pycache__/byoanet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/byobnet.cpython-310.pyc b/timm/models/__pycache__/byobnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c944a0280f366763372c80ef6e3d1fe4be70013a Binary files /dev/null and b/timm/models/__pycache__/byobnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/byobnet.cpython-311.pyc b/timm/models/__pycache__/byobnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a999982182905f502e6046e5fe136642538761f9 Binary files /dev/null and b/timm/models/__pycache__/byobnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/byobnet.cpython-313.pyc b/timm/models/__pycache__/byobnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2da73b35d0261e516ca1286864f5d7b36382784 Binary files /dev/null and b/timm/models/__pycache__/byobnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/byobnet.cpython-38.pyc b/timm/models/__pycache__/byobnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15a5717748fbc3bef03a572f47703d8f3c9f9607 Binary files /dev/null and b/timm/models/__pycache__/byobnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/cait.cpython-310.pyc b/timm/models/__pycache__/cait.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c682ff9825e6767a9b999ba50e7f2a60d8157f9 Binary files /dev/null and b/timm/models/__pycache__/cait.cpython-310.pyc differ diff --git a/timm/models/__pycache__/cait.cpython-311.pyc b/timm/models/__pycache__/cait.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c7cea1f7d15443a729b53c0611ea6094ad457cf Binary files /dev/null and b/timm/models/__pycache__/cait.cpython-311.pyc differ diff --git a/timm/models/__pycache__/cait.cpython-313.pyc b/timm/models/__pycache__/cait.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0807538ed76481e6cccd4f1367b3a599779e8d74 Binary files /dev/null and b/timm/models/__pycache__/cait.cpython-313.pyc differ diff --git a/timm/models/__pycache__/cait.cpython-38.pyc b/timm/models/__pycache__/cait.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43c40f3eee4c8ebfcdf9bb3b157ad5a1a81b0996 Binary files /dev/null and b/timm/models/__pycache__/cait.cpython-38.pyc differ diff --git a/timm/models/__pycache__/coat.cpython-310.pyc b/timm/models/__pycache__/coat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78cd71092f0179e4fa9c93d7b242f261c2079eda Binary files /dev/null and b/timm/models/__pycache__/coat.cpython-310.pyc differ diff --git a/timm/models/__pycache__/coat.cpython-311.pyc b/timm/models/__pycache__/coat.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1fecd26262a0e92793d1fa7af57e8e4c5f21937 Binary files /dev/null and b/timm/models/__pycache__/coat.cpython-311.pyc differ diff --git a/timm/models/__pycache__/coat.cpython-313.pyc b/timm/models/__pycache__/coat.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c017694b08dcef1f107a01f3df4340bf52cf61c3 Binary files /dev/null and b/timm/models/__pycache__/coat.cpython-313.pyc differ diff --git a/timm/models/__pycache__/coat.cpython-38.pyc b/timm/models/__pycache__/coat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3785c8e500501830bde882c3c34c81e2e823fe4 Binary files /dev/null and b/timm/models/__pycache__/coat.cpython-38.pyc differ diff --git a/timm/models/__pycache__/convit.cpython-310.pyc b/timm/models/__pycache__/convit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35dc051440b6f1f1b768b8370da317b6cae25552 Binary files /dev/null and b/timm/models/__pycache__/convit.cpython-310.pyc differ diff --git a/timm/models/__pycache__/convit.cpython-311.pyc b/timm/models/__pycache__/convit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..655476a20893b1ce54544ef400796876c2790f3f Binary files /dev/null and b/timm/models/__pycache__/convit.cpython-311.pyc differ diff --git a/timm/models/__pycache__/convit.cpython-313.pyc b/timm/models/__pycache__/convit.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..647b374e65915c4ab9426df95898b87cc5e58cec Binary files /dev/null and b/timm/models/__pycache__/convit.cpython-313.pyc differ diff --git a/timm/models/__pycache__/convit.cpython-38.pyc b/timm/models/__pycache__/convit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50ecb8939d7334dfe7c0ddcc1be965e982a33bd8 Binary files /dev/null and b/timm/models/__pycache__/convit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/convmixer.cpython-38.pyc b/timm/models/__pycache__/convmixer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..760f3615c6d0d96628a37e049df5da24b0044276 Binary files /dev/null and b/timm/models/__pycache__/convmixer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/convnext.cpython-38.pyc b/timm/models/__pycache__/convnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4380aa4687ea6109b25b2d03d65981e5ddb4826d Binary files /dev/null and b/timm/models/__pycache__/convnext.cpython-38.pyc differ diff --git a/timm/models/__pycache__/crossvit.cpython-38.pyc b/timm/models/__pycache__/crossvit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6f200eafff55e402aa1b1cee5478c7eb4adcd4f Binary files /dev/null and b/timm/models/__pycache__/crossvit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/cspnet.cpython-310.pyc b/timm/models/__pycache__/cspnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e6b345f39da692a92c031bff75fe2b6751dfab4 Binary files /dev/null and b/timm/models/__pycache__/cspnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/cspnet.cpython-311.pyc b/timm/models/__pycache__/cspnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2d857fdf9e20fb84d049196daad9a4e220372b8 Binary files /dev/null and b/timm/models/__pycache__/cspnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/cspnet.cpython-313.pyc b/timm/models/__pycache__/cspnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dfd5ccf7ceded69b1918a4ff67cc487a599fd8c Binary files /dev/null and b/timm/models/__pycache__/cspnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/cspnet.cpython-38.pyc b/timm/models/__pycache__/cspnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ab394566c4eb55d8c4cbb98e0fe359cac6cc1ad Binary files /dev/null and b/timm/models/__pycache__/cspnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/densenet.cpython-310.pyc b/timm/models/__pycache__/densenet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5650c240f73532670bae616dc1bd384aa36c9990 Binary files /dev/null and b/timm/models/__pycache__/densenet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/densenet.cpython-311.pyc b/timm/models/__pycache__/densenet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c56bdc049fd765d54e093bb218c993ce7e70a23 Binary files /dev/null and b/timm/models/__pycache__/densenet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/densenet.cpython-313.pyc b/timm/models/__pycache__/densenet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e669f8f1cddb4d74e09b64f32ce1a0035b1312 Binary files /dev/null and b/timm/models/__pycache__/densenet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/densenet.cpython-38.pyc b/timm/models/__pycache__/densenet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7004ffafa81090abc846fa8cf8b9294fb99c55b5 Binary files /dev/null and b/timm/models/__pycache__/densenet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/dla.cpython-310.pyc b/timm/models/__pycache__/dla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0706f556c4eb7655878dc6033135230aa14745e Binary files /dev/null and b/timm/models/__pycache__/dla.cpython-310.pyc differ diff --git a/timm/models/__pycache__/dla.cpython-311.pyc b/timm/models/__pycache__/dla.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed73ce04140b26d5eac7345314cd87eb86f70e0d Binary files /dev/null and b/timm/models/__pycache__/dla.cpython-311.pyc differ diff --git a/timm/models/__pycache__/dla.cpython-313.pyc b/timm/models/__pycache__/dla.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b34229738befe68a79f4228ca3cd361f3cb0af Binary files /dev/null and b/timm/models/__pycache__/dla.cpython-313.pyc differ diff --git a/timm/models/__pycache__/dla.cpython-38.pyc b/timm/models/__pycache__/dla.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e6c8e4fc508ef770611dc5ec31e6468228d1ae9 Binary files /dev/null and b/timm/models/__pycache__/dla.cpython-38.pyc differ diff --git a/timm/models/__pycache__/dpn.cpython-310.pyc b/timm/models/__pycache__/dpn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3861d4ba9c1942ec28355585f7cc0e8112d79f89 Binary files /dev/null and b/timm/models/__pycache__/dpn.cpython-310.pyc differ diff --git a/timm/models/__pycache__/dpn.cpython-311.pyc b/timm/models/__pycache__/dpn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d3ae527154f57e4c427b78692f7ecdee3361814 Binary files /dev/null and b/timm/models/__pycache__/dpn.cpython-311.pyc differ diff --git a/timm/models/__pycache__/dpn.cpython-313.pyc b/timm/models/__pycache__/dpn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d789d7556c12df888079904a7f032a0dcf8d1ce Binary files /dev/null and b/timm/models/__pycache__/dpn.cpython-313.pyc differ diff --git a/timm/models/__pycache__/dpn.cpython-38.pyc b/timm/models/__pycache__/dpn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5abb2e3228b9be1c24b93059515fdde4da21702c Binary files /dev/null and b/timm/models/__pycache__/dpn.cpython-38.pyc differ diff --git a/timm/models/__pycache__/efficientnet.cpython-310.pyc b/timm/models/__pycache__/efficientnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6fce46430f71d1cc33dd34a13eb270bef44ea1 Binary files /dev/null and b/timm/models/__pycache__/efficientnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/efficientnet.cpython-311.pyc b/timm/models/__pycache__/efficientnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da72da16b8ee45b5823abd63cbbc90053a86ba1 Binary files /dev/null and b/timm/models/__pycache__/efficientnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/efficientnet.cpython-313.pyc b/timm/models/__pycache__/efficientnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29397a41962bb2e9b58605290e4703b2d87dd6e1 Binary files /dev/null and b/timm/models/__pycache__/efficientnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/efficientnet.cpython-38.pyc b/timm/models/__pycache__/efficientnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34489d71ab2978f8ca1f1b97558fe94cf39f40d8 Binary files /dev/null and b/timm/models/__pycache__/efficientnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/efficientnet_blocks.cpython-310.pyc b/timm/models/__pycache__/efficientnet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9c9e7c95bd7f979e92939fcb26d8d7efad3fcb2 Binary files /dev/null and b/timm/models/__pycache__/efficientnet_blocks.cpython-310.pyc differ diff --git a/timm/models/__pycache__/efficientnet_blocks.cpython-311.pyc b/timm/models/__pycache__/efficientnet_blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..440f541cd09adbf0b513816cf9b325e763f3f248 Binary files /dev/null and b/timm/models/__pycache__/efficientnet_blocks.cpython-311.pyc differ diff --git a/timm/models/__pycache__/efficientnet_blocks.cpython-313.pyc b/timm/models/__pycache__/efficientnet_blocks.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..570be505f104acd76f26790e11e9dcdb7af98d0c Binary files /dev/null and b/timm/models/__pycache__/efficientnet_blocks.cpython-313.pyc differ diff --git a/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc b/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7f66e50c157b373162603d88cba1492013ca76a Binary files /dev/null and b/timm/models/__pycache__/efficientnet_blocks.cpython-38.pyc differ diff --git a/timm/models/__pycache__/efficientnet_builder.cpython-310.pyc b/timm/models/__pycache__/efficientnet_builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0633c39cfae3af859a75bd0414d29286cc36609d Binary files /dev/null and b/timm/models/__pycache__/efficientnet_builder.cpython-310.pyc differ diff --git a/timm/models/__pycache__/efficientnet_builder.cpython-311.pyc b/timm/models/__pycache__/efficientnet_builder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8497643b8bb0f966dc33694015ed13bf30304864 Binary files /dev/null and b/timm/models/__pycache__/efficientnet_builder.cpython-311.pyc differ diff --git a/timm/models/__pycache__/efficientnet_builder.cpython-313.pyc b/timm/models/__pycache__/efficientnet_builder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cce004c28e53af8379132b9c15c7c5841fc58dfe Binary files /dev/null and b/timm/models/__pycache__/efficientnet_builder.cpython-313.pyc differ diff --git a/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc b/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b4c5da9a27715b5548a9a3cb879ceaf9f9e3f5f Binary files /dev/null and b/timm/models/__pycache__/efficientnet_builder.cpython-38.pyc differ diff --git a/timm/models/__pycache__/factory.cpython-310.pyc b/timm/models/__pycache__/factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89117b9950d89b286a6740258d66ac01c48c31ca Binary files /dev/null and b/timm/models/__pycache__/factory.cpython-310.pyc differ diff --git a/timm/models/__pycache__/factory.cpython-311.pyc b/timm/models/__pycache__/factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f864084187fa888a905fa7865d2e22da367f4e81 Binary files /dev/null and b/timm/models/__pycache__/factory.cpython-311.pyc differ diff --git a/timm/models/__pycache__/factory.cpython-313.pyc b/timm/models/__pycache__/factory.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02585f08ea69078a43ab4f15738624ba1de3cba4 Binary files /dev/null and b/timm/models/__pycache__/factory.cpython-313.pyc differ diff --git a/timm/models/__pycache__/factory.cpython-38.pyc b/timm/models/__pycache__/factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..748b9e7f6c02bc8fd2aab58fc3209c5e1321ac49 Binary files /dev/null and b/timm/models/__pycache__/factory.cpython-38.pyc differ diff --git a/timm/models/__pycache__/features.cpython-310.pyc b/timm/models/__pycache__/features.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f00b9e59cc4154b12a7b6f25eeb00dfbcee8b6 Binary files /dev/null and b/timm/models/__pycache__/features.cpython-310.pyc differ diff --git a/timm/models/__pycache__/features.cpython-311.pyc b/timm/models/__pycache__/features.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87268635d9c18cc26ecb8b2d3eef9958eb640d5d Binary files /dev/null and b/timm/models/__pycache__/features.cpython-311.pyc differ diff --git a/timm/models/__pycache__/features.cpython-313.pyc b/timm/models/__pycache__/features.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c92c0377036ffd0b8baf8b774881a2a794ef3043 Binary files /dev/null and b/timm/models/__pycache__/features.cpython-313.pyc differ diff --git a/timm/models/__pycache__/features.cpython-38.pyc b/timm/models/__pycache__/features.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fcba43ee14fa0744439c79f08c470cc7f76a664 Binary files /dev/null and b/timm/models/__pycache__/features.cpython-38.pyc differ diff --git a/timm/models/__pycache__/fx_features.cpython-38.pyc b/timm/models/__pycache__/fx_features.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5a6ebfa1a9926d71178d1b79def6891eeed62fb Binary files /dev/null and b/timm/models/__pycache__/fx_features.cpython-38.pyc differ diff --git a/timm/models/__pycache__/ghostnet.cpython-310.pyc b/timm/models/__pycache__/ghostnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd88fba3bd854b2432084786e18dab19167c726 Binary files /dev/null and b/timm/models/__pycache__/ghostnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/ghostnet.cpython-311.pyc b/timm/models/__pycache__/ghostnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88383c44aed1fc5e8822cd4ca897e8d53361f1f6 Binary files /dev/null and b/timm/models/__pycache__/ghostnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/ghostnet.cpython-313.pyc b/timm/models/__pycache__/ghostnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dfe5cd3d0bc09591391b732fad065c4cd7fe805 Binary files /dev/null and b/timm/models/__pycache__/ghostnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/ghostnet.cpython-38.pyc b/timm/models/__pycache__/ghostnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7983a461f8047ec7a87a0468d73dfb00d85bdea8 Binary files /dev/null and b/timm/models/__pycache__/ghostnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/gluon_resnet.cpython-310.pyc b/timm/models/__pycache__/gluon_resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68ce2f52c108c41222a05237aa5a2fbd0765f9ee Binary files /dev/null and b/timm/models/__pycache__/gluon_resnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/gluon_resnet.cpython-311.pyc b/timm/models/__pycache__/gluon_resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45ebc4d13ae621e0ef602e193271e74ed481a243 Binary files /dev/null and b/timm/models/__pycache__/gluon_resnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/gluon_resnet.cpython-313.pyc b/timm/models/__pycache__/gluon_resnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..570501545fe9946d275f87b8f1b588dce1d1d62e Binary files /dev/null and b/timm/models/__pycache__/gluon_resnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/gluon_resnet.cpython-38.pyc b/timm/models/__pycache__/gluon_resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9d21005a8dc3179b4e8e04e7b0327399a3e5ce6 Binary files /dev/null and b/timm/models/__pycache__/gluon_resnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/gluon_xception.cpython-310.pyc b/timm/models/__pycache__/gluon_xception.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80da75995fe342efa9bcdea0844488ea2e7e9b33 Binary files /dev/null and b/timm/models/__pycache__/gluon_xception.cpython-310.pyc differ diff --git a/timm/models/__pycache__/gluon_xception.cpython-311.pyc b/timm/models/__pycache__/gluon_xception.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6beba9042fc86b1f5120db17dbbb70a9f33dcfea Binary files /dev/null and b/timm/models/__pycache__/gluon_xception.cpython-311.pyc differ diff --git a/timm/models/__pycache__/gluon_xception.cpython-313.pyc b/timm/models/__pycache__/gluon_xception.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f2ed6037474a7958386829bc7dc15dd7994e4d Binary files /dev/null and b/timm/models/__pycache__/gluon_xception.cpython-313.pyc differ diff --git a/timm/models/__pycache__/gluon_xception.cpython-38.pyc b/timm/models/__pycache__/gluon_xception.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52b0e8ea611ea4c94d4818a6e9a663b5ef6194e9 Binary files /dev/null and b/timm/models/__pycache__/gluon_xception.cpython-38.pyc differ diff --git a/timm/models/__pycache__/hardcorenas.cpython-310.pyc b/timm/models/__pycache__/hardcorenas.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c52b09f2a4791845be30d3ec4d9b4b567f384d Binary files /dev/null and b/timm/models/__pycache__/hardcorenas.cpython-310.pyc differ diff --git a/timm/models/__pycache__/hardcorenas.cpython-311.pyc b/timm/models/__pycache__/hardcorenas.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b29c3305a76f5533697717dabb9a1c53d8c25e2 Binary files /dev/null and b/timm/models/__pycache__/hardcorenas.cpython-311.pyc differ diff --git a/timm/models/__pycache__/hardcorenas.cpython-313.pyc b/timm/models/__pycache__/hardcorenas.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4779fc5831e4d3e41f918599a0b2a9751ecad76 Binary files /dev/null and b/timm/models/__pycache__/hardcorenas.cpython-313.pyc differ diff --git a/timm/models/__pycache__/hardcorenas.cpython-38.pyc b/timm/models/__pycache__/hardcorenas.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eb0e9a0f7d47ddd2a6fc3c90d555f2d09627726 Binary files /dev/null and b/timm/models/__pycache__/hardcorenas.cpython-38.pyc differ diff --git a/timm/models/__pycache__/helpers.cpython-310.pyc b/timm/models/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a5dab3af6316171796835cfd49c3703eadb8d62 Binary files /dev/null and b/timm/models/__pycache__/helpers.cpython-310.pyc differ diff --git a/timm/models/__pycache__/helpers.cpython-311.pyc b/timm/models/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..441ad1863e9a00b3cab73f6be3fc09c1ae88c3fa Binary files /dev/null and b/timm/models/__pycache__/helpers.cpython-311.pyc differ diff --git a/timm/models/__pycache__/helpers.cpython-313.pyc b/timm/models/__pycache__/helpers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9d2b553faa7d440c7ffbc3611bacf05dd504200 Binary files /dev/null and b/timm/models/__pycache__/helpers.cpython-313.pyc differ diff --git a/timm/models/__pycache__/helpers.cpython-38.pyc b/timm/models/__pycache__/helpers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d733ac6c18ccca502c9685b1c4376265a0822742 Binary files /dev/null and b/timm/models/__pycache__/helpers.cpython-38.pyc differ diff --git a/timm/models/__pycache__/hrnet.cpython-310.pyc b/timm/models/__pycache__/hrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7955b381084e266067330f2c6d6f0025850d205b Binary files /dev/null and b/timm/models/__pycache__/hrnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/hrnet.cpython-311.pyc b/timm/models/__pycache__/hrnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d578b1b03fd2700bcf9bec6ccc52a5d1db81775 Binary files /dev/null and b/timm/models/__pycache__/hrnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/hrnet.cpython-313.pyc b/timm/models/__pycache__/hrnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26aef8921909f3653c3ca423a44993599b473b3d Binary files /dev/null and b/timm/models/__pycache__/hrnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/hrnet.cpython-38.pyc b/timm/models/__pycache__/hrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f4601db34cf457e1f5c6c17618c38955cb880ce Binary files /dev/null and b/timm/models/__pycache__/hrnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/hub.cpython-310.pyc b/timm/models/__pycache__/hub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0ae55eeebb0f7d1eab5cafa824b7fc6577eb7c Binary files /dev/null and b/timm/models/__pycache__/hub.cpython-310.pyc differ diff --git a/timm/models/__pycache__/hub.cpython-311.pyc b/timm/models/__pycache__/hub.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78b1e3c2ba89e49dbf87c53df478c2490544804d Binary files /dev/null and b/timm/models/__pycache__/hub.cpython-311.pyc differ diff --git a/timm/models/__pycache__/hub.cpython-313.pyc b/timm/models/__pycache__/hub.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef71f14cce6e309bbf53d7478f3e4b3bfe57944c Binary files /dev/null and b/timm/models/__pycache__/hub.cpython-313.pyc differ diff --git a/timm/models/__pycache__/hub.cpython-38.pyc b/timm/models/__pycache__/hub.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bd25550edd9f4acddb02e9eecf798c6c515d9c7 Binary files /dev/null and b/timm/models/__pycache__/hub.cpython-38.pyc differ diff --git a/timm/models/__pycache__/inception_resnet_v2.cpython-310.pyc b/timm/models/__pycache__/inception_resnet_v2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2e6a3e68a67445bf4bc078351c8bb4269ef58b4 Binary files /dev/null and b/timm/models/__pycache__/inception_resnet_v2.cpython-310.pyc differ diff --git a/timm/models/__pycache__/inception_resnet_v2.cpython-311.pyc b/timm/models/__pycache__/inception_resnet_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92fbfed80b2042ee24cf7a8a84ec2acce5c42965 Binary files /dev/null and b/timm/models/__pycache__/inception_resnet_v2.cpython-311.pyc differ diff --git a/timm/models/__pycache__/inception_resnet_v2.cpython-313.pyc b/timm/models/__pycache__/inception_resnet_v2.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26799aa52e5374008f4ae1ccab7e9d4fbd4c1faa Binary files /dev/null and b/timm/models/__pycache__/inception_resnet_v2.cpython-313.pyc differ diff --git a/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc b/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a00be05a0655a4287707ae1c8c898123b4fc09ab Binary files /dev/null and b/timm/models/__pycache__/inception_resnet_v2.cpython-38.pyc differ diff --git a/timm/models/__pycache__/inception_v3.cpython-310.pyc b/timm/models/__pycache__/inception_v3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4f104aae13612be7d1a067b979ca70f358618b5 Binary files /dev/null and b/timm/models/__pycache__/inception_v3.cpython-310.pyc differ diff --git a/timm/models/__pycache__/inception_v3.cpython-311.pyc b/timm/models/__pycache__/inception_v3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb530d98bc4cf534d3dacd25a5b11617a7e910f Binary files /dev/null and b/timm/models/__pycache__/inception_v3.cpython-311.pyc differ diff --git a/timm/models/__pycache__/inception_v3.cpython-313.pyc b/timm/models/__pycache__/inception_v3.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16a742bbecaf75be3544b090d17134f651caf466 Binary files /dev/null and b/timm/models/__pycache__/inception_v3.cpython-313.pyc differ diff --git a/timm/models/__pycache__/inception_v3.cpython-38.pyc b/timm/models/__pycache__/inception_v3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..763c5d587e97fda15c3c757db31f0aceeb33e676 Binary files /dev/null and b/timm/models/__pycache__/inception_v3.cpython-38.pyc differ diff --git a/timm/models/__pycache__/inception_v4.cpython-310.pyc b/timm/models/__pycache__/inception_v4.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..058ed5874ffe5a76cdee457350b5c0882d018f1c Binary files /dev/null and b/timm/models/__pycache__/inception_v4.cpython-310.pyc differ diff --git a/timm/models/__pycache__/inception_v4.cpython-311.pyc b/timm/models/__pycache__/inception_v4.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1b69a5b39fb356a2de0bff085728637d422ddea Binary files /dev/null and b/timm/models/__pycache__/inception_v4.cpython-311.pyc differ diff --git a/timm/models/__pycache__/inception_v4.cpython-313.pyc b/timm/models/__pycache__/inception_v4.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34cf9b84031d86b6b1f58b7382ede285d142874a Binary files /dev/null and b/timm/models/__pycache__/inception_v4.cpython-313.pyc differ diff --git a/timm/models/__pycache__/inception_v4.cpython-38.pyc b/timm/models/__pycache__/inception_v4.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f326dd4e5abe56aca90c60792f5bc4b999946e7 Binary files /dev/null and b/timm/models/__pycache__/inception_v4.cpython-38.pyc differ diff --git a/timm/models/__pycache__/levit.cpython-310.pyc b/timm/models/__pycache__/levit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2837c24d8220183cdb90906284ccd8d944d172ac Binary files /dev/null and b/timm/models/__pycache__/levit.cpython-310.pyc differ diff --git a/timm/models/__pycache__/levit.cpython-311.pyc b/timm/models/__pycache__/levit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e80be955c1b43617b41aca23d8c8bfe6f4d6b8f0 Binary files /dev/null and b/timm/models/__pycache__/levit.cpython-311.pyc differ diff --git a/timm/models/__pycache__/levit.cpython-313.pyc b/timm/models/__pycache__/levit.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5358fe6b9671eb545c29deb8f7eb1f5cf4d0b6fb Binary files /dev/null and b/timm/models/__pycache__/levit.cpython-313.pyc differ diff --git a/timm/models/__pycache__/levit.cpython-38.pyc b/timm/models/__pycache__/levit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e3ebb762ef6ad99d516350037a95587d4403a28 Binary files /dev/null and b/timm/models/__pycache__/levit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/mlp_mixer.cpython-310.pyc b/timm/models/__pycache__/mlp_mixer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..076ef413ed3dc9d756c30de6bdb41f1cb932e85d Binary files /dev/null and b/timm/models/__pycache__/mlp_mixer.cpython-310.pyc differ diff --git a/timm/models/__pycache__/mlp_mixer.cpython-311.pyc b/timm/models/__pycache__/mlp_mixer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..860dca9c01b91278817ca78f1948695f9b59391c Binary files /dev/null and b/timm/models/__pycache__/mlp_mixer.cpython-311.pyc differ diff --git a/timm/models/__pycache__/mlp_mixer.cpython-313.pyc b/timm/models/__pycache__/mlp_mixer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..346da2d45c1dc78f404699947e299f4d4418a38d Binary files /dev/null and b/timm/models/__pycache__/mlp_mixer.cpython-313.pyc differ diff --git a/timm/models/__pycache__/mlp_mixer.cpython-38.pyc b/timm/models/__pycache__/mlp_mixer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69f06584dce1293e7ca26ae3ebcfb40250d09571 Binary files /dev/null and b/timm/models/__pycache__/mlp_mixer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/mobilenetv3.cpython-310.pyc b/timm/models/__pycache__/mobilenetv3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4139d9e576f9f0d7b128fc6b7a7d9370bf166679 Binary files /dev/null and b/timm/models/__pycache__/mobilenetv3.cpython-310.pyc differ diff --git a/timm/models/__pycache__/mobilenetv3.cpython-311.pyc b/timm/models/__pycache__/mobilenetv3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d7cd58804694c0616df66ec6a8e158a44fa181 Binary files /dev/null and b/timm/models/__pycache__/mobilenetv3.cpython-311.pyc differ diff --git a/timm/models/__pycache__/mobilenetv3.cpython-313.pyc b/timm/models/__pycache__/mobilenetv3.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e3ea17a43adbb1593f7887a5504da2bcfcfd0cb Binary files /dev/null and b/timm/models/__pycache__/mobilenetv3.cpython-313.pyc differ diff --git a/timm/models/__pycache__/mobilenetv3.cpython-38.pyc b/timm/models/__pycache__/mobilenetv3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b436e5be058a16f486e80352bb6b2df263bf4d0 Binary files /dev/null and b/timm/models/__pycache__/mobilenetv3.cpython-38.pyc differ diff --git a/timm/models/__pycache__/nasnet.cpython-310.pyc b/timm/models/__pycache__/nasnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff2772451b82d10ae693a72258b793557f5fd19d Binary files /dev/null and b/timm/models/__pycache__/nasnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/nasnet.cpython-311.pyc b/timm/models/__pycache__/nasnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a0f93c0a7a451e8d804879176bd0458b8f2bdd Binary files /dev/null and b/timm/models/__pycache__/nasnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/nasnet.cpython-313.pyc b/timm/models/__pycache__/nasnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a657bae1b076cc420c9d5b83d86c18b4c7ac2cc0 Binary files /dev/null and b/timm/models/__pycache__/nasnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/nasnet.cpython-38.pyc b/timm/models/__pycache__/nasnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0497574e435e45a701d0be86983fb574fe615e60 Binary files /dev/null and b/timm/models/__pycache__/nasnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/nest.cpython-38.pyc b/timm/models/__pycache__/nest.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b19b8383bf6050aa5cdd702f1c0e9195d9ecd27 Binary files /dev/null and b/timm/models/__pycache__/nest.cpython-38.pyc differ diff --git a/timm/models/__pycache__/nfnet.cpython-310.pyc b/timm/models/__pycache__/nfnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b4ce85be788ff024a7dcced2dc675ddbfe0ff66 Binary files /dev/null and b/timm/models/__pycache__/nfnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/nfnet.cpython-311.pyc b/timm/models/__pycache__/nfnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de170238166ee72938520f635cb5fa055137781 Binary files /dev/null and b/timm/models/__pycache__/nfnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/nfnet.cpython-313.pyc b/timm/models/__pycache__/nfnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6991cf6a06b5a7555b9d432e6fa8a0a628f5a0f8 Binary files /dev/null and b/timm/models/__pycache__/nfnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/nfnet.cpython-38.pyc b/timm/models/__pycache__/nfnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..239a5489ab30d5843d68046194bd1b84993197f0 Binary files /dev/null and b/timm/models/__pycache__/nfnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/pit.cpython-310.pyc b/timm/models/__pycache__/pit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab67df84babcf4c5d60ed34a59d6bd589195167b Binary files /dev/null and b/timm/models/__pycache__/pit.cpython-310.pyc differ diff --git a/timm/models/__pycache__/pit.cpython-311.pyc b/timm/models/__pycache__/pit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e045d183ae0f7dd8cae75c6c7b80d57cae443bda Binary files /dev/null and b/timm/models/__pycache__/pit.cpython-311.pyc differ diff --git a/timm/models/__pycache__/pit.cpython-313.pyc b/timm/models/__pycache__/pit.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3cf660bc376f813f319271481907c9025909d95 Binary files /dev/null and b/timm/models/__pycache__/pit.cpython-313.pyc differ diff --git a/timm/models/__pycache__/pit.cpython-38.pyc b/timm/models/__pycache__/pit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87396dce44d9098d4abff94c5950470564b497a7 Binary files /dev/null and b/timm/models/__pycache__/pit.cpython-38.pyc differ diff --git a/timm/models/__pycache__/pnasnet.cpython-310.pyc b/timm/models/__pycache__/pnasnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfebcfb8d3b96dcead030f1307e501ef01d067da Binary files /dev/null and b/timm/models/__pycache__/pnasnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/pnasnet.cpython-311.pyc b/timm/models/__pycache__/pnasnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7645fad6197b7f36f508284861c54f2a3a7c8342 Binary files /dev/null and b/timm/models/__pycache__/pnasnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/pnasnet.cpython-313.pyc b/timm/models/__pycache__/pnasnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84df90bad94cb2d4f27177f3d38a4c6ea3704e04 Binary files /dev/null and b/timm/models/__pycache__/pnasnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/pnasnet.cpython-38.pyc b/timm/models/__pycache__/pnasnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3d9c461bb46cbb0bf37e96618f587e6d4fcb27b Binary files /dev/null and b/timm/models/__pycache__/pnasnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/registry.cpython-310.pyc b/timm/models/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab346bbc4e9eb499087d036f110643474bf6345b Binary files /dev/null and b/timm/models/__pycache__/registry.cpython-310.pyc differ diff --git a/timm/models/__pycache__/registry.cpython-311.pyc b/timm/models/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed8674f217f4fd48faaa3f7783057f4b9a0d666a Binary files /dev/null and b/timm/models/__pycache__/registry.cpython-311.pyc differ diff --git a/timm/models/__pycache__/registry.cpython-313.pyc b/timm/models/__pycache__/registry.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7daf09ca02da7944b45467bf3907a820847b81ff Binary files /dev/null and b/timm/models/__pycache__/registry.cpython-313.pyc differ diff --git a/timm/models/__pycache__/registry.cpython-38.pyc b/timm/models/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ae8634725a030f876846f1c71b32e0400758da9 Binary files /dev/null and b/timm/models/__pycache__/registry.cpython-38.pyc differ diff --git a/timm/models/__pycache__/regnet.cpython-310.pyc b/timm/models/__pycache__/regnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7f6539aea0f086965b0e54118b9c7bbfd226e5d Binary files /dev/null and b/timm/models/__pycache__/regnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/regnet.cpython-311.pyc b/timm/models/__pycache__/regnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b7248bd4113f0caa67c79822a1c095faec2209a Binary files /dev/null and b/timm/models/__pycache__/regnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/regnet.cpython-313.pyc b/timm/models/__pycache__/regnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de66ee6f74ae9804ad3d437055a6b4705f4adc47 Binary files /dev/null and b/timm/models/__pycache__/regnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/regnet.cpython-38.pyc b/timm/models/__pycache__/regnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..084c3fdca3be07cd3afaba0f1b7f330364a39644 Binary files /dev/null and b/timm/models/__pycache__/regnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/res2net.cpython-310.pyc b/timm/models/__pycache__/res2net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..391c8d030dad1fae00a9da9444aca67cb5281a4c Binary files /dev/null and b/timm/models/__pycache__/res2net.cpython-310.pyc differ diff --git a/timm/models/__pycache__/res2net.cpython-311.pyc b/timm/models/__pycache__/res2net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af6e4c2411deb4c408e9037a06b65e0b4f04b8d4 Binary files /dev/null and b/timm/models/__pycache__/res2net.cpython-311.pyc differ diff --git a/timm/models/__pycache__/res2net.cpython-313.pyc b/timm/models/__pycache__/res2net.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20a318a8c61d9b4c848028a2486d9f6ccf17587f Binary files /dev/null and b/timm/models/__pycache__/res2net.cpython-313.pyc differ diff --git a/timm/models/__pycache__/res2net.cpython-38.pyc b/timm/models/__pycache__/res2net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..211bcfc37c21dd8b74b032f8a8b637ad1f55ea35 Binary files /dev/null and b/timm/models/__pycache__/res2net.cpython-38.pyc differ diff --git a/timm/models/__pycache__/resnest.cpython-310.pyc b/timm/models/__pycache__/resnest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..523c9e5e25139ec194f6359724dd7223f87984a9 Binary files /dev/null and b/timm/models/__pycache__/resnest.cpython-310.pyc differ diff --git a/timm/models/__pycache__/resnest.cpython-311.pyc b/timm/models/__pycache__/resnest.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b31df616638899cc9577f3f111562707c9fc977 Binary files /dev/null and b/timm/models/__pycache__/resnest.cpython-311.pyc differ diff --git a/timm/models/__pycache__/resnest.cpython-313.pyc b/timm/models/__pycache__/resnest.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156da836bd0d4d33fa1c6eb4dbfdbc0d27e6905a Binary files /dev/null and b/timm/models/__pycache__/resnest.cpython-313.pyc differ diff --git a/timm/models/__pycache__/resnest.cpython-38.pyc b/timm/models/__pycache__/resnest.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5be652b4f14e1f0792b4edd33b8a51de1a252f91 Binary files /dev/null and b/timm/models/__pycache__/resnest.cpython-38.pyc differ diff --git a/timm/models/__pycache__/resnet.cpython-310.pyc b/timm/models/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..809120af2ec7a995f085ce4d3b7e92a23c1be0d6 Binary files /dev/null and b/timm/models/__pycache__/resnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/resnet.cpython-311.pyc b/timm/models/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e00638d3a5b2f40e456f563485af8b39556322b9 Binary files /dev/null and b/timm/models/__pycache__/resnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/resnet.cpython-313.pyc b/timm/models/__pycache__/resnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c11f2faf6d639e3583cd0750db007f003d0f2475 Binary files /dev/null and b/timm/models/__pycache__/resnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/resnet.cpython-38.pyc b/timm/models/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c50ab45e9460ac7155cfcbac35968d8d4115c5e4 Binary files /dev/null and b/timm/models/__pycache__/resnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/resnetv2.cpython-310.pyc b/timm/models/__pycache__/resnetv2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60483d2efa2bb23e23acd64d4c67be576f3beddd Binary files /dev/null and b/timm/models/__pycache__/resnetv2.cpython-310.pyc differ diff --git a/timm/models/__pycache__/resnetv2.cpython-311.pyc b/timm/models/__pycache__/resnetv2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c492b3286c9c1170ea5fb3e567111ba735a6b2a Binary files /dev/null and b/timm/models/__pycache__/resnetv2.cpython-311.pyc differ diff --git a/timm/models/__pycache__/resnetv2.cpython-313.pyc b/timm/models/__pycache__/resnetv2.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdb98a529793b52ebfeba9ad2db603a57134be17 Binary files /dev/null and b/timm/models/__pycache__/resnetv2.cpython-313.pyc differ diff --git a/timm/models/__pycache__/resnetv2.cpython-38.pyc b/timm/models/__pycache__/resnetv2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1f9a940949e47a840ac1cfbd0922459c12b8f42 Binary files /dev/null and b/timm/models/__pycache__/resnetv2.cpython-38.pyc differ diff --git a/timm/models/__pycache__/rexnet.cpython-310.pyc b/timm/models/__pycache__/rexnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2944f69953038ac1a16733da0c3dfc29e1a9b966 Binary files /dev/null and b/timm/models/__pycache__/rexnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/rexnet.cpython-311.pyc b/timm/models/__pycache__/rexnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0623fc39f99a6556d8f0c9ba0237e4f00375aff Binary files /dev/null and b/timm/models/__pycache__/rexnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/rexnet.cpython-313.pyc b/timm/models/__pycache__/rexnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71cdf0b6e15b8f0bc2f232eee2f0efbba06b2214 Binary files /dev/null and b/timm/models/__pycache__/rexnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/rexnet.cpython-38.pyc b/timm/models/__pycache__/rexnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e902be82cac0228f45d03f2b82d41893f50b4e9 Binary files /dev/null and b/timm/models/__pycache__/rexnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/selecsls.cpython-310.pyc b/timm/models/__pycache__/selecsls.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cc858e19e87d6f18013fab10f6eba2a8a8c1fa7 Binary files /dev/null and b/timm/models/__pycache__/selecsls.cpython-310.pyc differ diff --git a/timm/models/__pycache__/selecsls.cpython-311.pyc b/timm/models/__pycache__/selecsls.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89a38dba53f725116b92aa5f9e105edc853e60b8 Binary files /dev/null and b/timm/models/__pycache__/selecsls.cpython-311.pyc differ diff --git a/timm/models/__pycache__/selecsls.cpython-313.pyc b/timm/models/__pycache__/selecsls.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31963a718832c6288270c9cd1f93300c80f707f8 Binary files /dev/null and b/timm/models/__pycache__/selecsls.cpython-313.pyc differ diff --git a/timm/models/__pycache__/selecsls.cpython-38.pyc b/timm/models/__pycache__/selecsls.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9796617ba361395b92ccd8cce74ba8b873003fc Binary files /dev/null and b/timm/models/__pycache__/selecsls.cpython-38.pyc differ diff --git a/timm/models/__pycache__/senet.cpython-310.pyc b/timm/models/__pycache__/senet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..979b892203f1a684588d72aef0001646f4a5156a Binary files /dev/null and b/timm/models/__pycache__/senet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/senet.cpython-311.pyc b/timm/models/__pycache__/senet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a47c20f07d9fd6b03d3c33ee668cde25540aeafd Binary files /dev/null and b/timm/models/__pycache__/senet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/senet.cpython-313.pyc b/timm/models/__pycache__/senet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25a639bbe747978b77a0da877d596d88b8c3119c Binary files /dev/null and b/timm/models/__pycache__/senet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/senet.cpython-38.pyc b/timm/models/__pycache__/senet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9d2249f72c101529ed7ed1cb4855db265919b22 Binary files /dev/null and b/timm/models/__pycache__/senet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/sknet.cpython-310.pyc b/timm/models/__pycache__/sknet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c917b1a750243334153730097bf9a3413dd39c31 Binary files /dev/null and b/timm/models/__pycache__/sknet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/sknet.cpython-311.pyc b/timm/models/__pycache__/sknet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05ab5e619b8daffb8cc865b3134325a5debe2bc7 Binary files /dev/null and b/timm/models/__pycache__/sknet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/sknet.cpython-313.pyc b/timm/models/__pycache__/sknet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30f77ddb2051c6be0865058bbada0bbc34c473c4 Binary files /dev/null and b/timm/models/__pycache__/sknet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/sknet.cpython-38.pyc b/timm/models/__pycache__/sknet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d92eef6c2a76fb694af978c761477011587da0e3 Binary files /dev/null and b/timm/models/__pycache__/sknet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/swin_transformer.cpython-310.pyc b/timm/models/__pycache__/swin_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7475fd61dc4350c9db89cfada505331ba148e3c Binary files /dev/null and b/timm/models/__pycache__/swin_transformer.cpython-310.pyc differ diff --git a/timm/models/__pycache__/swin_transformer.cpython-311.pyc b/timm/models/__pycache__/swin_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..668f638dedab5ed77cb8792a3170560f65658828 Binary files /dev/null and b/timm/models/__pycache__/swin_transformer.cpython-311.pyc differ diff --git a/timm/models/__pycache__/swin_transformer.cpython-313.pyc b/timm/models/__pycache__/swin_transformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe5cea710c48cc05019aa49b55cd104760e45ff9 Binary files /dev/null and b/timm/models/__pycache__/swin_transformer.cpython-313.pyc differ diff --git a/timm/models/__pycache__/swin_transformer.cpython-38.pyc b/timm/models/__pycache__/swin_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d75fba7aa22f618acc3d931c452fdc7de2a868e Binary files /dev/null and b/timm/models/__pycache__/swin_transformer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/tnt.cpython-310.pyc b/timm/models/__pycache__/tnt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fca1d107ff9d6c03f130e775969cc54433ca7f8 Binary files /dev/null and b/timm/models/__pycache__/tnt.cpython-310.pyc differ diff --git a/timm/models/__pycache__/tnt.cpython-311.pyc b/timm/models/__pycache__/tnt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bc65d3f9b1725e23ae9d4191ced571fb88c1960 Binary files /dev/null and b/timm/models/__pycache__/tnt.cpython-311.pyc differ diff --git a/timm/models/__pycache__/tnt.cpython-313.pyc b/timm/models/__pycache__/tnt.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9321a5bd4b73c56680ea452f460007050126f64d Binary files /dev/null and b/timm/models/__pycache__/tnt.cpython-313.pyc differ diff --git a/timm/models/__pycache__/tnt.cpython-38.pyc b/timm/models/__pycache__/tnt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9377b5638c5ba6e1355bc6f96386009f94873c68 Binary files /dev/null and b/timm/models/__pycache__/tnt.cpython-38.pyc differ diff --git a/timm/models/__pycache__/tresnet.cpython-310.pyc b/timm/models/__pycache__/tresnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a94f31dec1ebc0296176e5e748002dbe645e46b Binary files /dev/null and b/timm/models/__pycache__/tresnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/tresnet.cpython-311.pyc b/timm/models/__pycache__/tresnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa80a343885c9138ebed3da5c01d01d01b5434d1 Binary files /dev/null and b/timm/models/__pycache__/tresnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/tresnet.cpython-313.pyc b/timm/models/__pycache__/tresnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f51783cff6172d9d2de78dad91a9f16a322598e Binary files /dev/null and b/timm/models/__pycache__/tresnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/tresnet.cpython-38.pyc b/timm/models/__pycache__/tresnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27e7c10c02b9810879b56b909b7371653161bd2a Binary files /dev/null and b/timm/models/__pycache__/tresnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/twins.cpython-310.pyc b/timm/models/__pycache__/twins.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d3525c677b27e3edf0424f73385672c873bacc9 Binary files /dev/null and b/timm/models/__pycache__/twins.cpython-310.pyc differ diff --git a/timm/models/__pycache__/twins.cpython-311.pyc b/timm/models/__pycache__/twins.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c37707cc0e4c9112404edadcbf294fba5d79ca29 Binary files /dev/null and b/timm/models/__pycache__/twins.cpython-311.pyc differ diff --git a/timm/models/__pycache__/twins.cpython-313.pyc b/timm/models/__pycache__/twins.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6e31e33a59fa6378e702826a38fadd920eec6b9 Binary files /dev/null and b/timm/models/__pycache__/twins.cpython-313.pyc differ diff --git a/timm/models/__pycache__/twins.cpython-38.pyc b/timm/models/__pycache__/twins.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a994fab65ff5c244556446b3949d66fd8e135fbd Binary files /dev/null and b/timm/models/__pycache__/twins.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vgg.cpython-310.pyc b/timm/models/__pycache__/vgg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..530c562977522e45229b1d4307e816a061c0f70e Binary files /dev/null and b/timm/models/__pycache__/vgg.cpython-310.pyc differ diff --git a/timm/models/__pycache__/vgg.cpython-311.pyc b/timm/models/__pycache__/vgg.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bacb07492d261ee801acbccf7a7cd6698938f6e6 Binary files /dev/null and b/timm/models/__pycache__/vgg.cpython-311.pyc differ diff --git a/timm/models/__pycache__/vgg.cpython-313.pyc b/timm/models/__pycache__/vgg.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b225e60e635c8c6bd4eef66957ec5d211c0493ea Binary files /dev/null and b/timm/models/__pycache__/vgg.cpython-313.pyc differ diff --git a/timm/models/__pycache__/vgg.cpython-38.pyc b/timm/models/__pycache__/vgg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43849432db74a092f55954e490258ff06b1a6bc4 Binary files /dev/null and b/timm/models/__pycache__/vgg.cpython-38.pyc differ diff --git a/timm/models/__pycache__/visformer.cpython-310.pyc b/timm/models/__pycache__/visformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1753dd42233a820df9f6fb595e66ecf5712d349f Binary files /dev/null and b/timm/models/__pycache__/visformer.cpython-310.pyc differ diff --git a/timm/models/__pycache__/visformer.cpython-311.pyc b/timm/models/__pycache__/visformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a19192345bb4c1c01438a10f06a28f68a997c3 Binary files /dev/null and b/timm/models/__pycache__/visformer.cpython-311.pyc differ diff --git a/timm/models/__pycache__/visformer.cpython-313.pyc b/timm/models/__pycache__/visformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1859e297407869a39c3ae38de620a8f043350404 Binary files /dev/null and b/timm/models/__pycache__/visformer.cpython-313.pyc differ diff --git a/timm/models/__pycache__/visformer.cpython-38.pyc b/timm/models/__pycache__/visformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..702493e28f3d9380849fbacc394b262c6c69f7a0 Binary files /dev/null and b/timm/models/__pycache__/visformer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vision_transformer.cpython-310.pyc b/timm/models/__pycache__/vision_transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..608d8597de428891cb0c3c5b369260d85a33b6e9 Binary files /dev/null and b/timm/models/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/timm/models/__pycache__/vision_transformer.cpython-311.pyc b/timm/models/__pycache__/vision_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86aff7811d857dad98e1a748cce0975d57734bd3 Binary files /dev/null and b/timm/models/__pycache__/vision_transformer.cpython-311.pyc differ diff --git a/timm/models/__pycache__/vision_transformer.cpython-313.pyc b/timm/models/__pycache__/vision_transformer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..984f7f2e9487aa64498f7ad97d8caa5f60cde1a2 Binary files /dev/null and b/timm/models/__pycache__/vision_transformer.cpython-313.pyc differ diff --git a/timm/models/__pycache__/vision_transformer.cpython-38.pyc b/timm/models/__pycache__/vision_transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a521e80b9ab344099bd339b7f7f3eac51ae3abfb Binary files /dev/null and b/timm/models/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vision_transformer_hybrid.cpython-310.pyc b/timm/models/__pycache__/vision_transformer_hybrid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1e6d03fa3582d7912ee3f33c5fd169be8c67e70 Binary files /dev/null and b/timm/models/__pycache__/vision_transformer_hybrid.cpython-310.pyc differ diff --git a/timm/models/__pycache__/vision_transformer_hybrid.cpython-311.pyc b/timm/models/__pycache__/vision_transformer_hybrid.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6453c11c2bfd4bf99e46919d262b89b74b8966e4 Binary files /dev/null and b/timm/models/__pycache__/vision_transformer_hybrid.cpython-311.pyc differ diff --git a/timm/models/__pycache__/vision_transformer_hybrid.cpython-313.pyc b/timm/models/__pycache__/vision_transformer_hybrid.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5797a1a10b55744ab3e3cf4d3bb2b46e70bd70cb Binary files /dev/null and b/timm/models/__pycache__/vision_transformer_hybrid.cpython-313.pyc differ diff --git a/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc b/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..474079d87d7b8f478a4a66ea0d82a79390f75810 Binary files /dev/null and b/timm/models/__pycache__/vision_transformer_hybrid.cpython-38.pyc differ diff --git a/timm/models/__pycache__/vovnet.cpython-310.pyc b/timm/models/__pycache__/vovnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c360625cbbdafe52b011b329e6e8f03784795d2 Binary files /dev/null and b/timm/models/__pycache__/vovnet.cpython-310.pyc differ diff --git a/timm/models/__pycache__/vovnet.cpython-311.pyc b/timm/models/__pycache__/vovnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2f9e77557db74c8317971be402f5cb7ec79d461 Binary files /dev/null and b/timm/models/__pycache__/vovnet.cpython-311.pyc differ diff --git a/timm/models/__pycache__/vovnet.cpython-313.pyc b/timm/models/__pycache__/vovnet.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b412eb114710785a4bea6f5e941b09e77b5a1e85 Binary files /dev/null and b/timm/models/__pycache__/vovnet.cpython-313.pyc differ diff --git a/timm/models/__pycache__/vovnet.cpython-38.pyc b/timm/models/__pycache__/vovnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..399310504707be8da76dc9de567a87b52de54abb Binary files /dev/null and b/timm/models/__pycache__/vovnet.cpython-38.pyc differ diff --git a/timm/models/__pycache__/xception.cpython-310.pyc b/timm/models/__pycache__/xception.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..866dbc60ecdf7e8187d08023d51fb8e82b0be95f Binary files /dev/null and b/timm/models/__pycache__/xception.cpython-310.pyc differ diff --git a/timm/models/__pycache__/xception.cpython-311.pyc b/timm/models/__pycache__/xception.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd3932b9d93cf6a37533610b63789bc0374534f Binary files /dev/null and b/timm/models/__pycache__/xception.cpython-311.pyc differ diff --git a/timm/models/__pycache__/xception.cpython-313.pyc b/timm/models/__pycache__/xception.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa6d183951b5e7887eecd7989698814170777d7 Binary files /dev/null and b/timm/models/__pycache__/xception.cpython-313.pyc differ diff --git a/timm/models/__pycache__/xception.cpython-38.pyc b/timm/models/__pycache__/xception.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0e6ef361bced0e0a3fcf54a7a3d004d5569322c Binary files /dev/null and b/timm/models/__pycache__/xception.cpython-38.pyc differ diff --git a/timm/models/__pycache__/xception_aligned.cpython-310.pyc b/timm/models/__pycache__/xception_aligned.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca0cda6fe175cc36e22f4dea4a086981e7ffdb2 Binary files /dev/null and b/timm/models/__pycache__/xception_aligned.cpython-310.pyc differ diff --git a/timm/models/__pycache__/xception_aligned.cpython-311.pyc b/timm/models/__pycache__/xception_aligned.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ac66759405719611d33f9ff1cc1402bd13abfc0 Binary files /dev/null and b/timm/models/__pycache__/xception_aligned.cpython-311.pyc differ diff --git a/timm/models/__pycache__/xception_aligned.cpython-313.pyc b/timm/models/__pycache__/xception_aligned.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4aee8dd628374d34b7b0cb465f31451d6b83696 Binary files /dev/null and b/timm/models/__pycache__/xception_aligned.cpython-313.pyc differ diff --git a/timm/models/__pycache__/xception_aligned.cpython-38.pyc b/timm/models/__pycache__/xception_aligned.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a48377c6020eed54b9476097b62346e3482e7b5 Binary files /dev/null and b/timm/models/__pycache__/xception_aligned.cpython-38.pyc differ diff --git a/timm/models/__pycache__/xcit.cpython-38.pyc b/timm/models/__pycache__/xcit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3c67b8b437c7b20c3e67932d0f4f384eaeb7aae Binary files /dev/null and b/timm/models/__pycache__/xcit.cpython-38.pyc differ diff --git a/timm/models/beit.py b/timm/models/beit.py new file mode 100644 index 0000000000000000000000000000000000000000..f644b657a01c1b10099df280382964a575aa3267 --- /dev/null +++ b/timm/models/beit.py @@ -0,0 +1,416 @@ +""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) + +Model from official source: https://github.com/microsoft/unilm/tree/master/beit + +At this point only the 1k fine-tuned classification weights and model configs have been added, +see original source above for pre-training models and procedure. + +Modifications by / Copyright 2021 Ross Wightman, original copyrights below +""" +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm and DeiT code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from .registry import register_model +from .vision_transformer import checkpoint_filter_fn + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'beit_base_patch16_224': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), + 'beit_base_patch16_384': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'beit_base_patch16_224_in22k': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth', + num_classes=21841, + ), + 'beit_large_patch16_224': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), + 'beit_large_patch16_384': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'beit_large_patch16_512': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, + ), + 'beit_large_patch16_224_in22k': _cfg( + url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth', + num_classes=21841, + ), +} + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): + B, N, C = x.shape + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class Beit(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) + for i in range(depth)]) + self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) + self.fix_init_weight() + if isinstance(self.head, nn.Linear): + trunc_normal_(self.head.weight, std=.02) + self.head.weight.data.mul_(init_scale) + self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias=rel_pos_bias) + + x = self.norm(x) + if self.fc_norm is not None: + t = x[:, 1:, :] + return self.fc_norm(t.mean(1)) + else: + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Beit models.') + + model = build_model_with_cfg( + Beit, variant, pretrained, + default_cfg=default_cfg, + # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def beit_base_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) + model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beit_base_patch16_384(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) + model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beit_base_patch16_224_in22k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs) + model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beit_large_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beit_large_patch16_384(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beit_large_patch16_512(pretrained=False, **kwargs): + model_kwargs = dict( + img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beit_large_patch16_224_in22k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py new file mode 100644 index 0000000000000000000000000000000000000000..73c6811b9ce77aad8e11190e6a2d7599b1bb5c23 --- /dev/null +++ b/timm/models/byoanet.py @@ -0,0 +1,437 @@ +""" Bring-Your-Own-Attention Network + +A flexible network w/ dataclass based config for stacking NN blocks including +self-attention (or similar) layers. + +Currently used to implement experimential variants of: + * Bottleneck Transformers + * Lambda ResNets + * HaloNets + +Consider all of the models definitions here as experimental WIP and likely to change. + +Hacked together by / copyright Ross Wightman, 2021. +""" +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks +from .helpers import build_model_with_cfg +from .registry import register_model + +__all__ = [] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + 'fixed_input_size': False, 'min_input_size': (3, 224, 224), + **kwargs + } + + +default_cfgs = { + # GPU-Efficient (ResNet) weights + 'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_botnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'eca_halonext26ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + + 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), + 'eca_lambda_resnext26ts': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), + + 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), +} + + +model_cfgs = dict( + + botnet26t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + fixed_input_size=True, + self_attn_layer='bottleneck', + self_attn_kwargs=dict() + ), + botnet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='', + num_features=0, + fixed_input_size=True, + act_layer='silu', + self_attn_layer='bottleneck', + self_attn_kwargs=dict() + ), + eca_botnext26ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + fixed_input_size=True, + act_layer='silu', + attn_layer='eca', + self_attn_layer='bottleneck', + self_attn_kwargs=dict() + ), + + halonet_h1=ByoModelCfg( + blocks=( + ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ), + stem_chs=64, + stem_type='7x7', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=3), + ), + halonet_h1_c4c5=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0), + ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=3), + ), + halonet26t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + ), + halonet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=2) + ), + eca_halonext26ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + ), + + lambda_resnet26t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), + lambda_resnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), + eca_lambda_resnext26ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), + + swinnet26t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + fixed_input_size=True, + self_attn_layer='swin', + self_attn_kwargs=dict(win_size=8) + ), + swinnet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + fixed_input_size=True, + act_layer='silu', + self_attn_layer='swin', + self_attn_kwargs=dict(win_size=8) + ), + eca_swinnext26ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + fixed_input_size=True, + act_layer='silu', + attn_layer='eca', + self_attn_layer='swin', + self_attn_kwargs=dict(win_size=8) + ), + + + rednet26t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', # FIXME RedNet uses involution in middle of stem + stem_pool='maxpool', + num_features=0, + self_attn_layer='involution', + self_attn_kwargs=dict() + ), + rednet50ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + self_attn_layer='involution', + self_attn_kwargs=dict() + ), +) + + +def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def botnet26t_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs) + + +@register_model +def botnet50ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet50-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def eca_botnext26ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) + + +@register_model +def halonet_h1(pretrained=False, **kwargs): + """ HaloNet-H1. Halo attention in all stages as per the paper. + + This runs very slowly, param count lower than paper --> something is wrong. + """ + return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs) + + +@register_model +def halonet_h1_c4c5(pretrained=False, **kwargs): + """ HaloNet-H1 config w/ attention in last two stages. + """ + return _create_byoanet('halonet_h1_c4c5', pretrained=pretrained, **kwargs) + + +@register_model +def halonet26t(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ + return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs) + + +@register_model +def halonet50ts(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet50-t backbone, Hallo attention in final stage + """ + return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def eca_halonext26ts(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ + return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet26t(pretrained=False, **kwargs): + """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. + """ + return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs) + + +@register_model +def lambda_resnet50t(pretrained=False, **kwargs): + """ Lambda-ResNet-50T. Lambda layers in one C4 stage and all C5. + """ + return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def eca_lambda_resnext26ts(pretrained=False, **kwargs): + """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. + """ + return _create_byoanet('eca_lambda_resnext26ts', pretrained=pretrained, **kwargs) + + +@register_model +def swinnet26t_256(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('swinnet26t_256', 'swinnet26t', pretrained=pretrained, **kwargs) + + +@register_model +def swinnet50ts_256(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def eca_swinnext26ts_256(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_swinnext26ts_256', 'eca_swinnext26ts', pretrained=pretrained, **kwargs) + + +@register_model +def rednet26t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs) + + +@register_model +def rednet50ts(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py new file mode 100644 index 0000000000000000000000000000000000000000..38ff6615ed1bd80603824388c808f020d5862571 --- /dev/null +++ b/timm/models/byobnet.py @@ -0,0 +1,1156 @@ +""" Bring-Your-Own-Blocks Network + +A flexible network w/ dataclass based config for stacking those NN blocks. + +This model is currently used to implement the following networks: + +GPU Efficient (ResNets) - gernet_l/m/s (original versions called genet, but this was already used (by SENet author)). +Paper: `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 +Code and weights: https://github.com/idstcv/GPU-Efficient-Networks, licensed Apache 2.0 + +RepVGG - repvgg_* +Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 +Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT + +In all cases the models have been modified to fit within the design of ByobNet. I've remapped +the original weights and verified accuracies. + +For GPU Efficient nets, I used the original names for the blocks since they were for the most part +the same as original residual blocks in ResNe(X)t, DarkNet, and other existing models. Note also some +changes introduced in RegNet were also present in the stem and bottleneck blocks for this model. + +A significant number of different network archs can be implemented here, including variants of the +above nets that include attention. + +Hacked together by / copyright Ross Wightman, 2021. +""" +import math +from dataclasses import dataclass, field, replace +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple +from .registry import register_model + +__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + # GPU-Efficient (ResNet) weights + 'gernet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth'), + 'gernet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth'), + 'gernet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + + # RepVGG weights + 'repvgg_a2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_a2-c1ee6d2b.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b0-80ac3f1b.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b1-77ca2989.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b1g4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b1g4-abde5d92.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b2-25b7494e.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b2g4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b2g4-165a85f2.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3-199bc50d.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + 'repvgg_b3g4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth', + first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + + # experimental configs + 'resnet51q': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth', + first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8), + test_input_size=(3, 288, 288), crop_pct=1.0), + 'resnet61q': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'geresnet50t': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'gcresnet50t': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + + 'gcresnext26ts': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'), + 'bat_resnext26ts': _cfg( + first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic', + min_input_size=(3, 256, 256)), +} + + +@dataclass +class ByoBlockCfg: + type: Union[str, nn.Module] + d: int # block depth (number of block repeats in stage) + c: int # number of output channels for each block in stage + s: int = 2 # stride of stage (first block) + gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 + br: float = 1. # bottleneck-ratio of blocks in stage + + # NOTE: these config items override the model cfgs that are applied to all blocks by default + attn_layer: Optional[str] = None + attn_kwargs: Optional[Dict[str, Any]] = None + self_attn_layer: Optional[str] = None + self_attn_kwargs: Optional[Dict[str, Any]] = None + block_kwargs: Optional[Dict[str, Any]] = None + + +@dataclass +class ByoModelCfg: + blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...] + downsample: str = 'conv1x1' + stem_type: str = '3x3' + stem_pool: Optional[str] = 'maxpool' + stem_chs: int = 32 + width_factor: float = 1.0 + num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 + zero_init_last_bn: bool = True + fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation + + act_layer: str = 'relu' + norm_layer: str = 'batchnorm' + + # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there + attn_layer: Optional[str] = None + attn_kwargs: dict = field(default_factory=lambda: dict()) + self_attn_layer: Optional[str] = None + self_attn_kwargs: dict = field(default_factory=lambda: dict()) + block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict()) + + +def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): + c = (64, 128, 256, 512) + group_size = 0 + if groups > 0: + group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0 + bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)]) + return bcfg + + +def interleave_blocks( + types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs +) -> Tuple[ByoBlockCfg]: + """ interleave 2 block types in stack + """ + assert len(types) == 2 + if isinstance(every, int): + every = list(range(0 if first else every, d, every)) + if not every: + every = [d - 1] + set(every) + blocks = [] + for i in range(d): + block_type = types[1] if i in every else types[0] + blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)] + return tuple(blocks) + + +model_cfgs = dict( + gernet_l=ByoModelCfg( + blocks=( + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + stem_pool=None, + num_features=2560, + ), + gernet_m=ByoModelCfg( + blocks=( + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ), + stem_chs=32, + stem_pool=None, + num_features=2560, + ), + gernet_s=ByoModelCfg( + blocks=( + ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ), + stem_chs=13, + stem_pool=None, + num_features=1920, + ), + + repvgg_a2=ByoModelCfg( + blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)), + stem_type='rep', + stem_chs=64, + ), + repvgg_b0=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)), + stem_type='rep', + stem_chs=64, + ), + repvgg_b1=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)), + stem_type='rep', + stem_chs=64, + ), + repvgg_b1g4=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4), + stem_type='rep', + stem_chs=64, + ), + repvgg_b2=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)), + stem_type='rep', + stem_chs=64, + ), + repvgg_b2g4=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4), + stem_type='rep', + stem_chs=64, + ), + repvgg_b3=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)), + stem_type='rep', + stem_chs=64, + ), + repvgg_b3g4=ByoModelCfg( + blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4), + stem_type='rep', + stem_chs=64, + ), + + # WARN: experimental, may vanish/change + resnet51q=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), + ), + stem_chs=128, + stem_type='quad2', + stem_pool=None, + num_features=2048, + act_layer='silu', + ), + + resnet61q=ByoModelCfg( + blocks=( + ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), + ), + stem_chs=128, + stem_type='quad', + stem_pool=None, + num_features=2048, + act_layer='silu', + block_kwargs=dict(extra_conv=True), + ), + + # WARN: experimental, may vanish/change + geresnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool=None, + attn_layer='ge', + attn_kwargs=dict(extent=8, extra_params=True), + #attn_kwargs=dict(extent=8), + #block_kwargs=dict(attn_last=True) + ), + + # WARN: experimental, may vanish/change + gcresnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool=None, + attn_layer='gc' + ), + + gcresnext26ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, gs=32, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='gc', + ), + + bat_resnext26ts=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='bat', + attn_kwargs=dict(block_size=8) + ), +) + + +@register_model +def gernet_l(pretrained=False, **kwargs): + """ GEResNet-Large (GENet-Large from official impl) + `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 + """ + return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs) + + +@register_model +def gernet_m(pretrained=False, **kwargs): + """ GEResNet-Medium (GENet-Normal from official impl) + `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 + """ + return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs) + + +@register_model +def gernet_s(pretrained=False, **kwargs): + """ EResNet-Small (GENet-Small from official impl) + `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090 + """ + return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_a2(pretrained=False, **kwargs): + """ RepVGG-A2 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b0(pretrained=False, **kwargs): + """ RepVGG-B0 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b1(pretrained=False, **kwargs): + """ RepVGG-B1 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b1g4(pretrained=False, **kwargs): + """ RepVGG-B1g4 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b2(pretrained=False, **kwargs): + """ RepVGG-B2 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b2g4(pretrained=False, **kwargs): + """ RepVGG-B2g4 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b3(pretrained=False, **kwargs): + """ RepVGG-B3 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs) + + +@register_model +def repvgg_b3g4(pretrained=False, **kwargs): + """ RepVGG-B3g4 + `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 + """ + return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) + + +@register_model +def resnet51q(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs) + + +@register_model +def resnet61q(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs) + + +@register_model +def geresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def gcresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def gcresnext26ts(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('gcresnext26ts', pretrained=pretrained, **kwargs) + + +@register_model +def bat_resnext26ts(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('bat_resnext26ts', pretrained=pretrained, **kwargs) + + +def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: + if not isinstance(stage_blocks_cfg, Sequence): + stage_blocks_cfg = (stage_blocks_cfg,) + block_cfgs = [] + for i, cfg in enumerate(stage_blocks_cfg): + block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)] + return block_cfgs + + +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +@dataclass +class LayerFn: + conv_norm_act: Callable = ConvBnAct + norm_act: Callable = BatchNormAct2d + act: Callable = nn.ReLU + attn: Optional[Callable] = None + self_attn: Optional[Callable] = None + + +class DownsampleAvg(nn.Module): + def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None): + """ AvgPool Downsampling as in 'D' ResNet variants.""" + super(DownsampleAvg, self).__init__() + layers = layers or LayerFn() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act) + + def forward(self, x): + return self.conv(self.pool(x)) + + +def create_downsample(downsample_type, layers: LayerFn, **kwargs): + if downsample_type == 'avg': + return DownsampleAvg(**kwargs) + else: + return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs) + + +class BasicBlock(nn.Module): + """ ResNet Basic Block - kxk + kxk + """ + + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0, + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): + super(BasicBlock, self).__init__() + layers = layers or LayerFn() + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) + self.conv2_kxk = layers.conv_norm_act( + mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv2_kxk.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + # residual path + x = self.conv1_kxk(x) + x = self.conv2_kxk(x) + x = self.attn(x) + x = self.drop_path(x) + + x = self.act(x + shortcut) + return x + + +class BottleneckBlock(nn.Module): + """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1 + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None, + drop_block=None, drop_path_rate=0.): + super(BottleneckBlock, self).__init__() + layers = layers or LayerFn() + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + if extra_conv: + self.conv2b_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block) + else: + self.conv2b_kxk = nn.Identity() + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv3_1x1.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_1x1(x) + x = self.conv2_kxk(x) + x = self.conv2b_kxk(x) + x = self.attn(x) + x = self.conv3_1x1(x) + x = self.attn_last(x) + x = self.drop_path(x) + + x = self.act(x + shortcut) + return x + + +class DarkBlock(nn.Module): + """ DarkNet-like (1x1 + 3x3 w/ stride) block + + The GE-Net impl included a 1x1 + 3x3 block in their search space. It was not used in the feature models. + This block is pretty much a DarkNet block (also DenseNet) hence the name. Neither DarkNet or DenseNet + uses strides within the block (external 3x3 or maxpool downsampling is done in front of the block repeats). + + If one does want to use a lot of these blocks w/ stride, I'd recommend using the EdgeBlock (3x3 /w stride + 1x1) + for more optimal compute. + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): + super(DarkBlock, self).__init__() + layers = layers or LayerFn() + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) + self.conv2_kxk = layers.conv_norm_act( + mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv2_kxk.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_1x1(x) + x = self.attn(x) + x = self.conv2_kxk(x) + x = self.attn_last(x) + x = self.drop_path(x) + x = self.act(x + shortcut) + return x + + +class EdgeBlock(nn.Module): + """ EdgeResidual-like (3x3 + 1x1) block + + A two layer block like DarkBlock, but with the order of the 3x3 and 1x1 convs reversed. + Very similar to the EfficientNet Edge-Residual block but this block it ends with activations, is + intended to be used with either expansion or bottleneck contraction, and can use DW/group/non-grouped convs. + + FIXME is there a more common 3x3 + 1x1 conv block to name this after? + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, + downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, + drop_block=None, drop_path_rate=0.): + super(EdgeBlock, self).__init__() + layers = layers or LayerFn() + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_kxk = layers.conv_norm_act( + in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) + self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv2_1x1.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_kxk(x) + x = self.attn(x) + x = self.conv2_1x1(x) + x = self.attn_last(x) + x = self.drop_path(x) + x = self.act(x + shortcut) + return x + + +class RepVggBlock(nn.Module): + """ RepVGG Block. + + Adapted from impl at https://github.com/DingXiaoH/RepVGG + + This version does not currently support the deploy optimization. It is currently fixed in 'train' mode. + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, + downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + super(RepVggBlock, self).__init__() + layers = layers or LayerFn() + groups = num_groups(group_size, in_chs) + + use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1] + self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None + self.conv_kxk = layers.conv_norm_act( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block, apply_act=False) + self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) + self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() + self.act = layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + # NOTE this init overrides that base model init with specific changes for the block type + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + nn.init.normal_(m.weight, .1, .1) + nn.init.normal_(m.bias, 0, .1) + if hasattr(self.attn, 'reset_parameters'): + self.attn.reset_parameters() + + def forward(self, x): + if self.identity is None: + x = self.conv_1x1(x) + self.conv_kxk(x) + else: + identity = self.identity(x) + x = self.conv_1x1(x) + self.conv_kxk(x) + x = self.drop_path(x) # not in the paper / official impl, experimental + x = x + identity + x = self.attn(x) # no attn in the paper / official impl, experimental + x = self.act(x) + return x + + +class SelfAttnBlock(nn.Module): + """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, + layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + super(SelfAttnBlock, self).__init__() + assert layers is not None + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + if extra_conv: + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + stride = 1 # striding done via conv if enabled + else: + self.conv2_kxk = nn.Identity() + opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) + # FIXME need to dilate self attn to have dilated network support, moop moop + self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) + self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv3_1x1.bn.weight) + if hasattr(self.self_attn, 'reset_parameters'): + self.self_attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_1x1(x) + x = self.conv2_kxk(x) + x = self.self_attn(x) + x = self.post_attn(x) + x = self.conv3_1x1(x) + x = self.drop_path(x) + + x = self.act(x + shortcut) + return x + + +_block_registry = dict( + basic=BasicBlock, + bottle=BottleneckBlock, + dark=DarkBlock, + edge=EdgeBlock, + rep=RepVggBlock, + self_attn=SelfAttnBlock, +) + + +def register_block(block_type:str, block_fn: nn.Module): + _block_registry[block_type] = block_fn + + +def create_block(block: Union[str, nn.Module], **kwargs): + if isinstance(block, (nn.Module, partial)): + return block(**kwargs) + assert block in _block_registry, f'Unknown block type ({block}' + return _block_registry[block](**kwargs) + + +class Stem(nn.Sequential): + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool', + num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None): + super().__init__() + assert stride in (2, 4) + layers = layers or LayerFn() + + if isinstance(out_chs, (list, tuple)): + num_rep = len(out_chs) + stem_chs = out_chs + else: + stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1] + + self.stride = stride + self.feature_info = [] # track intermediate features + prev_feat = '' + stem_strides = [2] + [1] * (num_rep - 1) + if stride == 4 and not pool: + # set last conv in stack to be strided if stride == 4 and no pooling layer + stem_strides[-1] = 2 + + num_act = num_rep if num_act is None else num_act + # if num_act < num_rep, first convs in stack won't have bn + act + stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act + prev_chs = in_chs + curr_stride = 1 + for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)): + layer_fn = layers.conv_norm_act if na else create_conv2d + conv_name = f'conv{i + 1}' + if i > 0 and s > 1: + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s)) + prev_chs = ch + curr_stride *= s + prev_feat = conv_name + + if pool and 'max' in pool.lower(): + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + self.add_module('pool', nn.MaxPool2d(3, 2, 1)) + curr_stride *= 2 + prev_feat = 'pool' + + self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) + assert curr_stride == stride + + +def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None): + layers = layers or LayerFn() + assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3') + if 'quad' in stem_type: + # based on NFNet stem, stack of 4 3x3 convs + num_act = 2 if 'quad2' in stem_type else None + stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers) + elif 'tiered' in stem_type: + # 3x3 stack of 3 convs as in my ResNet-T + stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers) + elif 'deep' in stem_type: + # 3x3 stack of 3 convs as in ResNet-D + stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers) + elif 'rep' in stem_type: + stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers) + elif '7x7' in stem_type: + # 7x7 stem conv as in ResNet + if pool_type: + stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers) + else: + stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2) + else: + # 3x3 stem conv as in RegNet is the default + if pool_type: + stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + else: + stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) + + if isinstance(stem, Stem): + feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] + else: + feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)] + return stem, feature_info + + +def reduce_feat_size(feat_size, stride=2): + return None if feat_size is None else tuple([s // stride for s in feat_size]) + + +def override_kwargs(block_kwargs, model_kwargs): + """ Override model level attn/self-attn/block kwargs w/ block level + + NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs + for the block if set to anything that isn't None. + + i.e. an empty block_kwargs dict will remove kwargs set at model level for that block + """ + out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs + return out_kwargs or {} # make sure None isn't returned + + +def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ): + layer_fns = block_kwargs['layers'] + + # override attn layer / args with block local config + if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None: + # override attn layer config + if not block_cfg.attn_layer: + # empty string for attn_layer type will disable attn for this block + attn_layer = None + else: + attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs) + attn_layer = block_cfg.attn_layer or model_cfg.attn_layer + attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None + layer_fns = replace(layer_fns, attn=attn_layer) + + # override self-attn layer / args with block local cfg + if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None: + # override attn layer config + if not block_cfg.self_attn_layer: + # empty string for self_attn_layer type will disable attn for this block + self_attn_layer = None + else: + self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) + self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer + self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \ + if self_attn_layer is not None else None + layer_fns = replace(layer_fns, self_attn=self_attn_layer) + + block_kwargs['layers'] = layer_fns + + # add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set + block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs)) + + +def create_byob_stages( + cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + feat_size: Optional[int] = None, + layers: Optional[LayerFn] = None, + block_kwargs_fn: Optional[Callable] = update_block_kwargs): + + layers = layers or LayerFn() + feature_info = [] + block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] + depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs] + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + dilation = 1 + net_stride = stem_feat['reduction'] + prev_chs = stem_feat['num_chs'] + prev_feat = stem_feat + stages = [] + for stage_idx, stage_block_cfgs in enumerate(block_cfgs): + stride = stage_block_cfgs[0].s + if stride != 1 and prev_feat: + feature_info.append(prev_feat) + if net_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + blocks = [] + for block_idx, block_cfg in enumerate(stage_block_cfgs): + out_chs = make_divisible(block_cfg.c * cfg.width_factor) + group_size = block_cfg.gs + if isinstance(group_size, Callable): + group_size = group_size(out_chs, block_idx) + block_kwargs = dict( # Blocks used in this model must accept these arguments + in_chs=prev_chs, + out_chs=out_chs, + stride=stride if block_idx == 0 else 1, + dilation=(first_dilation, dilation), + group_size=group_size, + bottle_ratio=block_cfg.br, + downsample=cfg.downsample, + drop_path_rate=dpr[stage_idx][block_idx], + layers=layers, + ) + if block_cfg.type in ('self_attn',): + # add feat_size arg for blocks that support/need it + block_kwargs['feat_size'] = feat_size + block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg) + blocks += [create_block(block_cfg.type, **block_kwargs)] + first_dilation = dilation + prev_chs = out_chs + if stride > 1 and block_idx == 0: + feat_size = reduce_feat_size(feat_size, stride) + + stages += [nn.Sequential(*blocks)] + prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}') + + feature_info.append(prev_feat) + return nn.Sequential(*stages), feature_info + + +def get_layer_fns(cfg: ByoModelCfg): + act = get_act_layer(cfg.act_layer) + norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) + conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) + attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None + layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) + return layer_fn + + +class ByobNet(nn.Module): + """ 'Bring-your-own-blocks' Net + + A flexible network backbone that allows building model stem + blocks via + dataclass cfg definition w/ factory functions for module instantiation. + + Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). + """ + def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + layers = get_layer_fns(cfg) + if cfg.fixed_input_size: + assert img_size is not None, 'img_size argument is required for fixed input size model' + feat_size = to_2tuple(img_size) if img_size is not None else None + + self.feature_info = [] + stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) + self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) + self.feature_info.extend(stem_feat[:-1]) + feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) + + self.stages, stage_feat = create_byob_stages( + cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size) + self.feature_info.extend(stage_feat[:-1]) + + prev_chs = stage_feat[-1]['num_chs'] + if cfg.num_features: + self.num_features = int(round(cfg.width_factor * cfg.num_features)) + self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) + else: + self.num_features = prev_chs + self.final_conv = nn.Identity() + self.feature_info += [ + dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + for n, m in self.named_modules(): + _init_weights(m, n) + for m in self.modules(): + # call each block's weight init for block-specific overrides to init above + if hasattr(m, 'init_weights'): + m.init_weights(zero_init_last_bn=zero_init_last_bn) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.final_conv(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(m, n=''): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + +def _create_byobnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ByobNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) diff --git a/timm/models/cait.py b/timm/models/cait.py new file mode 100644 index 0000000000000000000000000000000000000000..69b4ba06c889196a19022d5938a73600734ebc2d --- /dev/null +++ b/timm/models/cait.py @@ -0,0 +1,394 @@ +""" Class-Attention in Image Transformers (CaiT) + +Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239 + +Original code and weights from https://github.com/facebookresearch/deit, copyright below + +""" +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +from copy import deepcopy + +import torch +import torch.nn as nn +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ +from .registry import register_model + + +__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + cait_xxs24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS24_224.pth', + input_size=(3, 224, 224), + ), + cait_xxs24_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS24_384.pth', + ), + cait_xxs36_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS36_224.pth', + input_size=(3, 224, 224), + ), + cait_xxs36_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS36_384.pth', + ), + cait_xs24_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XS24_384.pth', + ), + cait_s24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/S24_224.pth', + input_size=(3, 224, 224), + ), + cait_s24_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/S24_384.pth', + ), + cait_s36_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/S36_384.pth', + ), + cait_m36_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/M36_384.pth', + ), + cait_m48_448=_cfg( + url='https://dl.fbaipublicfiles.com/deit/M48_448.pth', + input_size=(3, 448, 448), + ), +) + + +class ClassAttn(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to do CA + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.k = nn.Linear(dim, dim, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + q = q * self.scale + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) + x_cls = self.proj(x_cls) + x_cls = self.proj_drop(x_cls) + + return x_cls + + +class LayerScaleBlockClassAttn(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add CA and LayerScale + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn, + mlp_block=Mlp, init_values=1e-4): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_block( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, x, x_cls): + u = torch.cat((x_cls, x), dim=1) + x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u))) + x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls))) + return x_cls + + +class TalkingHeadAttn(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + + self.num_heads = num_heads + + head_dim = dim // num_heads + + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Linear(dim, dim) + + self.proj_l = nn.Linear(num_heads, num_heads) + self.proj_w = nn.Linear(num_heads, num_heads) + + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) + + attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + attn = attn.softmax(dim=-1) + + attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScaleBlock(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to add layerScale + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn, + mlp_block=Mlp, init_values=1e-4): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_block( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class Cait(nn.Module): + # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + # with slight modifications to adapt to our cait models + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + global_pool=None, + block_layers=LayerScaleBlock, + block_layers_token=LayerScaleBlockClassAttn, + patch_layer=PatchEmbed, + act_layer=nn.GELU, + attn_block=TalkingHeadAttn, + mlp_block=Mlp, + init_scale=1e-4, + attn_block_token_only=ClassAttn, + mlp_block_token_only=Mlp, + depth_token_only=2, + mlp_ratio_clstk=4.0 + ): + super().__init__() + + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = patch_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [drop_path_rate for i in range(depth)] + self.blocks = nn.ModuleList([ + block_layers( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale) + for i in range(depth)]) + + self.blocks_token_only = nn.ModuleList([ + block_layers_token( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, + drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, + act_layer=act_layer, attn_block=attn_block_token_only, + mlp_block=mlp_block_token_only, init_values=init_scale) + for i in range(depth_token_only)]) + + self.norm = norm_layer(embed_dim) + + self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + x = x + self.pos_embed + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + x = blk(x) + + for i, blk in enumerate(self.blocks_token_only): + cls_tokens = blk(x, cls_tokens) + + x = torch.cat((cls_tokens, x), dim=1) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model=None): + if 'model' in state_dict: + state_dict = state_dict['model'] + checkpoint_no_module = {} + for k, v in state_dict.items(): + checkpoint_no_module[k.replace('module.', '')] = v + return checkpoint_no_module + + +def _create_cait(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + Cait, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def cait_xxs24_224(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_xxs24_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_xxs36_224(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_xxs36_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_xs24_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_s24_224(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) + model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_s24_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) + model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_s36_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs) + model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_m36_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs) + model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args) + return model + + +@register_model +def cait_m48_448(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs) + model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/coat.py b/timm/models/coat.py new file mode 100644 index 0000000000000000000000000000000000000000..f071715a347120ce0ca4710eceddda61de28ce8f --- /dev/null +++ b/timm/models/coat.py @@ -0,0 +1,660 @@ +""" +CoaT architecture. + +Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399 + +Official CoaT code at: https://github.com/mlpc-ucsd/CoaT + +Modified from timm/models/vision_transformer.py +""" +from copy import deepcopy +from functools import partial +from typing import Tuple, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model + + +__all__ = [ + "coat_tiny", + "coat_mini", + "coat_lite_tiny", + "coat_lite_mini", + "coat_lite_small" +] + + +def _cfg_coat(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed1.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'coat_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth' + ), + 'coat_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth' + ), + 'coat_lite_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' + ), + 'coat_lite_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' + ), + 'coat_lite_small': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth' + ), +} + + +class ConvRelPosEnc(nn.Module): + """ Convolutional relative position encoding. """ + def __init__(self, Ch, h, window): + """ + Initialization. + Ch: Channels per head. + h: Number of heads. + window: Window size(s) in convolutional relative positional encoding. It can have two forms: + 1. An integer of window size, which assigns all attention heads with the same window s + size in ConvRelPosEnc. + 2. A dict mapping window size to #attention head splits ( + e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) + It will apply different window size to the attention head splits. + """ + super().__init__() + + if isinstance(window, int): + # Set the same window size for all attention heads. + window = {window: h} + self.window = window + elif isinstance(window, dict): + self.window = window + else: + raise ValueError() + + self.conv_list = nn.ModuleList() + self.head_splits = [] + for cur_window, cur_head_split in window.items(): + dilation = 1 + # Determine padding size. + # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 + padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 + cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch, + kernel_size=(cur_window, cur_window), + padding=(padding_size, padding_size), + dilation=(dilation, dilation), + groups=cur_head_split*Ch, + ) + self.conv_list.append(cur_conv) + self.head_splits.append(cur_head_split) + self.channel_splits = [x*Ch for x in self.head_splits] + + def forward(self, q, v, size: Tuple[int, int]): + B, h, N, Ch = q.shape + H, W = size + assert N == 1 + H * W + + # Convolutional relative position encoding. + q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] + v_img = v[:, :, 1:, :] # [B, h, H*W, Ch] + + v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W) + v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels + conv_v_img_list = [] + for i, conv in enumerate(self.conv_list): + conv_v_img_list.append(conv(v_img_list[i])) + conv_v_img = torch.cat(conv_v_img_list, dim=1) + conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2) + + EV_hat = q_img * conv_v_img + EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch]. + return EV_hat + + +class FactorAtt_ConvRelPosEnc(nn.Module): + """ Factorized attention with convolutional relative position encoding class. """ + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # Shared convolutional relative position encoding. + self.crpe = shared_crpe + + def forward(self, x, size: Tuple[int, int]): + B, N, C = x.shape + + # Generate Q, K, V. + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch] + + # Factorized attention. + k_softmax = k.softmax(dim=2) + factor_att = k_softmax.transpose(-1, -2) @ v + factor_att = q @ factor_att + + # Convolutional relative position encoding. + crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] + + # Merge and reshape. + x = self.scale * factor_att + crpe + x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C] + + # Output projection. + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class ConvPosEnc(nn.Module): + """ Convolutional Position Encoding. + Note: This module is similar to the conditional position encoding in CPVT. + """ + def __init__(self, dim, k=3): + super(ConvPosEnc, self).__init__() + self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) + + def forward(self, x, size: Tuple[int, int]): + B, N, C = x.shape + H, W = size + assert N == 1 + H * W + + # Extract CLS token and image tokens. + cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] + + # Depthwise convolution. + feat = img_tokens.transpose(1, 2).view(B, C, H, W) + x = self.proj(feat) + feat + x = x.flatten(2).transpose(1, 2) + + # Combine with CLS token. + x = torch.cat((cls_token, x), dim=1) + + return x + + +class SerialBlock(nn.Module): + """ Serial block class. + Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None): + super().__init__() + + # Conv-Attention. + self.cpe = shared_cpe + + self.norm1 = norm_layer(dim) + self.factoratt_crpe = FactorAtt_ConvRelPosEnc( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + # MLP. + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size: Tuple[int, int]): + # Conv-Attention. + x = self.cpe(x, size) + cur = self.norm1(x) + cur = self.factoratt_crpe(cur, size) + x = x + self.drop_path(cur) + + # MLP. + cur = self.norm2(x) + cur = self.mlp(cur) + x = x + self.drop_path(cur) + + return x + + +class ParallelBlock(nn.Module): + """ Parallel block class. """ + def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None): + super().__init__() + + # Conv-Attention. + self.norm12 = norm_layer(dims[1]) + self.norm13 = norm_layer(dims[2]) + self.norm14 = norm_layer(dims[3]) + self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( + dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpes[1] + ) + self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( + dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpes[2] + ) + self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( + dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpes[3] + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + # MLP. + self.norm22 = norm_layer(dims[1]) + self.norm23 = norm_layer(dims[2]) + self.norm24 = norm_layer(dims[3]) + # In parallel block, we assume dimensions are the same and share the linear transformation. + assert dims[1] == dims[2] == dims[3] + assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] + mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) + self.mlp2 = self.mlp3 = self.mlp4 = Mlp( + in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def upsample(self, x, factor: float, size: Tuple[int, int]): + """ Feature map up-sampling. """ + return self.interpolate(x, scale_factor=factor, size=size) + + def downsample(self, x, factor: float, size: Tuple[int, int]): + """ Feature map down-sampling. """ + return self.interpolate(x, scale_factor=1.0/factor, size=size) + + def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): + """ Feature map interpolation. """ + B, N, C = x.shape + H, W = size + assert N == 1 + H * W + + cls_token = x[:, :1, :] + img_tokens = x[:, 1:, :] + + img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) + img_tokens = F.interpolate( + img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False) + img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) + + out = torch.cat((cls_token, img_tokens), dim=1) + + return out + + def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): + _, S2, S3, S4 = sizes + cur2 = self.norm12(x2) + cur3 = self.norm13(x3) + cur4 = self.norm14(x4) + cur2 = self.factoratt_crpe2(cur2, size=S2) + cur3 = self.factoratt_crpe3(cur3, size=S3) + cur4 = self.factoratt_crpe4(cur4, size=S4) + upsample3_2 = self.upsample(cur3, factor=2., size=S3) + upsample4_3 = self.upsample(cur4, factor=2., size=S4) + upsample4_2 = self.upsample(cur4, factor=4., size=S4) + downsample2_3 = self.downsample(cur2, factor=2., size=S2) + downsample3_4 = self.downsample(cur3, factor=2., size=S3) + downsample2_4 = self.downsample(cur2, factor=4., size=S2) + cur2 = cur2 + upsample3_2 + upsample4_2 + cur3 = cur3 + upsample4_3 + downsample2_3 + cur4 = cur4 + downsample3_4 + downsample2_4 + x2 = x2 + self.drop_path(cur2) + x3 = x3 + self.drop_path(cur3) + x4 = x4 + self.drop_path(cur4) + + # MLP. + cur2 = self.norm22(x2) + cur3 = self.norm23(x3) + cur4 = self.norm24(x4) + cur2 = self.mlp2(cur2) + cur3 = self.mlp3(cur3) + cur4 = self.mlp4(cur4) + x2 = x2 + self.drop_path(cur2) + x3 = x3 + self.drop_path(cur3) + x4 = x4 + self.drop_path(cur4) + + return x1, x2, x3, x4 + + +class CoaT(nn.Module): + """ CoaT class. """ + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), + serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, out_features=None, crpe_window=None, **kwargs): + super().__init__() + crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} + self.return_interm_layers = return_interm_layers + self.out_features = out_features + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + self.num_classes = num_classes + + # Patch embeddings. + img_size = to_2tuple(img_size) + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) + self.patch_embed2 = PatchEmbed( + img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) + self.patch_embed3 = PatchEmbed( + img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) + self.patch_embed4 = PatchEmbed( + img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) + + # Class tokens. + self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) + self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1])) + self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2])) + self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) + + # Convolutional position encodings. + self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3) + self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3) + self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3) + self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3) + + # Convolutional relative position encodings. + self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window) + self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window) + self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window) + self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window) + + # Disable stochastic depth. + dpr = drop_path_rate + assert dpr == 0.0 + + # Serial blocks 1. + self.serial_blocks1 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe1, shared_crpe=self.crpe1 + ) + for _ in range(serial_depths[0])] + ) + + # Serial blocks 2. + self.serial_blocks2 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe2, shared_crpe=self.crpe2 + ) + for _ in range(serial_depths[1])] + ) + + # Serial blocks 3. + self.serial_blocks3 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe3, shared_crpe=self.crpe3 + ) + for _ in range(serial_depths[2])] + ) + + # Serial blocks 4. + self.serial_blocks4 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe4, shared_crpe=self.crpe4 + ) + for _ in range(serial_depths[3])] + ) + + # Parallel blocks. + self.parallel_depth = parallel_depth + if self.parallel_depth > 0: + self.parallel_blocks = nn.ModuleList([ + ParallelBlock( + dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4) + ) + for _ in range(parallel_depth)] + ) + else: + self.parallel_blocks = None + + # Classification head(s). + if not self.return_interm_layers: + if self.parallel_blocks is not None: + self.norm2 = norm_layer(embed_dims[1]) + self.norm3 = norm_layer(embed_dims[2]) + else: + self.norm2 = self.norm3 = None + self.norm4 = norm_layer(embed_dims[3]) + + if self.parallel_depth > 0: + # CoaT series: Aggregate features of last three scales for classification. + assert embed_dims[1] == embed_dims[2] == embed_dims[3] + self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + else: + # CoaT-Lite series: Use feature of last scale for classification. + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # Initialize weights. + trunc_normal_(self.cls_token1, std=.02) + trunc_normal_(self.cls_token2, std=.02) + trunc_normal_(self.cls_token3, std=.02) + trunc_normal_(self.cls_token4, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def insert_cls(self, x, cls_token): + """ Insert CLS token. """ + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + return x + + def remove_cls(self, x): + """ Remove CLS token. """ + return x[:, 1:, :] + + def forward_features(self, x0): + B = x0.shape[0] + + # Serial blocks 1. + x1 = self.patch_embed1(x0) + H1, W1 = self.patch_embed1.grid_size + x1 = self.insert_cls(x1, self.cls_token1) + for blk in self.serial_blocks1: + x1 = blk(x1, size=(H1, W1)) + x1_nocls = self.remove_cls(x1) + x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() + + # Serial blocks 2. + x2 = self.patch_embed2(x1_nocls) + H2, W2 = self.patch_embed2.grid_size + x2 = self.insert_cls(x2, self.cls_token2) + for blk in self.serial_blocks2: + x2 = blk(x2, size=(H2, W2)) + x2_nocls = self.remove_cls(x2) + x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() + + # Serial blocks 3. + x3 = self.patch_embed3(x2_nocls) + H3, W3 = self.patch_embed3.grid_size + x3 = self.insert_cls(x3, self.cls_token3) + for blk in self.serial_blocks3: + x3 = blk(x3, size=(H3, W3)) + x3_nocls = self.remove_cls(x3) + x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() + + # Serial blocks 4. + x4 = self.patch_embed4(x3_nocls) + H4, W4 = self.patch_embed4.grid_size + x4 = self.insert_cls(x4, self.cls_token4) + for blk in self.serial_blocks4: + x4 = blk(x4, size=(H4, W4)) + x4_nocls = self.remove_cls(x4) + x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() + + # Only serial blocks: Early return. + if self.parallel_blocks is None: + if not torch.jit.is_scripting() and self.return_interm_layers: + # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + feat_out = {} + if 'x1_nocls' in self.out_features: + feat_out['x1_nocls'] = x1_nocls + if 'x2_nocls' in self.out_features: + feat_out['x2_nocls'] = x2_nocls + if 'x3_nocls' in self.out_features: + feat_out['x3_nocls'] = x3_nocls + if 'x4_nocls' in self.out_features: + feat_out['x4_nocls'] = x4_nocls + return feat_out + else: + # Return features for classification. + x4 = self.norm4(x4) + x4_cls = x4[:, 0] + return x4_cls + + # Parallel blocks. + for blk in self.parallel_blocks: + x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4)) + x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) + + if not torch.jit.is_scripting() and self.return_interm_layers: + # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + feat_out = {} + if 'x1_nocls' in self.out_features: + x1_nocls = self.remove_cls(x1) + x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x1_nocls'] = x1_nocls + if 'x2_nocls' in self.out_features: + x2_nocls = self.remove_cls(x2) + x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x2_nocls'] = x2_nocls + if 'x3_nocls' in self.out_features: + x3_nocls = self.remove_cls(x3) + x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x3_nocls'] = x3_nocls + if 'x4_nocls' in self.out_features: + x4_nocls = self.remove_cls(x4) + x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x4_nocls'] = x4_nocls + return feat_out + else: + x2 = self.norm2(x2) + x3 = self.norm3(x3) + x4 = self.norm4(x4) + x2_cls = x2[:, :1] # [B, 1, C] + x3_cls = x3[:, :1] + x4_cls = x4[:, :1] + merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C] + merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C] + return merged_cls + + def forward(self, x): + if self.return_interm_layers: + # Return intermediate features (for down-stream tasks). + return self.forward_features(x) + else: + # Return features for classification. + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + for k, v in state_dict.items(): + # original model had unused norm layers, removing them requires filtering pretrained checkpoints + if k.startswith('norm1') or \ + (model.norm2 is None and k.startswith('norm2')) or \ + (model.norm3 is None and k.startswith('norm3')): + continue + out_dict[k] = v + return out_dict + + +def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + CoaT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def coat_tiny(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, + num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def coat_mini(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, + num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def coat_lite_tiny(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def coat_lite_mini(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def coat_lite_small(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg) + return model \ No newline at end of file diff --git a/timm/models/convit.py b/timm/models/convit.py new file mode 100644 index 0000000000000000000000000000000000000000..f58249ec979dc32e5ddefa5aceb2e6143d6a4954 --- /dev/null +++ b/timm/models/convit.py @@ -0,0 +1,349 @@ +""" ConViT Model + +@article{d2021convit, + title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases}, + author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent}, + journal={arXiv preprint arXiv:2103.10697}, + year={2021} +} + +Paper link: https://arxiv.org/abs/2103.10697 +Original code: https://github.com/facebookresearch/convit, original copyright below +""" +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the CC-by-NC license found in the +# LICENSE file in the root directory of this source tree. +# +'''These modules are adapted from those of timm, see +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +''' + +import torch +import torch.nn as nn +from functools import partial +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp +from .registry import register_model +from .vision_transformer_hybrid import HybridEmbed + +import torch +import torch.nn as nn + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # ConViT + 'convit_tiny': _cfg( + url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"), + 'convit_small': _cfg( + url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"), + 'convit_base': _cfg( + url="https://dl.fbaipublicfiles.com/convit/convit_base.pth") +} + + +class GPSA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., + locality_strength=1.): + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.locality_strength = locality_strength + + self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.pos_proj = nn.Linear(3, num_heads) + self.proj_drop = nn.Dropout(proj_drop) + self.gating_param = nn.Parameter(torch.ones(self.num_heads)) + self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None + + def forward(self, x): + B, N, C = x.shape + if self.rel_indices is None or self.rel_indices.shape[1] != N: + self.rel_indices = self.get_rel_indices(N) + attn = self.get_attention(x) + v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def get_attention(self, x): + B, N, C = x.shape + qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k = qk[0], qk[1] + pos_score = self.rel_indices.expand(B, -1, -1, -1) + pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2) + patch_score = (q @ k.transpose(-2, -1)) * self.scale + patch_score = patch_score.softmax(dim=-1) + pos_score = pos_score.softmax(dim=-1) + + gating = self.gating_param.view(1, -1, 1, 1) + attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score + attn /= attn.sum(dim=-1).unsqueeze(-1) + attn = self.attn_drop(attn) + return attn + + def get_attention_map(self, x, return_map=False): + attn_map = self.get_attention(x).mean(0) # average over batch + distances = self.rel_indices.squeeze()[:, :, -1] ** .5 + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0) + if return_map: + return dist, attn_map + else: + return dist + + def local_init(self): + self.v.weight.data.copy_(torch.eye(self.dim)) + locality_distance = 1 # max(1,1/locality_strength**.5) + + kernel_size = int(self.num_heads ** .5) + center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2 + for h1 in range(kernel_size): + for h2 in range(kernel_size): + position = h1 + kernel_size * h2 + self.pos_proj.weight.data[position, 2] = -1 + self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance + self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance + self.pos_proj.weight.data *= self.locality_strength + + def get_rel_indices(self, num_patches: int) -> torch.Tensor: + img_size = int(num_patches ** .5) + rel_indices = torch.zeros(1, num_patches, num_patches, 3) + ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size, img_size) + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) + indd = indx ** 2 + indy ** 2 + rel_indices[:, :, :, 2] = indd.unsqueeze(0) + rel_indices[:, :, :, 1] = indy.unsqueeze(0) + rel_indices[:, :, :, 0] = indx.unsqueeze(0) + device = self.qk.weight.device + return rel_indices.to(device) + + +class MHSA(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def get_attention_map(self, x, return_map=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn_map = (q @ k.transpose(-2, -1)) * self.scale + attn_map = attn_map.softmax(dim=-1).mean(0) + + img_size = int(N ** .5) + ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1) + indx = ind.repeat(img_size, img_size) + indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1) + indd = indx ** 2 + indy ** 2 + distances = indd ** .5 + distances = distances.to('cuda') + + dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N + if return_map: + return dist, attn_map + else: + return dist + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): + super().__init__() + self.norm1 = norm_layer(dim) + self.use_gpsa = use_gpsa + if self.use_gpsa: + self.attn = GPSA( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs) + else: + self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ConViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, + local_up_to_layer=3, locality_strength=1., use_pos_embed=True): + super().__init__() + embed_dim *= num_heads + self.num_classes = num_classes + self.local_up_to_layer = local_up_to_layer + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.locality_strength = locality_strength + self.use_pos_embed = use_pos_embed + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + if self.use_pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.pos_embed, std=.02) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_gpsa=True, + locality_strength=locality_strength) + if i < local_up_to_layer else + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_gpsa=False) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')] + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + for n, m in self.named_modules(): + if hasattr(m, 'local_init'): + m.local_init() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) + + if self.use_pos_embed: + x = x + self.pos_embed + x = self.pos_drop(x) + + for u, blk in enumerate(self.blocks): + if u == self.local_up_to_layer: + x = torch.cat((cls_tokens, x), dim=1) + x = blk(x) + + x = self.norm(x) + return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_convit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + return build_model_with_cfg( + ConViT, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def convit_tiny(pretrained=False, **kwargs): + model_args = dict( + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args) + return model + + +@register_model +def convit_small(pretrained=False, **kwargs): + model_args = dict( + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args) + return model + + +@register_model +def convit_base(pretrained=False, **kwargs): + model_args = dict( + local_up_to_layer=10, locality_strength=1.0, embed_dim=48, + num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..a24007826594f0f23b30c778eabebce75b947915 --- /dev/null +++ b/timm/models/convmixer.py @@ -0,0 +1,101 @@ +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.registry import register_model +from .helpers import build_model_with_cfg + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .96, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'first_conv': 'stem.0', + **kwargs + } + + +default_cfgs = { + 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'), + 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'), + 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar') +} + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x): + return self.fn(x) + x + + +class ConvMixer(nn.Module): + def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs): + super().__init__() + self.num_classes = num_classes + self.num_features = dim + self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), + activation(), + nn.BatchNorm2d(dim) + ) + self.blocks = nn.Sequential( + *[nn.Sequential( + Residual(nn.Sequential( + nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), + activation(), + nn.BatchNorm2d(dim) + )), + nn.Conv2d(dim, dim, kernel_size=1), + activation(), + nn.BatchNorm2d(dim) + ) for i in range(depth)] + ) + self.pooling = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten() + ) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.pooling(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + + return x + + +def _create_convmixer(variant, pretrained=False, **kwargs): + return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + + +@register_model +def convmixer_1536_20(pretrained=False, **kwargs): + model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) + return _create_convmixer('convmixer_1536_20', pretrained, **model_args) + + +@register_model +def convmixer_768_32(pretrained=False, **kwargs): + model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs) + return _create_convmixer('convmixer_768_32', pretrained, **model_args) + + +@register_model +def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs): + model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) + return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) \ No newline at end of file diff --git a/timm/models/convnext.py b/timm/models/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..3b8a049e2e39305b720fa0c81408736500ea8efc --- /dev/null +++ b/timm/models/convnext.py @@ -0,0 +1,427 @@ +""" ConvNeXt + +Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf + +Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# This source code is licensed under the MIT license +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_module +from .helpers import named_apply, build_model_with_cfg +from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp +from .registry import register_model + + +__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), + convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), + convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), + convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), + + convnext_tiny_hnf=_cfg(url=''), + + convnext_base_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), + convnext_large_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'), + convnext_xlarge_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), + + convnext_base_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + convnext_large_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + convnext_xlarge_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + convnext_base_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), + convnext_large_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), + convnext_xlarge_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), +) + + +def _is_contiguous(tensor: torch.Tensor) -> bool: + # jit is oh so lovely :/ + # if torch.jit.is_tracing(): + # return True + if torch.jit.is_scripting(): + return tensor.is_contiguous() + else: + return tensor.is_contiguous(memory_format=torch.contiguous_format) + + +@register_notrace_module +class LayerNorm2d(nn.LayerNorm): + r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__(normalized_shape, eps=eps) + + def forward(self, x) -> torch.Tensor: + if _is_contiguous(x): + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + else: + s, u = torch.var_mean(x, dim=1, keepdim=True) + x = (x - u) * torch.rsqrt(s + self.eps) + x = x * self.weight[:, None, None] + self.bias[:, None, None] + return x + + +class ConvNeXtBlock(nn.Module): + """ ConvNeXt Block + There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): + super().__init__() + if not norm_layer: + norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + mlp_layer = ConvMlp if conv_mlp else Mlp + self.use_conv_mlp = conv_mlp + self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = norm_layer(dim) + self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = x.permute(0, 3, 1, 2) + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtStage(nn.Module): + + def __init__( + self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, + norm_layer=None, cl_norm_layer=None, cross_stage=False): + super().__init__() + + if in_chs != out_chs or stride > 1: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), + ) + else: + self.downsample = nn.Identity() + + dp_rates = dp_rates or [0.] * depth + self.blocks = nn.Sequential(*[ConvNeXtBlock( + dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, + norm_layer=norm_layer if conv_mlp else cl_norm_layer) + for j in range(depth)] + ) + + def forward(self, x): + x = self.downsample(x) + x = self.blocks(x) + return x + + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_rate (float): Head dropout rate + drop_path_rate (float): Stochastic depth rate. Default: 0. + ls_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, + depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, + head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., + ): + super().__init__() + assert output_stride == 32 + if norm_layer is None: + norm_layer = partial(LayerNorm2d, eps=1e-6) + cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + else: + assert conv_mlp,\ + 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' + cl_norm_layer = norm_layer + + self.num_classes = num_classes + self.drop_rate = drop_rate + self.feature_info = [] + + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + norm_layer(dims[0]) + ) + + self.stages = nn.Sequential() + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + curr_stride = patch_size + prev_chs = dims[0] + stages = [] + # 4 feature resolution stages, each consisting of multiple residual blocks + for i in range(4): + stride = 2 if i > 0 else 1 + # FIXME support dilation / output_stride + curr_stride *= stride + out_chs = dims[i] + stages.append(ConvNeXtStage( + prev_chs, out_chs, stride=stride, + depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, + norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) + ) + prev_chs = out_chs + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + self.num_features = prev_chs + if head_norm_first: + # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights) + self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + else: + # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) + self.norm_pre = nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool='avg'): + if isinstance(self.head, ClassifierHead): + # norm -> global pool -> fc + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + else: + # pool -> norm -> fc + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', self.head.norm), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02) + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + nn.init.constant_(module.bias, 0) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'model' in state_dict: + state_dict = state_dict['model'] + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_convnext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + ConvNeXt, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +@register_model +def convnext_tiny(pretrained=False, **kwargs): + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) + model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_tiny_hnf(pretrained=False, **kwargs): + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, **kwargs) + model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_small(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) + model = _create_convnext('convnext_small', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_base_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_large_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_xlarge_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) + return model + + + diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py new file mode 100644 index 0000000000000000000000000000000000000000..acf9a47de7c3b209ab23005ea638a121664f70b6 --- /dev/null +++ b/timm/models/crossvit.py @@ -0,0 +1,519 @@ +""" CrossViT Model + +@inproceedings{ + chen2021crossvit, + title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}}, + author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda}, + booktitle={International Conference on Computer Vision (ICCV)}, + year={2021} +} + +Paper link: https://arxiv.org/abs/2103.14899 +Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py + +NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408 + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" + +# Copyright IBM All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +""" +Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +""" +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.hub +from functools import partial +from typing import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg +from .layers import DropPath, to_2tuple, trunc_normal_, _assert +from .registry import register_model +from .vision_transformer import Mlp, Block + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, + 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'), + 'classifier': ('head.0', 'head.1'), + **kwargs + } + + +default_cfgs = { + 'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'), + 'crossvit_15_dagger_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_15_dagger_408': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth', + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, + ), + 'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'), + 'crossvit_18_dagger_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_18_dagger_408': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth', + input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0, + ), + 'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'), + 'crossvit_9_dagger_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth', + first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), + ), + 'crossvit_base_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'), + 'crossvit_small_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'), + 'crossvit_tiny_240': _cfg( + url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'), +} + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + if multi_conv: + if patch_size[0] == 12: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1), + ) + elif patch_size[0] == 16: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), + ) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + _assert(H == self.img_size[0], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + _assert(W == self.img_size[1], + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).") + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class CrossAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.wq = nn.Linear(dim, dim, bias=qkv_bias) + self.wk = nn.Linear(dim, dim, bias=qkv_bias) + self.wv = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # B1C -> B1H(C/H) -> BH1(C/H) + q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + # BNC -> BNH(C/H) -> BHN(C/H) + k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + # BNC -> BNH(C/H) -> BHN(C/H) + v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossAttention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) + + return x + + +class MultiScaleBlock(nn.Module): + + def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + + num_branches = len(dim) + self.num_branches = num_branches + # different branch could have different embedding size, the first one is the base + self.blocks = nn.ModuleList() + for d in range(num_branches): + tmp = [] + for i in range(depth[d]): + tmp.append(Block( + dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) + if len(tmp) != 0: + self.blocks.append(nn.Sequential(*tmp)) + + if len(self.blocks) == 0: + self.blocks = None + + self.projs = nn.ModuleList() + for d in range(num_branches): + if dim[d] == dim[(d + 1) % num_branches] and False: + tmp = [nn.Identity()] + else: + tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])] + self.projs.append(nn.Sequential(*tmp)) + + self.fusion = nn.ModuleList() + for d in range(num_branches): + d_ = (d + 1) % num_branches + nh = num_heads[d_] + if depth[-1] == 0: # backward capability: + self.fusion.append( + CrossAttentionBlock( + dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) + else: + tmp = [] + for _ in range(depth[-1]): + tmp.append(CrossAttentionBlock( + dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) + self.fusion.append(nn.Sequential(*tmp)) + + self.revert_projs = nn.ModuleList() + for d in range(num_branches): + if dim[(d + 1) % num_branches] == dim[d] and False: + tmp = [nn.Identity()] + else: + tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(), + nn.Linear(dim[(d + 1) % num_branches], dim[d])] + self.revert_projs.append(nn.Sequential(*tmp)) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + + outs_b = [] + for i, block in enumerate(self.blocks): + outs_b.append(block(x[i])) + + # only take the cls token out + proj_cls_token = torch.jit.annotate(List[torch.Tensor], []) + for i, proj in enumerate(self.projs): + proj_cls_token.append(proj(outs_b[i][:, 0:1, ...])) + + # cross attention + outs = [] + for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)): + tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) + tmp = fusion(tmp) + reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...]) + tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) + outs.append(tmp) + return outs + + +def _compute_num_patches(img_size, patches): + return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)] + + +@register_notrace_function +def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript + """ + Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing. + Args: + x (Tensor): input image + ss (tuple[int, int]): height and width to scale to + crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False + Returns: + Tensor: the "scaled" image batch tensor + """ + H, W = x.shape[-2:] + if H != ss[0] or W != ss[1]: + if crop_scale and ss[0] <= H and ss[1] <= W: + cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.)) + x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]] + else: + x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False) + return x + + +class CrossViT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000, + embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.), + qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False, + ): + super().__init__() + + self.num_classes = num_classes + self.img_size = to_2tuple(img_size) + img_scale = to_2tuple(img_scale) + self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale] + self.crop_scale = crop_scale # crop instead of interpolate for scale + num_patches = _compute_num_patches(self.img_size_scaled, patch_size) + self.num_branches = len(patch_size) + self.embed_dim = embed_dim + self.num_features = embed_dim[0] # to pass the tests + self.patch_embed = nn.ModuleList() + + # hard-coded for torch jit script + for i in range(self.num_branches): + setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i]))) + setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i]))) + + for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim): + self.patch_embed.append( + PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) + + self.pos_drop = nn.Dropout(p=drop_rate) + + total_depth = sum([sum(x[-2:]) for x in depth]) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule + dpr_ptr = 0 + self.blocks = nn.ModuleList() + for idx, block_cfg in enumerate(depth): + curr_depth = max(block_cfg[:-1]) + block_cfg[-1] + dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] + blk = MultiScaleBlock( + embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer) + dpr_ptr += curr_depth + self.blocks.append(blk) + + self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) + self.head = nn.ModuleList([ + nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() + for i in range(self.num_branches)]) + + for i in range(self.num_branches): + trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02) + trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + out = set() + for i in range(self.num_branches): + out.add(f'cls_token_{i}') + pe = getattr(self, f'pos_embed_{i}', None) + if pe is not None and pe.requires_grad: + out.add(f'pos_embed_{i}') + return out + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.ModuleList( + [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in + range(self.num_branches)]) + + def forward_features(self, x): + B = x.shape[0] + xs = [] + for i, patch_embed in enumerate(self.patch_embed): + x_ = x + ss = self.img_size_scaled[i] + x_ = scale_image(x_, ss, self.crop_scale) + x_ = patch_embed(x_) + cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script + cls_tokens = cls_tokens.expand(B, -1, -1) + x_ = torch.cat((cls_tokens, x_), dim=1) + pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script + x_ = x_ + pos_embed + x_ = self.pos_drop(x_) + xs.append(x_) + + for i, blk in enumerate(self.blocks): + xs = blk(xs) + + # NOTE: was before branch token section, move to here to assure all branch token are before layer norm + xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] + return [xo[:, 0] for xo in xs] + + def forward(self, x): + xs = self.forward_features(x) + ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] + if not isinstance(self.head[0], nn.Identity): + ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) + return ce_logits + + +def _create_crossvit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + def pretrained_filter_fn(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + if 'pos_embed' in key or 'cls_token' in key: + new_key = key.replace(".", "_") + else: + new_key = key + new_state_dict[new_key] = state_dict[key] + return new_state_dict + + return build_model_with_cfg( + CrossViT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=pretrained_filter_fn, + **kwargs) + + +@register_model +def crossvit_tiny_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs) + model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_small_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs) + model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_base_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], + num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs) + model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_9_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], + num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs) + model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs) + model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs) + model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_9_dagger_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], + num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_dagger_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_15_dagger_408(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], + num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_dagger_240(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args) + return model + + +@register_model +def crossvit_18_dagger_408(pretrained=False, **kwargs): + model_args = dict( + img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], + num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs) + model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py new file mode 100644 index 0000000000000000000000000000000000000000..39d16200f869f352cf9289f962c8fd77ebd3f9f9 --- /dev/null +++ b/timm/models/cspnet.py @@ -0,0 +1,457 @@ +"""PyTorch CspNet + +A PyTorch implementation of Cross Stage Partial Networks including: +* CSPResNet50 +* CSPResNeXt50 +* CSPDarkNet53 +* and DarkNet53 for good measure + +Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 + +Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer +from .registry import register_model + + +__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.887, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'cspresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'), + 'cspresnet50d': _cfg(url=''), + 'cspresnet50w': _cfg(url=''), + 'cspresnext50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth', + input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl + ), + 'cspresnext50_iabn': _cfg(url=''), + 'cspdarknet53': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'), + 'cspdarknet53_iabn': _cfg(url=''), + 'darknet53': _cfg(url=''), +} + + +model_cfgs = dict( + cspresnet50=dict( + stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(2.,) * 4, + bottle_ratio=(0.5,) * 4, + block_ratio=(1.,) * 4, + cross_linear=True, + ) + ), + cspresnet50d=dict( + stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), + stage=dict( + out_chs=(128, 256, 512, 1024), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(2.,) * 4, + bottle_ratio=(0.5,) * 4, + block_ratio=(1.,) * 4, + cross_linear=True, + ) + ), + cspresnet50w=dict( + stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'), + stage=dict( + out_chs=(256, 512, 1024, 2048), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + exp_ratio=(1.,) * 4, + bottle_ratio=(0.25,) * 4, + block_ratio=(0.5,) * 4, + cross_linear=True, + ) + ), + cspresnext50=dict( + stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'), + stage=dict( + out_chs=(256, 512, 1024, 2048), + depth=(3, 3, 5, 2), + stride=(1,) + (2,) * 3, + groups=(32,) * 4, + exp_ratio=(1.,) * 4, + bottle_ratio=(1.,) * 4, + block_ratio=(0.5,) * 4, + cross_linear=True, + ) + ), + cspdarknet53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + exp_ratio=(2.,) + (1.,) * 4, + bottle_ratio=(0.5,) + (1.0,) * 4, + block_ratio=(1.,) + (0.5,) * 4, + down_growth=True, + ) + ), + darknet53=dict( + stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''), + stage=dict( + out_chs=(64, 128, 256, 512, 1024), + depth=(1, 2, 8, 8, 4), + stride=(2,) * 5, + bottle_ratio=(0.5,) * 5, + block_ratio=(1.,) * 5, + ) + ) +) + + +def create_stem( + in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', + act_layer=None, norm_layer=None, aa_layer=None): + stem = nn.Sequential() + if not isinstance(out_chs, (tuple, list)): + out_chs = [out_chs] + assert len(out_chs) + in_c = in_chans + for i, out_c in enumerate(out_chs): + conv_name = f'conv{i + 1}' + stem.add_module(conv_name, ConvBnAct( + in_c, out_c, kernel_size, stride=stride if i == 0 else 1, + act_layer=act_layer, norm_layer=norm_layer)) + in_c = out_c + last_conv = conv_name + if pool: + if aa_layer is not None: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)) + stem.add_module('aa', aa_layer(channels=in_c, stride=2)) + else: + stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv])) + + +class ResBottleneck(nn.Module): + """ ResNe(X)t Bottleneck Block + """ + + def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResBottleneck, self).__init__() + mid_chs = int(round(out_chs * bottle_ratio)) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + + self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None + self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) + self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None + self.drop_path = drop_path + self.act3 = act_layer(inplace=True) + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.attn2 is not None: + x = self.attn2(x) + x = self.conv3(x) + if self.attn3 is not None: + x = self.attn3(x) + if self.drop_path is not None: + x = self.drop_path(x) + x = x + shortcut + # FIXME partial shortcut needed if first block handled as per original, not used for my current impl + #x[:, :shortcut.size(1)] += shortcut + x = self.act3(x) + return x + + +class DarkBlock(nn.Module): + """ DarkNet Block + """ + + def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): + super(DarkBlock, self).__init__() + mid_chs = int(round(out_chs * bottle_ratio)) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.attn = create_attn(attn_layer, channels=out_chs) + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.attn is not None: + x = self.attn(x) + if self.drop_path is not None: + x = self.drop_path(x) + x = x + shortcut + return x + + +class CrossStage(nn.Module): + """Cross Stage.""" + def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1., + groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None, + block_fn=ResBottleneck, **block_kwargs): + super(CrossStage, self).__init__() + first_dilation = first_dilation or dilation + down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels + exp_chs = int(round(out_chs * exp_ratio)) + block_out_chs = int(round(out_chs * block_ratio)) + conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) + + if stride != 1 or first_dilation != dilation: + self.conv_down = ConvBnAct( + in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) + prev_chs = down_chs + else: + self.conv_down = None + prev_chs = in_chs + + # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, + # there is also special case for the first stage for some of the model that results in uneven split + # across the two paths. I did it this way for simplicity for now. + self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + prev_chs = exp_chs // 2 # output of conv_exp is always split in two + + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + # transition convs + self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) + self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + + def forward(self, x): + if self.conv_down is not None: + x = self.conv_down(x) + x = self.conv_exp(x) + split = x.shape[1] // 2 + xs, xb = x[:, :split], x[:, split:] + xb = self.blocks(xb) + xb = self.conv_transition_b(xb).contiguous() + out = self.conv_transition(torch.cat([xs, xb], dim=1)) + return out + + +class DarkStage(nn.Module): + """DarkNet stage.""" + + def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1, + first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs): + super(DarkStage, self).__init__() + first_dilation = first_dilation or dilation + + self.conv_down = ConvBnAct( + in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, + act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), + aa_layer=block_kwargs.get('aa_layer', None)) + + prev_chs = out_chs + block_out_chs = int(round(out_chs * block_ratio)) + self.blocks = nn.Sequential() + for i in range(depth): + drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None + self.blocks.add_module(str(i), block_fn( + prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs)) + prev_chs = block_out_chs + + def forward(self, x): + x = self.conv_down(x) + x = self.blocks(x) + return x + + +def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.): + # get per stage args for stage and containing blocks, calculate strides to meet target output_stride + num_stages = len(cfg['depth']) + if 'groups' not in cfg: + cfg['groups'] = (1,) * num_stages + if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)): + cfg['down_growth'] = (cfg['down_growth'],) * num_stages + if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)): + cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages + cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \ + [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])] + stage_strides = [] + stage_dilations = [] + stage_first_dilations = [] + dilation = 1 + for cfg_stride in cfg['stride']: + stage_first_dilations.append(dilation) + if curr_stride >= output_stride: + dilation *= cfg_stride + stride = 1 + else: + stride = cfg_stride + curr_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + cfg['stride'] = stage_strides + cfg['dilation'] = stage_dilations + cfg['first_dilation'] = stage_first_dilations + stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())] + return stage_args + + +class CspNet(nn.Module): + """Cross Stage Partial base model. + + Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929 + Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks + + NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the + darknet impl. I did it this way for simplicity and less special cases. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0., + zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + + # Construct the stem + self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args) + self.feature_info = [stem_feat_info] + prev_chs = stem_feat_info['num_chs'] + curr_stride = stem_feat_info['reduction'] # reduction does not include pool + if cfg['stem']['pool']: + curr_stride *= 2 + + # Construct the stages + per_stage_args = _cfg_to_stage_args( + cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate) + self.stages = nn.Sequential() + for i, sa in enumerate(per_stage_args): + self.stages.add_module( + str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn)) + prev_chs = sa['out_chs'] + curr_stride *= sa['stride'] + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] + + # Construct the head + self.num_features = prev_chs + self.head = ClassifierHead( + in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_cspnet(variant, pretrained=False, **kwargs): + cfg_variant = variant.split('_')[0] + return build_model_with_cfg( + CspNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], + **kwargs) + + +@register_model +def cspresnet50(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnet50d(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnet50w(pretrained=False, **kwargs): + return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnext50(pretrained=False, **kwargs): + return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs) + + +@register_model +def cspresnext50_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) + + +@register_model +def cspdarknet53(pretrained=False, **kwargs): + return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs) + + +@register_model +def cspdarknet53_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) + + +@register_model +def darknet53(pretrained=False, **kwargs): + return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs) diff --git a/timm/models/densenet.py b/timm/models/densenet.py new file mode 100644 index 0000000000000000000000000000000000000000..38a19727874bfb08cf20281ec8e85b8cc7c60688 --- /dev/null +++ b/timm/models/densenet.py @@ -0,0 +1,387 @@ +"""Pytorch Densenet implementation w/ tweaks +This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with +fixed kwargs passthrough and addition of dynamic global avg/max pool. +""" +import re +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.jit.annotations import List + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier +from .registry import register_model + +__all__ = ['DenseNet'] + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.conv0', 'classifier': 'classifier', + } + + +default_cfgs = { + 'densenet121': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'), + 'densenet121d': _cfg(url=''), + 'densenetblur121d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'), + 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'), + 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'), + 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'), + 'densenet264': _cfg(url=''), + 'densenet264d_iabn': _cfg(url=''), + 'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'), +} + + +class DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d, + drop_rate=0., memory_efficient=False): + super(DenseLayer, self).__init__() + self.add_module('norm1', norm_layer(num_input_features)), + self.add_module('conv1', nn.Conv2d( + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', norm_layer(bn_size * growth_rate)), + self.add_module('conv2', nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bottleneck_fn(self, xs): + # type: (List[torch.Tensor]) -> torch.Tensor + concated_features = torch.cat(xs, 1) + bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484 + return bottleneck_output + + # todo: rewrite when torchscript supports any + def any_requires_grad(self, x): + # type: (List[torch.Tensor]) -> bool + for tensor in x: + if tensor.requires_grad: + return True + return False + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, x): + # type: (List[torch.Tensor]) -> torch.Tensor + def closure(*xs): + return self.bottleneck_fn(xs) + + return cp.checkpoint(closure, *x) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, x): # noqa: F811 + if isinstance(x, torch.Tensor): + prev_features = [x] + else: + prev_features = x + + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bottleneck_fn(prev_features) + + new_features = self.conv2(self.norm2(bottleneck_output)) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU, + drop_rate=0., memory_efficient=False): + super(DenseBlock, self).__init__() + for i in range(num_layers): + layer = DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseTransition(nn.Sequential): + def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None): + super(DenseTransition, self).__init__() + self.add_module('norm', norm_layer(num_input_features)) + self.add_module('conv', nn.Conv2d( + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + if aa_layer is not None: + self.add_module('pool', aa_layer(num_output_features, stride=2)) + else: + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, + but slower. Default: *False*. See `"paper" `_ + """ + + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='', + num_classes=1000, in_chans=3, global_pool='avg', + norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False, + aa_stem_only=True): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(DenseNet, self).__init__() + + # Stem + deep_stem = 'deep' in stem_type # 3x3 deep stem + num_init_features = growth_rate * 2 + if aa_layer is None: + stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + else: + stem_pool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=num_init_features, stride=2)]) + if deep_stem: + stem_chs_1 = stem_chs_2 = growth_rate + if 'tiered' in stem_type: + stem_chs_1 = 3 * (growth_rate // 4) + stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4) + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)), + ('norm0', norm_layer(stem_chs_1)), + ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)), + ('norm1', norm_layer(stem_chs_2)), + ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)), + ('norm2', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + else: + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('norm0', norm_layer(num_init_features)), + ('pool0', stem_pool), + ])) + self.feature_info = [ + dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')] + current_stride = 4 + + # DenseBlocks + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + norm_layer=norm_layer, + drop_rate=drop_rate, + memory_efficient=memory_efficient + ) + module_name = f'denseblock{(i + 1)}' + self.features.add_module(module_name, block) + num_features = num_features + num_layers * growth_rate + transition_aa_layer = None if aa_stem_only else aa_layer + if i != len(block_config) - 1: + self.feature_info += [ + dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] + current_stride *= 2 + trans = DenseTransition( + num_input_features=num_features, num_output_features=num_features // 2, + norm_layer=norm_layer, aa_layer=transition_aa_layer) + self.features.add_module(f'transition{i + 1}', trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', norm_layer(num_features)) + + self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] + self.num_features = num_features + + # Linear layer + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + # both classifier and block drop? + # if self.drop_rate > 0.: + # x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +def _filter_torchvision_pretrained(state_dict): + pattern = re.compile( + r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + + for key in list(state_dict.keys()): + res = pattern.match(key) + if res: + new_key = res.group(1) + res.group(2) + state_dict[new_key] = state_dict[key] + del state_dict[key] + return state_dict + + +def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): + kwargs['growth_rate'] = growth_rate + kwargs['block_config'] = block_config + return build_model_with_cfg( + DenseNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, + **kwargs) + + +@register_model +def densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenetblur121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', + aa_layer=BlurPool2d, **kwargs) + return model + + +@register_model +def densenet121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep', + pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet169(pretrained=False, **kwargs): + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet201(pretrained=False, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet161(pretrained=False, **kwargs): + r"""Densenet-161 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264(pretrained=False, **kwargs): + r"""Densenet-264 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs) + return model + + +@register_model +def densenet264d_iabn(pretrained=False, **kwargs): + r"""Densenet-264 model with deep stem and Inplace-ABN + """ + def norm_act_fn(num_features, **kwargs): + return create_norm_act('iabn', num_features, **kwargs) + model = _create_densenet( + 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', + norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tv_densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model with original Torchvision weights, from + `"Densely Connected Convolutional Networks" ` + """ + model = _create_densenet( + 'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/dla.py b/timm/models/dla.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e4dd285db53cd547ecb1f913219890517e3c00 --- /dev/null +++ b/timm/models/dla.py @@ -0,0 +1,443 @@ +""" Deep Layer Aggregation and DLA w/ Res2Net +DLA original adapted from Official Pytorch impl at: +DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484 + +Res2Net additions from: https://github.com/gasvn/Res2Net/ +Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 +""" +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['DLA'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'base_layer.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'dla34': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth'), + 'dla46_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46_c-2bfd52c3.pth'), + 'dla46x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth'), + 'dla60x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x_c-b870c45c.pth'), + 'dla60': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60-24839fc4.pth'), + 'dla60x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x-d15cacda.pth'), + 'dla102': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102-d94d9790.pth'), + 'dla102x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x-ad62be81.pth'), + 'dla102x2': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x2-262837b6.pth'), + 'dla169': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla169-0914e092.pth'), + 'dla60_res2net': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'), + 'dla60_res2next': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'), +} + + +class DlaBasic(nn.Module): + """DLA Basic""" + + def __init__(self, inplanes, planes, stride=1, dilation=1, **_): + super(DlaBasic, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.stride = stride + + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + out += shortcut + out = self.relu(out) + + return out + + +class DlaBottleneck(nn.Module): + """DLA/DLA-X Bottleneck""" + expansion = 2 + + def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64): + super(DlaBottleneck, self).__init__() + self.stride = stride + mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) + mid_planes = mid_planes // self.expansion + + self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes) + self.conv2 = nn.Conv2d( + mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation, + bias=False, dilation=dilation, groups=cardinality) + self.bn2 = nn.BatchNorm2d(mid_planes) + self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out += shortcut + out = self.relu(out) + + return out + + +class DlaBottle2neck(nn.Module): + """ Res2Net/Res2NeXT DLA Bottleneck + Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py + """ + expansion = 2 + + def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4): + super(DlaBottle2neck, self).__init__() + self.is_first = stride > 1 + self.scale = scale + mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality) + mid_planes = mid_planes // self.expansion + self.width = mid_planes + + self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_planes * scale) + + num_scale_convs = max(1, scale - 1) + convs = [] + bns = [] + for _ in range(num_scale_convs): + convs.append(nn.Conv2d( + mid_planes, mid_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, groups=cardinality, bias=False)) + bns.append(nn.BatchNorm2d(mid_planes)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + if self.is_first: + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + + self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(outplanes) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x, shortcut=None): + if shortcut is None: + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + spo = [] + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + sp = spx[i] if i == 0 or self.is_first else sp + spx[i] + sp = conv(sp) + sp = bn(sp) + sp = self.relu(sp) + spo.append(sp) + if self.scale > 1: + spo.append(self.pool(spx[-1]) if self.is_first else spx[-1]) + out = torch.cat(spo, 1) + + out = self.conv3(out) + out = self.bn3(out) + + out += shortcut + out = self.relu(out) + + return out + + +class DlaRoot(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, shortcut): + super(DlaRoot, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.shortcut = shortcut + + def forward(self, *x): + children = x + x = self.conv(torch.cat(x, 1)) + x = self.bn(x) + if self.shortcut: + x += children[0] + x = self.relu(x) + + return x + + +class DlaTree(nn.Module): + def __init__(self, levels, block, in_channels, out_channels, stride=1, + dilation=1, cardinality=1, base_width=64, + level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False): + super(DlaTree, self).__init__() + if root_dim == 0: + root_dim = 2 * out_channels + if level_root: + root_dim += in_channels + self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity() + self.project = nn.Identity() + cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width) + if levels == 1: + self.tree1 = block(in_channels, out_channels, stride, **cargs) + self.tree2 = block(out_channels, out_channels, 1, **cargs) + if in_channels != out_channels: + # NOTE the official impl/weights have project layers in levels > 1 case that are never + # used, I've moved the project layer here to avoid wasted params but old checkpoints will + # need strict=False while loading. + self.project = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(out_channels)) + else: + cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut)) + self.tree1 = DlaTree( + levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs) + self.tree2 = DlaTree( + levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs) + if levels == 1: + self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut) + self.level_root = level_root + self.root_dim = root_dim + self.levels = levels + + def forward(self, x, shortcut=None, children=None): + children = [] if children is None else children + bottom = self.downsample(x) + shortcut = self.project(bottom) + if self.level_root: + children.append(bottom) + x1 = self.tree1(x, shortcut) + if self.levels == 1: + x2 = self.tree2(x1) + x = self.root(x2, x1, *children) + else: + children.append(x1) + x = self.tree2(x1, children=children) + return x + + +class DLA(nn.Module): + def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False, + drop_rate=0.0, global_pool='avg'): + super(DLA, self).__init__() + self.channels = channels + self.num_classes = num_classes + self.cardinality = cardinality + self.base_width = base_width + self.drop_rate = drop_rate + assert output_stride == 32 # FIXME support dilation + + self.base_layer = nn.Sequential( + nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False), + nn.BatchNorm2d(channels[0]), + nn.ReLU(inplace=True)) + self.level0 = self._make_conv_level(channels[0], channels[0], levels[0]) + self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2) + cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root) + self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs) + self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs) + self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs) + self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs) + self.feature_info = [ + dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level + dict(num_chs=channels[1], reduction=2, module='level1'), + dict(num_chs=channels[2], reduction=4, module='level2'), + dict(num_chs=channels[3], reduction=8, module='level3'), + dict(num_chs=channels[4], reduction=16, module='level4'), + dict(num_chs=channels[5], reduction=32, module='level5'), + ] + + self.num_features = channels[-1] + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + modules = [] + for i in range(convs): + modules.extend([ + nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1, + padding=dilation, bias=False, dilation=dilation), + nn.BatchNorm2d(planes), + nn.ReLU(inplace=True)]) + inplanes = planes + return nn.Sequential(*modules) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + def forward_features(self, x): + x = self.base_layer(x) + x = self.level0(x) + x = self.level1(x) + x = self.level2(x) + x = self.level3(x) + x = self.level4(x) + x = self.level5(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + x = self.flatten(x) + return x + + +def _create_dla(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DLA, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=False, + feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), + **kwargs) + + +@register_model +def dla60_res2net(pretrained=False, **kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs) + return _create_dla('dla60_res2net', pretrained, **model_kwargs) + + +@register_model +def dla60_res2next(pretrained=False,**kwargs): + model_kwargs = dict( + levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024), + block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs) + return _create_dla('dla60_res2next', pretrained, **model_kwargs) + + +@register_model +def dla34(pretrained=False, **kwargs): # DLA-34 + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], + block=DlaBasic, **kwargs) + return _create_dla('dla34', pretrained, **model_kwargs) + + +@register_model +def dla46_c(pretrained=False, **kwargs): # DLA-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, **kwargs) + return _create_dla('dla46_c', pretrained, **model_kwargs) + + +@register_model +def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla46x_c', pretrained, **model_kwargs) + + +@register_model +def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x_c', pretrained, **model_kwargs) + + +@register_model +def dla60(pretrained=False, **kwargs): # DLA-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, **kwargs) + return _create_dla('dla60', pretrained, **model_kwargs) + + +@register_model +def dla60x(pretrained=False, **kwargs): # DLA-X-60 + model_kwargs = dict( + levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, **kwargs) + return _create_dla('dla60x', pretrained, **model_kwargs) + + +@register_model +def dla102(pretrained=False, **kwargs): # DLA-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, shortcut_root=True, **kwargs) + return _create_dla('dla102', pretrained, **model_kwargs) + + +@register_model +def dla102x(pretrained=False, **kwargs): # DLA-X-102 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs) + return _create_dla('dla102x', pretrained, **model_kwargs) + + +@register_model +def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64 + model_kwargs = dict( + levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs) + return _create_dla('dla102x2', pretrained, **model_kwargs) + + +@register_model +def dla169(pretrained=False, **kwargs): # DLA-169 + model_kwargs = dict( + levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024], + block=DlaBottleneck, shortcut_root=True, **kwargs) + return _create_dla('dla169', pretrained, **model_kwargs) diff --git a/timm/models/dpn.py b/timm/models/dpn.py new file mode 100644 index 0000000000000000000000000000000000000000..c4e380b1e31d63a6a6381352eb2ce1555fbbff3b --- /dev/null +++ b/timm/models/dpn.py @@ -0,0 +1,317 @@ +""" PyTorch implementation of DualPathNetworks +Based on original MXNet implementation https://github.com/cypw/DPNs with +many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs. + +This implementation is compatible with the pretrained weights from cypw's MXNet implementation. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier +from .registry import register_model + +__all__ = ['DPN'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD, + 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'dpn68': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), + 'dpn68b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'dpn92': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), + 'dpn98': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), + 'dpn131': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'), + 'dpn107': _cfg( + url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth') +} + + +class CatBnAct(nn.Module): + def __init__(self, in_chs, norm_layer=BatchNormAct2d): + super(CatBnAct, self).__init__() + self.bn = norm_layer(in_chs, eps=0.001) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (torch.Tensor) + pass + + def forward(self, x): + if isinstance(x, tuple): + x = torch.cat(x, dim=1) + return self.bn(x) + + +class BnActConv2d(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=BatchNormAct2d): + super(BnActConv2d, self).__init__() + self.bn = norm_layer(in_chs, eps=0.001) + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups) + + def forward(self, x): + return self.conv(self.bn(x)) + + +class DualPathBlock(nn.Module): + def __init__( + self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): + super(DualPathBlock, self).__init__() + self.num_1x1_c = num_1x1_c + self.inc = inc + self.b = b + if block_type == 'proj': + self.key_stride = 1 + self.has_proj = True + elif block_type == 'down': + self.key_stride = 2 + self.has_proj = True + else: + assert block_type == 'normal' + self.key_stride = 1 + self.has_proj = False + + self.c1x1_w_s1 = None + self.c1x1_w_s2 = None + if self.has_proj: + # Using different member names here to allow easier parameter key matching for conversion + if self.key_stride == 2: + self.c1x1_w_s2 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2) + else: + self.c1x1_w_s1 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) + + self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) + self.c3x3_b = BnActConv2d( + in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups) + if b: + self.c1x1_c = CatBnAct(in_chs=num_3x3_b) + self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1) + self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1) + else: + self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) + self.c1x1_c1 = None + self.c1x1_c2 = None + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] + pass + + def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(x, tuple): + x_in = torch.cat(x, dim=1) + else: + x_in = x + if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None: + # self.has_proj == False, torchscript requires condition on module == None + x_s1 = x[0] + x_s2 = x[1] + else: + # self.has_proj == True + if self.c1x1_w_s1 is not None: + # self.key_stride = 1 + x_s = self.c1x1_w_s1(x_in) + else: + # self.key_stride = 2 + x_s = self.c1x1_w_s2(x_in) + x_s1 = x_s[:, :self.num_1x1_c, :, :] + x_s2 = x_s[:, self.num_1x1_c:, :, :] + x_in = self.c1x1_a(x_in) + x_in = self.c3x3_b(x_in) + x_in = self.c1x1_c(x_in) + if self.c1x1_c1 is not None: + # self.b == True, using None check for torchscript compat + out1 = self.c1x1_c1(x_in) + out2 = self.c1x1_c2(x_in) + else: + out1 = x_in[:, :self.num_1x1_c, :, :] + out2 = x_in[:, self.num_1x1_c:, :, :] + resid = x_s1 + out1 + dense = torch.cat([x_s2, out2], dim=1) + return resid, dense + + +class DPN(nn.Module): + def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, + b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, + num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU): + super(DPN, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.b = b + assert output_stride == 32 # FIXME look into dilation support + norm_layer = partial(BatchNormAct2d, eps=.001) + fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False) + bw_factor = 1 if small else 4 + blocks = OrderedDict() + + # conv1 + blocks['conv1_1'] = ConvBnAct( + in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) + blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] + + # conv2 + bw = 64 * bw_factor + inc = inc_sec[0] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[0] + 1): + blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')] + + # conv3 + bw = 128 * bw_factor + inc = inc_sec[1] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[1] + 1): + blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')] + + # conv4 + bw = 256 * bw_factor + inc = inc_sec[2] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[2] + 1): + blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')] + + # conv5 + bw = 512 * bw_factor + inc = inc_sec[3] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[3] + 1): + blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')] + + blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer) + + self.num_features = in_chs + self.features = nn.Sequential(blocks) + + # Using 1x1 conv for the FC layer to allow the extra pooling scheme + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + x = self.flatten(x) + return x + + +def _create_dpn(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + DPN, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_concat=True, flatten_sequential=True), + **kwargs) + + +@register_model +def dpn68(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn68b(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) + return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn92(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=64, k_r=96, groups=32, + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) + return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn98(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=96, k_r=160, groups=40, + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn131(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=128, k_r=160, groups=40, + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) + return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) + + +@register_model +def dpn107(pretrained=False, **kwargs): + model_kwargs = dict( + num_init_features=128, k_r=200, groups=50, + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) + return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6426b540c303e14a42c8b213c961b108be0bf76d --- /dev/null +++ b/timm/models/efficientnet.py @@ -0,0 +1,2145 @@ +""" The EfficientNet Family in PyTorch + +An implementation of EfficienNet that covers variety of related models with efficient architectures: + +* EfficientNet-V2 + - `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available +by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing +the models and weights open source! + +Hacked together by / Copyright 2021 Ross Wightman +""" +from functools import partial +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import create_conv2d, create_classifier +from .registry import register_model + +__all__ = ['EfficientNet', 'EfficientNetFeatures'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mnasnet_050': _cfg(url=''), + 'mnasnet_075': _cfg(url=''), + 'mnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'), + 'mnasnet_140': _cfg(url=''), + + 'semnasnet_050': _cfg(url=''), + 'semnasnet_075': _cfg(url=''), + 'semnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'), + 'semnasnet_140': _cfg(url=''), + 'mnasnet_small': _cfg(url=''), + + 'mobilenetv2_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth'), + 'mobilenetv2_110d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth'), + 'mobilenetv2_120d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth'), + 'mobilenetv2_140': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth'), + + 'fbnetc_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + interpolation='bilinear'), + 'spnasnet_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + interpolation='bilinear'), + + # NOTE experimenting with alternate attention + 'eca_efficientnet_b0': _cfg( + url=''), + 'gc_efficientnet_b0': _cfg( + url=''), + + 'efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'), + 'efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + test_input_size=(3, 256, 256), crop_pct=1.0), + 'efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0), + 'efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth', + input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0), + 'efficientnet_b5': _cfg( + url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'efficientnet_b6': _cfg( + url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'efficientnet_b7': _cfg( + url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_b8': _cfg( + url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + 'efficientnet_l2': _cfg( + url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + + 'efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), + 'efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el': _cfg( + url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_es_pruned': _cfg( + url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_es_pruned75.pth'), + 'efficientnet_el_pruned': _cfg( + url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el_pruned70.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'efficientnet_cc_b0_4e': _cfg(url=''), + 'efficientnet_cc_b0_8e': _cfg(url=''), + 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth'), + 'efficientnet_lite1': _cfg( + url='', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_lite2': _cfg( + url='', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'efficientnet_lite3': _cfg( + url='', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'efficientnet_lite4': _cfg( + url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + + 'efficientnet_b1_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b2_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'efficientnet_b3_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'efficientnetv2_rw_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_rw_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + + 'efficientnetv2_s': _cfg( + url='', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_m': _cfg( + url='', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + 'efficientnetv2_l': _cfg( + url='', + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnet_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_b8_ap': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), + + 'tf_efficientnet_b0_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + input_size=(3, 224, 224)), + 'tf_efficientnet_b1_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_b2_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890), + 'tf_efficientnet_b3_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + 'tf_efficientnet_b4_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922), + 'tf_efficientnet_b5_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934), + 'tf_efficientnet_b6_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), + 'tf_efficientnet_b7_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_l2_ns_475': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), + 'tf_efficientnet_l2_ns': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), + + 'tf_efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), + + 'tf_efficientnet_cc_b0_4e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b0_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_efficientnet_cc_b1_8e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + + 'tf_efficientnet_lite0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, + interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res + ), + 'tf_efficientnet_lite3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'), + 'tf_efficientnet_lite4': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), + + 'tf_efficientnetv2_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_s_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l_in21ft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_s_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', + input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)), + 'tf_efficientnetv2_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth', + input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882), + 'tf_efficientnetv2_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth', + input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890), + 'tf_efficientnetv2_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth', + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + + 'mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), + 'mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'), + 'mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'), + 'mixnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'), + 'mixnet_xxl': _cfg(), + + 'tf_mixnet_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'), + 'tf_mixnet_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), + 'tf_mixnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), +} + + +class EfficientNet(nn.Module): + """ (Generic) EfficientNet + + A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet-V2 Small, Medium, Large & B0-B3 + * EfficientNet B0-B8, L2 + * EfficientNet-EdgeTPU + * EfficientNet-CondConv + * MixNet S, M, L, XL + * MnasNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + + """ + + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False, + output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None, + se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'): + super(EfficientNet, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) + self.bn2 = norm_layer(self.num_features) + self.act2 = act_layer(inplace=True) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) + layers.extend([nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class EfficientNetFeatures(nn.Module): + """ EfficientNet Feature Extractor + + A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, + act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + super(EfficientNetFeatures, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.drop_rate = drop_rate + + # Stem + if not fix_stem: + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + feature_location=feature_location) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_effnet(variant, pretrained=False, **kwargs): + features_only = False + model_cls = EfficientNet + kwargs_filter = None + if kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') + model_cls = EfficientNetFeatures + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_chs_fn(1280), + stem_size=32, + fix_stem=fix_stem_head, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'relu6'), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-EdgeTPU model + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu + """ + + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'relu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an EfficientNet-CondConv model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and + # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'swish'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + act_layer=resolve_act_layer(kwargs, 'relu6'), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_base( + variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 base model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + arch_def = [ + ['cn_r1_k3_s1_e1_c16_skip'], + ['er_r2_k3_s2_e4_c32'], + ['er_r2_k3_s2_e4_c48'], + ['ir_r3_k3_s2_e4_c96_se0.25'], + ['ir_r5_k3_s1_e6_c112_se0.25'], + ['ir_r8_k3_s2_e6_c192_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_s( + variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Small model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + + NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model, + before ref the impl was released. + """ + arch_def = [ + ['cn_r2_k3_s1_e1_c24_skip'], + ['er_r4_k3_s2_e4_c48'], + ['er_r4_k3_s2_e4_c64'], + ['ir_r6_k3_s2_e4_c128_se0.25'], + ['ir_r9_k3_s1_e6_c160_se0.25'], + ['ir_r15_k3_s2_e6_c256_se0.25'], + ] + num_features = 1280 + if rw: + # my original variant, based on paper figure differs from the official release + arch_def[0] = ['er_r2_k3_s1_e1_c24'] + arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25'] + num_features = 1792 + + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(num_features), + stem_size=24, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Medium model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r3_k3_s1_e1_c24_skip'], + ['er_r5_k3_s2_e4_c48'], + ['er_r5_k3_s2_e4_c80'], + ['ir_r7_k3_s2_e4_c160_se0.25'], + ['ir_r14_k3_s1_e6_c176_se0.25'], + ['ir_r18_k3_s2_e6_c304_se0.25'], + ['ir_r5_k3_s1_e6_c512_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=24, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Large model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r4_k3_s1_e1_c32_skip'], + ['er_r7_k3_s2_e4_c64'], + ['er_r7_k3_s2_e4_c96'], + ['ir_r10_k3_s2_e4_c192_se0.25'], + ['ir_r19_k3_s1_e6_c224_se0.25'], + ['ir_r25_k3_s2_e6_c384_se0.25'], + ['ir_r7_k3_s1_e6_c640_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +@register_model +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +@register_model +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +@register_model +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def eca_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ ECA attn """ + # NOTE experimental config + model = _gen_efficientnet( + 'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def gc_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ GlobalContext """ + # NOTE experminetal config + model = _gen_efficientnet( + 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2a(pretrained=False, **kwargs): + """ EfficientNet-B2 @ 288x288 w/ 1.0 test crop""" + # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now + return efficientnet_b2(pretrained=pretrained, **kwargs) + + +@register_model +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3a(pretrained=False, **kwargs): + """ EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """ + # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now + return efficientnet_b3(pretrained=pretrained, **kwargs) + + +@register_model +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2.""" + # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_es_pruned(pretrained=False, **kwargs): + """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" + model = _gen_efficientnet_edge( + 'efficientnet_es_pruned', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_el_pruned(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" + model = _gen_efficientnet_edge( + 'efficientnet_el_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + +@register_model +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b1_pruned(pretrained=False, **kwargs): + """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + variant = 'efficientnet_b1_pruned' + model = _gen_efficientnet( + variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b2_pruned(pretrained=False, **kwargs): + """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_pruned(pretrained=False, **kwargs): + """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small RW variant. + NOTE: This is my initial (pre official code release) w/ some differences. + See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding + """ + model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_rw_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium RW variant. + """ + model = _gen_efficientnetv2_s( + 'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. """ + model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. """ + model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_l(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. """ + model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + + +@register_model +def tf_efficientnetv2_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l_in21ft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21ft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l_in21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b0(pretrained=False, **kwargs): + """ EfficientNet-V2-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b1(pretrained=False, **kwargs): + """ EfficientNet-V2-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b2(pretrained=False, **kwargs): + """ EfficientNet-V2-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b3(pretrained=False, **kwargs): + """ EfficientNet-V2-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..b43f38f5868cd78a9ba821154c7c6247b7572a0d --- /dev/null +++ b/timm/models/efficientnet_blocks.py @@ -0,0 +1,324 @@ +""" EfficientNet, MobileNetV3, etc Blocks + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .layers import create_conv2d, drop_path, make_divisible, create_act_layer +from .layers.activations import sigmoid + +__all__ = [ + 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] + + +class SqueezeExcite(nn.Module): + """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family + + Args: + in_chs (int): input channels to layer + rd_ratio (float): ratio of squeeze reduction + act_layer (nn.Module): activation layer of containing block + gate_layer (Callable): attention gate function + force_act_layer (nn.Module): override block's activation fn if this is set/bound + rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs + """ + + def __init__( + self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU, + gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None): + super(SqueezeExcite, self).__init__() + if rd_channels is None: + rd_round_fn = rd_round_fn or round + rd_channels = rd_round_fn(in_chs * rd_ratio) + act_layer = force_act_layer or act_layer + self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True) + self.act1 = create_act_layer(act_layer, inplace=True) + self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + return x * self.gate(x_se) + + +class ConvBnAct(nn.Module): + """ Conv + Norm Layer + Activation w/ optional skip connection + """ + def __init__( + self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='', + skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): + super(ConvBnAct, self).__init__() + self.has_residual = skip and stride == 1 and in_chs == out_chs + self.drop_path_rate = drop_path_rate + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(out_chs) + self.act1 = act_layer(inplace=True) + + def feature_info(self, location): + if location == 'expansion': # output of conv after act, same as block coutput + info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv.out_channels) + return info + + def forward(self, x): + shortcut = x + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + se_layer=None, drop_path_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.drop_path_rate = drop_path_rate + + self.conv_dw = create_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() + + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PW + info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) + return info + + def forward(self, x): + shortcut = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE + + Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often + referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in + * MNasNet - https://arxiv.org/abs/1807.11626 + * EfficientNet - https://arxiv.org/abs/1905.11946 + * MobileNet-V3 - https://arxiv.org/abs/1905.02244 + """ + + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): + super(InvertedResidual, self).__init__() + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Point-wise expansion + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = create_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs) + + def feature_info(self, location): + if location == 'expansion': # after SE, input to PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + shortcut = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, + drop_path_rate=drop_path_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + shortcut = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + return x + + +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride + + Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML` + - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html + + This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers + * MobileDet - https://arxiv.org/abs/2004.14525 + * EfficientNet-X - https://arxiv.org/abs/2102.05610 + * EfficientNet-V2 - https://arxiv.org/abs/2104.00298 + """ + + def __init__( + self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', + force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): + super(EdgeResidual, self).__init__() + if force_in_chs > 0: + mid_chs = make_divisible(force_in_chs * exp_ratio) + else: + mid_chs = make_divisible(in_chs * exp_ratio) + has_se = se_layer is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Expansion convolution + self.conv_exp = create_conv2d( + in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.bn1 = norm_layer(mid_chs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs) + + def feature_info(self, location): + if location == 'expansion': # after SE, before PWL + info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + else: # location == 'bottleneck', block output + info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) + return info + + def forward(self, x): + shortcut = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut + + return x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a23e8273d98891b4663e7e8fa409700ef0818122 --- /dev/null +++ b/timm/models/efficientnet_builder.py @@ -0,0 +1,463 @@ +""" EfficientNet, MobileNetV3, etc Builder + +Assembles EfficieNet and related network feature blocks from string definitions. +Handles stride, dilation calculations, and selects feature extraction points. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +import math +import re +from copy import deepcopy +from functools import partial + +import torch.nn as nn + +from .efficientnet_blocks import * +from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible + +__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", + 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] + +_logger = logging.getLogger(__name__) + + +_DEBUG_BUILDER = False + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +def resolve_act_layer(kwargs, default='relu'): + return get_act_layer(kwargs.pop('act_layer', default)) + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit) + + +def _log_info_if(msg, condition): + if condition: + _logger.info(msg) + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + skip = None + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + skip = False # force no skip connection + elif op == 'skip': + skip = True # force a skip connection + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') # aka SiLU + elif v == 'mi': + value = get_act_layer('mish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else 0., + stride=int(options['s']), + act_layer=act_layer, + noskip=skip is False, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else 0., + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or skip is False, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + force_in_chs=force_in_chs, + se_ratio=float(options['se']) if 'se' in options else 0., + stride=int(options['s']), + act_layer=act_layer, + noskip=skip is False, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + skip=skip is True, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + if isinstance(depth_multiplier, tuple): + assert len(depth_multiplier) == len(arch_def) + else: + depth_multiplier = (depth_multiplier,) * len(arch_def) + for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc)) + return arch_args + + +class EfficientNetBuilder: + """ Build Trunk Blocks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False, + act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''): + self.output_stride = output_stride + self.pad_type = pad_type + self.round_chs_fn = round_chs_fn + self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs + self.act_layer = act_layer + self.norm_layer = norm_layer + self.se_layer = get_attn(se_layer) + try: + self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg + self.se_has_ratio = True + except TypeError: + self.se_has_ratio = False + self.drop_path_rate = drop_path_rate + if feature_location == 'depthwise': + # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense + _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") + feature_location = 'expansion' + self.feature_location = feature_location + assert feature_location in ('bottleneck', 'expansion', '') + self.verbose = _DEBUG_BUILDER + + # state updated during build, consumed by model + self.in_chs = None + self.features = [] + + def _make_block(self, ba, block_idx, block_count): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self.round_chs_fn(ba['out_chs']) + if 'force_in_chs' in ba and ba['force_in_chs']: + # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl + ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + ba['norm_layer'] = self.norm_layer + ba['drop_path_rate'] = drop_path_rate + if bt != 'cn': + se_ratio = ba.pop('se_ratio') + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + # adjust se_ratio by expansion ratio if calculating se channels from block input + se_ratio /= ba.get('exp_ratio', 1.0) + if self.se_has_ratio: + ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio) + else: + ba['se_layer'] = self.se_layer + + if bt == 'ir': + _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = EdgeResidual(**ba) + elif bt == 'cn': + _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose) + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def __call__(self, in_chs, model_block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + model_block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose) + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + stages = [] + if model_block_args[0][0]['stride'] > 1: + # if the first block starts with a stride, we need to extract first level feat from stem + feature_info = dict( + module='act1', num_chs=in_chs, stage=0, reduction=current_stride, + hook_type='forward' if self.feature_location != 'bottleneck' else '') + self.features.append(feature_info) + + # outer list of block_args defines the stacks + for stack_idx, stack_args in enumerate(model_block_args): + last_stack = stack_idx + 1 == len(model_block_args) + _log_info_if('Stack: {}'.format(stack_idx), self.verbose) + assert isinstance(stack_args, list) + + blocks = [] + # each stack (stage of blocks) contains a list of block arguments + for block_idx, block_args in enumerate(stack_args): + last_block = block_idx + 1 == len(stack_args) + _log_info_if(' Block: {}'.format(block_idx), self.verbose) + + assert block_args['stride'] in (1, 2) + if block_idx >= 1: # only the first block in any stack can have a stride > 1 + block_args['stride'] = 1 + + extract_features = False + if last_block: + next_stack_idx = stack_idx + 1 + extract_features = next_stack_idx >= len(model_block_args) or \ + model_block_args[next_stack_idx][0]['stride'] > 1 + + next_dilation = current_dilation + if block_args['stride'] > 1: + next_output_stride = current_stride * block_args['stride'] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args['stride'] + block_args['stride'] = 1 + _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format( + self.output_stride), self.verbose) + else: + current_stride = next_output_stride + block_args['dilation'] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + + # create the block + block = self._make_block(block_args, total_block_idx, total_block_count) + blocks.append(block) + + # stash feature module name and channel info for model feature extraction + if extract_features: + feature_info = dict( + stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location)) + module_name = f'blocks.{stack_idx}.{block_idx}' + leaf_name = feature_info.get('module', '') + feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name + self.features.append(feature_info) + + total_block_idx += 1 # incr global block idx (across all stacks) + stages.append(nn.Sequential(*blocks)) + return stages + + +def _init_weight_goog(m, n='', fix_group_fanout=True): + """ Weight initialization as per Tensorflow official implementations. + + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs + + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: + * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + """ + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + nn.init.uniform_(m.weight, -init_range, init_range) + nn.init.zeros_(m.bias) + + +def efficientnet_init_weights(model: nn.Module, init_fn=None): + init_fn = init_fn or _init_weight_goog + for n, m in model.named_modules(): + init_fn(m, n) + diff --git a/timm/models/factory.py b/timm/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..d040a9ff62c0a4089536078fee0e9552ab3cdabc --- /dev/null +++ b/timm/models/factory.py @@ -0,0 +1,86 @@ +from .registry import is_model, is_model_in_modules, model_entrypoint +from .helpers import load_checkpoint +from .layers import set_layer_config +from .hub import load_model_config_from_hf + + +def split_model_name(model_name): + model_split = model_name.split(':', 1) + if len(model_split) == 1: + return '', model_split[0] + else: + source_name, model_name = model_split + assert source_name in ('timm', 'hf_hub') + return source_name, model_name + + +def safe_model_name(model_name, remove_source=True): + def make_safe(name): + return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') + if remove_source: + model_name = split_model_name(model_name)[-1] + return make_safe(model_name) + + +def create_model( + model_name, + pretrained=False, + checkpoint_path='', + scriptable=None, + exportable=None, + no_jit=None, + **kwargs): + """Create a model + + Args: + model_name (str): name of model to instantiate + pretrained (bool): load pretrained ImageNet-1k weights if true + checkpoint_path (str): path of checkpoint to load after model is initialized + scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) + exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) + no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) + + Keyword Args: + drop_rate (float): dropout rate for training (default: 0.0) + global_pool (str): global pool type (default: 'avg') + **: other kwargs are model specific + """ + source_name, model_name = split_model_name(model_name) + + # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args + is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) + if not is_efficientnet: + kwargs.pop('bn_tf', None) + kwargs.pop('bn_momentum', None) + kwargs.pop('bn_eps', None) + + # handle backwards compat with drop_connect -> drop_path change + drop_connect_rate = kwargs.pop('drop_connect_rate', None) + if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: + print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'." + " Setting drop_path to %f." % drop_connect_rate) + kwargs['drop_path_rate'] = drop_connect_rate + + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if source_name == 'hf_hub': + # For model names specified in the form `hf_hub:path/architecture_name#revision`, + # load model weights + default_cfg from Hugging Face hub. + hf_default_cfg, model_name = load_model_config_from_hf(model_name) + kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday + + if is_model(model_name): + create_fn = model_entrypoint(model_name) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + model = create_fn(pretrained=pretrained, **kwargs) + + if checkpoint_path: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/timm/models/features.py b/timm/models/features.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d6890f3ed07311c5484b4a397c3b1da555880a --- /dev/null +++ b/timm/models/features.py @@ -0,0 +1,284 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torcscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): + super(FeatureListNet, self).__init__( + model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, + flatten_sequential=flatten_sequential) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, model, + out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, + feature_concat=False, flatten_sequential=False, default_hook_type='forward'): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts()} + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py new file mode 100644 index 0000000000000000000000000000000000000000..5a25ee3e503cabbe41dd57a0877941b376481b00 --- /dev/null +++ b/timm/models/fx_features.py @@ -0,0 +1,73 @@ +""" PyTorch FX Based Feature Extraction Helpers +Using https://pytorch.org/vision/stable/feature_extraction.html +""" +from typing import Callable +from torch import nn + +from .features import _get_feature_info + +try: + from torchvision.models.feature_extraction import create_feature_extractor + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + +# Layers we went to treat as leaf modules +from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath +from .layers.non_local_attn import BilinearAttnTransform +from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below +_leaf_modules = { + BatchNormAct2d, # reason: flow control for jit scripting + BilinearAttnTransform, # reason: flow control t <= 1 + BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) + DropPath, # reason: TypeError: rand recieved Proxy in `size` argument +} + +try: + from .layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +def register_notrace_module(module: nn.Module): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +class FeatureGraphNet(nn.Module): + def __init__(self, model, out_indices, out_map=None): + super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + return_nodes = {info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + + def forward(self, x): + return list(self.graph_module(x).values()) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3b6f90a42fa02099f7c2d769d97ec359bd82733d --- /dev/null +++ b/timm/models/ghostnet.py @@ -0,0 +1,276 @@ +""" +An implementation of GhostNet Model as defined in: +GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907 +The train script of the model is similar to that of MobileNetV3 +Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch +""" +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import SelectAdaptivePool2d, Linear, make_divisible +from .efficientnet_blocks import SqueezeExcite, ConvBnAct +from .helpers import build_model_with_cfg +from .registry import register_model + + +__all__ = ['GhostNet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'ghostnet_050': _cfg(url=''), + 'ghostnet_100': _cfg( + url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'), + 'ghostnet_130': _cfg(url=''), +} + + +_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4)) + + +class GhostModule(nn.Module): + def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels * (ratio - 1) + + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1, x2], dim=1) + return out[:, :self.oup, :, :] + + +class GhostBottleneck(nn.Module): + """ Ghost bottleneck w/ optional SE""" + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0.): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + else: + self.conv_dw = None + self.bn_dw = None + + # Squeeze-and-excitation + self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None + + # Point-wise linear projection + self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) + + # shortcut + if in_chs == out_chs and self.stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + shortcut = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.conv_dw is not None: + x = self.conv_dw(x) + x = self.bn_dw(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(shortcut) + return x + + +class GhostNet(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'): + super(GhostNet, self).__init__() + # setting of inverted residual blocks + assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' + self.cfgs = cfgs + self.num_classes = num_classes + self.dropout = dropout + self.feature_info = [] + + # building first layer + stem_chs = make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False) + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem')) + self.bn1 = nn.BatchNorm2d(stem_chs) + self.act1 = nn.ReLU(inplace=True) + prev_chs = stem_chs + + # building inverted residual blocks + stages = nn.ModuleList([]) + block = GhostBottleneck + stage_idx = 0 + net_stride = 2 + for cfg in self.cfgs: + layers = [] + s = 1 + for k, exp_size, c, se_ratio, s in cfg: + out_chs = make_divisible(c * width, 4) + mid_chs = make_divisible(exp_size * width, 4) + layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio)) + prev_chs = out_chs + if s > 1: + net_stride *= 2 + self.feature_info.append(dict( + num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}')) + stages.append(nn.Sequential(*layers)) + stage_idx += 1 + + out_chs = make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) + self.pool_dim = prev_chs = out_chs + + self.blocks = nn.Sequential(*stages) + + # building last several layers + self.num_features = out_chs = 1280 + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.flatten(x) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + return x + + +def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): + """ + Constructs a GhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 16, 16, 0, 1]], + # stage2 + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + # stage3 + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + # stage4 + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1] + ], + # stage5 + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1] + ] + ] + model_kwargs = dict( + cfgs=cfgs, + width=width, + **kwargs, + ) + return build_model_with_cfg( + GhostNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **model_kwargs) + + +@register_model +def ghostnet_050(pretrained=False, **kwargs): + """ GhostNet-0.5x """ + model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs) + return model + + +@register_model +def ghostnet_100(pretrained=False, **kwargs): + """ GhostNet-1.0x """ + model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def ghostnet_130(pretrained=False, **kwargs): + """ GhostNet-1.3x """ + model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..027a10b534b50e94775d463304235256f1cfc16f --- /dev/null +++ b/timm/models/gluon_resnet.py @@ -0,0 +1,248 @@ +"""Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants +This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions +and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) +by Ross Wightman +""" + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SEModule +from .registry import register_model +from .resnet import ResNet, Bottleneck, BasicBlock + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'), + 'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'), + 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'), + 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'), + 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'), + 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth', + first_conv='conv1.0'), + 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), + 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), + 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), + 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), + 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), + 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), + 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', + first_conv='conv1.0'), +} + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def gluon_resnet18_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet34_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1b(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1c(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet101_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1d(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args) + + +@register_model +def gluon_resnet50_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnet101_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args) + + +@register_model +def gluon_resnet152_v1s(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs) + return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args) + + + +@register_model +def gluon_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt50-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_32x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-32x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def gluon_seresnext101_64x4d(pretrained=False, **kwargs): + """Constructs a SEResNeXt-101-64x4d model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, + block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args) + + +@register_model +def gluon_senet154(pretrained=False, **kwargs): + """Constructs an SENet-154 model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs) + return _create_resnet('gluon_senet154', pretrained, **model_args) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd668a585e676726a7a6f8bd43642e57e4566e2 --- /dev/null +++ b/timm/models/gluon_xception.py @@ -0,0 +1,246 @@ +"""Pytorch impl of Gluon Xception +This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl. + +Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html) +Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier, get_padding +from .registry import register_model + +__all__ = ['Xception65'] + +default_cfgs = { + 'gluon_xception65': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', + 'input_size': (3, 299, 299), + 'crop_pct': 0.903, + 'pool_size': (10, 10), + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + }, +} + +""" PADDING NOTES +The original PyTorch and Gluon impl of these models dutifully reproduced the +aligned padding added to Tensorflow models for Deeplab. This padding was compensating +for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. +""" + + +class SeparableConv2d(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None): + super(SeparableConv2d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + padding = get_padding(kernel_size, stride, dilation) + self.conv_dw = nn.Conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=inplanes, bias=bias) + self.bn = norm_layer(num_features=inplanes) + # pointwise convolution + self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn(x) + x = self.conv_pw(x) + return x + + +class Block(nn.Module): + def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None): + super(Block, self).__init__() + if isinstance(planes, (list, tuple)): + assert len(planes) == 3 + else: + planes = (planes,) * 3 + outplanes = planes[-1] + + if outplanes != inplanes or stride != 1: + self.skip = nn.Sequential() + self.skip.add_module('conv1', nn.Conv2d( + inplanes, outplanes, 1, stride=stride, bias=False)), + self.skip.add_module('bn1', norm_layer(num_features=outplanes)) + else: + self.skip = None + + rep = OrderedDict() + for i in range(3): + rep['act%d' % (i + 1)] = nn.ReLU(inplace=True) + rep['conv%d' % (i + 1)] = SeparableConv2d( + inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer) + rep['bn%d' % (i + 1)] = norm_layer(planes[i]) + inplanes = planes[i] + + if not start_with_relu: + del rep['act1'] + else: + rep['act1'] = nn.ReLU(inplace=False) + self.rep = nn.Sequential(rep) + + def forward(self, x): + skip = x + if self.skip is not None: + skip = self.skip(skip) + x = self.rep(x) + skip + return x + + +class Xception65(nn.Module): + """Modified Aligned Xception. + + NOTE: only the 65 layer version is included here, the 71 layer variant + was not correct and had no pretrained weights + """ + + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, + drop_rate=0., global_pool='avg'): + super(Xception65, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + if output_stride == 32: + entry_block3_stride = 2 + exit_block20_stride = 2 + middle_dilation = 1 + exit_dilation = (1, 1) + elif output_stride == 16: + entry_block3_stride = 2 + exit_block20_stride = 1 + middle_dilation = 1 + exit_dilation = (1, 2) + elif output_stride == 8: + entry_block3_stride = 1 + exit_block20_stride = 1 + middle_dilation = 2 + exit_dilation = (2, 4) + else: + raise NotImplementedError + + # Entry flow + self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(num_features=32) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(num_features=64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block1_act = nn.ReLU(inplace=True) + self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer) + self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer) + + # Middle flow + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)])) + + # Exit flow + self.block20 = Block( + 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer) + self.block20_act = nn.ReLU(inplace=True) + + self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn3 = norm_layer(num_features=1536) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn4 = norm_layer(num_features=1536) + self.act4 = nn.ReLU(inplace=True) + + self.num_features = 2048 + self.conv5 = SeparableConv2d( + 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer) + self.bn5 = norm_layer(num_features=self.num_features) + self.act5 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block1_act'), + dict(num_chs=256, reduction=8, module='block3.rep.act1'), + dict(num_chs=728, reduction=16, module='block20.rep.act1'), + dict(num_chs=2048, reduction=32, module='act5'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block1_act(x) + # c1 = x + x = self.block2(x) + # c2 = x + x = self.block3(x) + + # Middle flow + x = self.mid(x) + # c3 = x + + # Exit flow + x = self.block20(x) + x = self.block20_act(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.act5(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _create_gluon_xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception65, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), + **kwargs) + + +@register_model +def gluon_xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + return _create_gluon_xception('gluon_xception65', pretrained, **kwargs) diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py new file mode 100644 index 0000000000000000000000000000000000000000..9988a0444558d9e7f4b640ff468cc63b1dc1d7f4 --- /dev/null +++ b/timm/models/hardcorenas.py @@ -0,0 +1,152 @@ +from functools import partial + +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import get_act_fn +from .mobilenetv3 import MobileNetV3, MobileNetV3Features +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hardcorenas_a': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'), + 'hardcorenas_b': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'), + 'hardcorenas_c': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'), + 'hardcorenas_d': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'), + 'hardcorenas_e': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'), + 'hardcorenas_f': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'), +} + + +def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): + """Creates a hardcorenas model + + Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS + Paper: https://arxiv.org/abs/2102.11646 + + """ + num_features = 1280 + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=32, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=se_layer, + **kwargs, + ) + + features_only = False + model_cls = MobileNetV3 + kwargs_filter = None + if model_kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hardcorenas_a(pretrained=False, **kwargs): + """ hardcorenas_A """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_b(pretrained=False, **kwargs): + """ hardcorenas_B """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], + ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_c(pretrained=False, **kwargs): + """ hardcorenas_C """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', + 'ir_r1_k5_s1_e3_c40_nre'], + ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'], + ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_d(pretrained=False, **kwargs): + """ hardcorenas_D """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'], + ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_e(pretrained=False, **kwargs): + """ hardcorenas_E """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', + 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'], + ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k5_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs) + return model + + +@register_model +def hardcorenas_f(pretrained=False, **kwargs): + """ hardcorenas_F """ + arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'], + ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'], + ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', + 'ir_r1_k3_s1_e3_c80_se0.25'], + ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', + 'ir_r1_k3_s1_e3_c112_se0.25'], + ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25', + 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']] + model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs) + return model diff --git a/timm/models/helpers.py b/timm/models/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..662a7a483b1e40f9f00d931e84762878c612c0c6 --- /dev/null +++ b/timm/models/helpers.py @@ -0,0 +1,508 @@ +""" Model creation / weight loading / state_dict helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import os +import math +from collections import OrderedDict +from copy import deepcopy +from typing import Any, Callable, Optional, Tuple + +import torch +import torch.nn as nn + + +from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url +from .layers import Conv2dSame, Linear + + +_logger = logging.getLogger(__name__) + + +def load_state_dict(checkpoint_path, use_ema=False): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + state_dict_key = 'state_dict' + if isinstance(checkpoint, dict): + if use_ema and 'state_dict_ema' in checkpoint: + state_dict_key = 'state_dict_ema' + if state_dict_key and state_dict_key in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint[state_dict_key].items(): + # strip `module.` prefix + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + state_dict = new_state_dict + else: + state_dict = checkpoint + _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) + return state_dict + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + model.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return + state_dict = load_state_dict(checkpoint_path, use_ema) + model.load_state_dict(state_dict, strict=strict) + + +def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): + resume_epoch = None + if os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + if log_info: + _logger.info('Restoring model state from checkpoint...') + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k[7:] if k.startswith('module') else k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + + if optimizer is not None and 'optimizer' in checkpoint: + if log_info: + _logger.info('Restoring optimizer state from checkpoint...') + optimizer.load_state_dict(checkpoint['optimizer']) + + if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: + if log_info: + _logger.info('Restoring AMP loss scaler state from checkpoint...') + loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + + if log_info: + _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + else: + model.load_state_dict(checkpoint) + if log_info: + _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) + return resume_epoch + else: + _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): + r"""Loads a custom (read non .pth) weight file + + Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls + a passed in custom load fun, or the `load_pretrained` model member fn. + + If the object is already present in `model_dir`, it's deserialized and returned. + The default value of `model_dir` is ``/checkpoints`` where + `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`. + + Args: + model: The instantiated model to load weights into + default_cfg (dict): Default pretrained model cfg + load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named + 'laod_pretrained' on the model will be called if it exists + progress (bool, optional): whether or not to display a progress bar to stderr. Default: False + check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention + ``filename-.ext`` where ```` is the first eight or more + digits of the SHA256 hash of the contents of the file. The hash is used to + ensure unique names and to verify the contents of the file. Default: False + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + if not pretrained_url: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) + + if load_fn is not None: + load_fn(model, cached_file) + elif hasattr(model, 'load_pretrained'): + model.load_pretrained(cached_file) + else: + _logger.warning("Valid function to load pretrained weights is not available, using random initialization.") + + +def adapt_input_conv(in_chans, conv_weight): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + conv_weight = conv_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + # NOTE this strategy should be better than random init, but there could be other combinations of + # the original RGB input layer weights that'd work better for specific cases. + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset + num_classes (int): num_classes for model + in_chans (int): in_chans for model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + progress (bool): enable progress bar for weight download + + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + hf_hub_id = default_cfg.get('hf_hub', None) + if not pretrained_url and not hf_hub_id: + _logger.warning("No pretrained weights exist for this model. Using random initialization.") + return + if hf_hub_id and has_hf_hub(necessary=not pretrained_url): + _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})') + state_dict = load_state_dict_from_hf(hf_hub_id) + else: + _logger.info(f'Loading pretrained weights from url ({pretrained_url})') + state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') + if filter_fn is not None: + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) + + input_convs = default_cfg.get('first_conv', None) + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs,) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name]) + _logger.info( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)') + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + _logger.warning( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') + + classifiers = default_cfg.get('classifier', None) + label_offset = default_cfg.get('label_offset', 0) + if classifiers is not None: + if isinstance(classifiers, str): + classifiers = (classifiers,) + if num_classes != default_cfg['num_classes']: + for classifier_name in classifiers: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + elif label_offset > 0: + for classifier_name in classifiers: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + model.load_state_dict(state_dict, strict=strict) + + +def extract_layer(model, layer): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + if not hasattr(model, 'module') and layer[0] == 'module': + layer = layer[1:] + for l in layer: + if hasattr(module, l): + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + else: + return module + return module + + +def set_layer(model, layer, val): + layer = layer.split('.') + module = model + if hasattr(model, 'module') and layer[0] != 'module': + module = model.module + lst_index = 0 + module2 = module + for l in layer: + if hasattr(module2, l): + if not l.isdigit(): + module2 = getattr(module2, l) + else: + module2 = module2[int(l)] + lst_index += 1 + lst_index -= 1 + for l in layer[:lst_index]: + if not l.isdigit(): + module = getattr(module, l) + else: + module = module[int(l)] + l = layer[lst_index] + setattr(module, l, val) + + +def adapt_model_from_string(parent_module, model_string): + separator = '***' + state_dict = {} + lst_shape = model_string.split(separator) + for k in lst_shape: + k = k.split(':') + key = k[0] + shape = k[1][1:-1].split(',') + if shape[0] != '': + state_dict[key] = [int(i) for i in shape] + + new_module = deepcopy(parent_module) + for n, m in parent_module.named_modules(): + old_module = extract_layer(parent_module, n) + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): + if isinstance(old_module, Conv2dSame): + conv = Conv2dSame + else: + conv = nn.Conv2d + s = state_dict[n + '.weight'] + in_channels = s[1] + out_channels = s[0] + g = 1 + if old_module.groups > 1: + in_channels = out_channels + g = in_channels + new_conv = conv( + in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, + bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, + groups=g, stride=old_module.stride) + set_layer(new_module, n, new_conv) + if isinstance(old_module, nn.BatchNorm2d): + new_bn = nn.BatchNorm2d( + num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + set_layer(new_module, n, new_bn) + if isinstance(old_module, nn.Linear): + # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? + num_features = state_dict[n + '.weight'][1] + new_fc = Linear( + in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) + set_layer(new_module, n, new_fc) + if hasattr(new_module, 'num_features'): + new_module.num_features = num_features + new_module.eval() + parent_module.eval() + + return new_module + + +def adapt_model_from_file(parent_module, model_variant): + adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') + with open(adapt_file, 'r') as f: + return adapt_model_from_string(parent_module, f.read().strip()) + + +def default_cfg_for_features(default_cfg): + default_cfg = deepcopy(default_cfg) + # remove default pretrained cfg fields that don't have much relevance for feature backbone + to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? + for tr in to_remove: + default_cfg.pop(tr, None) + return default_cfg + + +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. + """ + external_default_cfg = kwargs.pop('external_default_cfg', None) + if external_default_cfg: + default_cfg.pop('url', None) # url should come from external cfg + default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg + default_cfg.update(external_default_cfg) + + +def set_default_kwargs(kwargs, names, default_cfg): + for n in names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # default_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = default_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, default_cfg[n]) + + +def filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs + could/should be replaced by an improved configuration mechanism + + Args: + default_cfg: input default_cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Overlay default cfg values from `external_default_cfg` if it exists in kwargs + overlay_external_default_cfg(default_cfg, kwargs) + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + filter_kwargs(kwargs, names=kwargs_filter) + + +def build_model_with_cfg( + model_cls: Callable, + variant: str, + pretrained: bool, + default_cfg: dict, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[dict] = None, + pretrained_strict: bool = True, + pretrained_filter_fn: Optional[Callable] = None, + pretrained_custom_load: bool = False, + kwargs_filter: Optional[Tuple[str]] = None, + **kwargs): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + default_cfg (dict): model's default pretrained/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ + pruned = kwargs.pop('pruned', False) + features = False + feature_cfg = feature_cfg or {} + default_cfg = deepcopy(default_cfg) if default_cfg else {} + update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) + default_cfg.setdefault('architecture', variant) + + # Setup for feature extraction wrapper done at end of this fn + if kwargs.pop('features_only', False): + features = True + feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) + if 'out_indices' in kwargs: + feature_cfg['out_indices'] = kwargs.pop('out_indices') + + # Build the model + model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) + model.default_cfg = default_cfg + + if pruned: + model = adapt_model_from_file(model, variant) + + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) + if pretrained: + if pretrained_custom_load: + load_custom_pretrained(model) + else: + load_pretrained( + model, + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict) + + # Wrap the model in a feature extraction module if enabled + if features: + feature_cls = FeatureListNet + if 'feature_cls' in feature_cfg: + feature_cls = feature_cfg.pop('feature_cls') + if isinstance(feature_cls, str): + feature_cls = feature_cls.lower() + if 'hook' in feature_cls: + feature_cls = FeatureHookNet + else: + assert False, f'Unknown feature class {feature_cls}' + model = feature_cls(model, **feature_cfg) + model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg + + return model + + +def model_parameters(model, exclude_head=False): + if exclude_head: + # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering + return [p for p in model.parameters()][:-2] + else: + return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c56964f64feec08f10b02ad368987eecd46db618 --- /dev/null +++ b/timm/models/hrnet.py @@ -0,0 +1,836 @@ +""" HRNet + +Copied from https://github.com/HRNet/HRNet-Image-Classification + +Original header: + Copyright (c) Microsoft + Licensed under the MIT License. + Written by Bin Xiao (Bin.Xiao@microsoft.com) + Modified by Ke Sun (sunk@mail.ustc.edu.cn) +""" +import logging +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .features import FeatureInfo +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import create_classifier +from .registry import register_model +from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE + +_BN_MOMENTUM = 0.1 +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'hrnet_w18_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'), + 'hrnet_w18_small_v2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'), + 'hrnet_w18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'), + 'hrnet_w30': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'), + 'hrnet_w32': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'), + 'hrnet_w40': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'), + 'hrnet_w44': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'), + 'hrnet_w48': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'), + 'hrnet_w64': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'), +} + +cfg_cls = dict( + hrnet_w18_small=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(1,), + NUM_CHANNELS=(32,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(16, 32), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=1, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(16, 32, 64), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=1, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(16, 32, 64, 128), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18_small_v2=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(2,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=3, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=2, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(2, 2, 2, 2), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w18=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(18, 36), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(18, 36, 72), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(18, 36, 72, 144), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w30=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(30, 60), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(30, 60, 120), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(30, 60, 120, 240), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w32=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(32, 64), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(32, 64, 128), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(32, 64, 128, 256), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w40=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(40, 80), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(40, 80, 160), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(40, 80, 160, 320), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w44=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(44, 88), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(44, 88, 176), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(44, 88, 176, 352), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w48=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(48, 96), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(48, 96, 192), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(48, 96, 192, 384), + FUSE_METHOD='SUM', + ), + ), + + hrnet_w64=dict( + STEM_WIDTH=64, + STAGE1=dict( + NUM_MODULES=1, + NUM_BRANCHES=1, + BLOCK='BOTTLENECK', + NUM_BLOCKS=(4,), + NUM_CHANNELS=(64,), + FUSE_METHOD='SUM', + ), + STAGE2=dict( + NUM_MODULES=1, + NUM_BRANCHES=2, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4), + NUM_CHANNELS=(64, 128), + FUSE_METHOD='SUM' + ), + STAGE3=dict( + NUM_MODULES=4, + NUM_BRANCHES=3, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4), + NUM_CHANNELS=(64, 128, 256), + FUSE_METHOD='SUM' + ), + STAGE4=dict( + NUM_MODULES=3, + NUM_BRANCHES=4, + BLOCK='BASIC', + NUM_BLOCKS=(4, 4, 4, 4), + NUM_CHANNELS=(64, 128, 256, 512), + FUSE_METHOD='SUM', + ), + ) +) + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.fuse_act = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + error_msg = '' + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) + elif num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) + elif num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels)) + if error_msg: + _logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)] + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return nn.Identity() + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM), + nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(nn.Identity()) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x: List[torch.Tensor]): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i, branch in enumerate(self.branches): + x[i] = branch(x[i]) + + x_fuse = [] + for i, fuse_outer in enumerate(self.fuse_layers): + y = x[0] if i == 0 else fuse_outer[0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + fuse_outer[j](x[j]) + x_fuse.append(self.fuse_act(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'): + super(HighResolutionNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + + stem_width = cfg['STEM_WIDTH'] + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM) + self.act2 = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + self.head = head + self.head_channels = None # set if _make_head called + if head == 'classification': + # Classification Head + self.num_features = 2048 + self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels) + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + elif head == 'incre': + self.num_features = 2048 + self.incre_modules, _, _ = self._make_head(pre_stage_channels, True) + else: + self.incre_modules = None + self.num_features = 256 + + curr_stride = 2 + # module names aren't actually valid here, hook or FeatureNet based extraction would not work + self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')] + for i, c in enumerate(self.head_channels if self.head_channels else num_channels): + curr_stride *= 2 + c = c * 4 if self.head_channels else c # head block expansion factor of 4 + self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')] + + self.init_weights() + + def _make_head(self, pre_stage_channels, incre_only=False): + head_block = Bottleneck + self.head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1)) + incre_modules = nn.ModuleList(incre_modules) + if incre_only: + return incre_modules, None, None + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels) - 1): + in_channels = self.head_channels[i] * head_block.expansion + out_channels = self.head_channels[i + 1] * head_block.expansion + downsamp_module = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1), + nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=self.head_channels[3] * head_block.expansion, + out_channels=self.num_features, kernel_size=1, stride=1, padding=0 + ), + nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(nn.Identity()) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM), + ) + + layers = [block(inplanes, planes, stride, downsample)] + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + reset_multi_scale_output = multi_scale_output or i < num_modules - 1 + modules.append(HighResolutionModule( + num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classifier = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def stages(self, x) -> List[torch.Tensor]: + x = self.layer1(x) + + xl = [t(x) for i, t in enumerate(self.transition1)] + yl = self.stage2(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)] + yl = self.stage3(xl) + + xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)] + yl = self.stage4(xl) + return yl + + def forward_features(self, x): + # Stem + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + # Stages + yl = self.stages(x) + + # Classification Head + y = self.incre_modules[0](yl[0]) + for i, down in enumerate(self.downsamp_modules): + y = self.incre_modules[i + 1](yl[i + 1]) + down(y) + y = self.final_layer(y) + return y + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classifier(x) + return x + + +class HighResolutionNetFeatures(HighResolutionNet): + """HighResolutionNet feature extraction + + The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so. + It would be more complicated to use the FeatureNet helpers. + + The `feature_location=incre` allows grabbing increased channel count features using part of the + classification head. If `feature_location=''` the default HRNet features are returned. First stem + conv is used for stride 2 features. + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, + feature_location='incre', out_indices=(0, 1, 2, 3, 4)): + assert feature_location in ('incre', '') + super(HighResolutionNetFeatures, self).__init__( + cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool, + drop_rate=drop_rate, head=feature_location) + self.feature_info = FeatureInfo(self.feature_info, out_indices) + self._out_idx = {i for i in out_indices} + + def forward_features(self, x): + assert False, 'Not supported' + + def forward(self, x) -> List[torch.tensor]: + out = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + if 0 in self._out_idx: + out.append(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + x = self.stages(x) + if self.incre_modules is not None: + x = [incre(f) for f, incre in zip(x, self.incre_modules)] + for i, f in enumerate(x): + if i + 1 in self._out_idx: + out.append(f) + return out + + +def _create_hrnet(variant, pretrained, **model_kwargs): + model_cls = HighResolutionNet + features_only = False + kwargs_filter = None + if model_kwargs.pop('features_only', False): + model_cls = HighResolutionNetFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfg_cls[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +@register_model +def hrnet_w18_small(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18_small', pretrained, **kwargs) + + +@register_model +def hrnet_w18_small_v2(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs) + + +@register_model +def hrnet_w18(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w18', pretrained, **kwargs) + + +@register_model +def hrnet_w30(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w30', pretrained, **kwargs) + + +@register_model +def hrnet_w32(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w32', pretrained, **kwargs) + + +@register_model +def hrnet_w40(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w40', pretrained, **kwargs) + + +@register_model +def hrnet_w44(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w44', pretrained, **kwargs) + + +@register_model +def hrnet_w48(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w48', pretrained, **kwargs) + + +@register_model +def hrnet_w64(pretrained=True, **kwargs): + return _create_hrnet('hrnet_w64', pretrained, **kwargs) diff --git a/timm/models/hub.py b/timm/models/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9b553031fb9d1846338990cd3b6f77228174c6 --- /dev/null +++ b/timm/models/hub.py @@ -0,0 +1,96 @@ +import json +import logging +import os +from functools import partial +from typing import Union, Optional + +import torch +from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from timm import __version__ +try: + from huggingface_hub import hf_hub_url + from huggingface_hub import cached_download + cached_download = partial(cached_download, library_name="timm", library_version=__version__) +except ImportError: + hf_hub_url = None + cached_download = None + +_logger = logging.getLogger(__name__) + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if hf_hub_url is None and necessary: + # if no HF Hub module installed and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return hf_hub_url is not None + + +def hf_split(hf_id): + rev_split = hf_id.split('@') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + url = hf_hub_url(hf_model_id, filename, revision=hf_revision) + return cached_download(url, cache_dir=get_cache_dir('hf')) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + default_cfg = load_cfg_from_json(cached_file) + default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation + model_name = default_cfg.get('architecture') + return default_cfg, model_name + + +def load_state_dict_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'pytorch_model.bin') + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..716728495a668ae4a11257d32780059b51b28763 --- /dev/null +++ b/timm/models/inception_resnet_v2.py @@ -0,0 +1,358 @@ +""" Pytorch Inception-Resnet-V2 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionResnetV2'] + +default_cfgs = { + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz + 'inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + }, + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'ens_adv_inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + 'label_offset': 1, # 1001 classes in pretrained weights + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=.001) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + + def __init__(self, scale=1.0, no_relu=False): + super(Block8, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + self.relu = None if no_relu else nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if self.relu is not None: + out = self.relu(out) + return out + + +class InceptionResnetV2(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'): + super(InceptionResnetV2, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + assert output_stride == 32 + + self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')] + + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')] + + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')] + + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')] + + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(no_relu=True) + self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')] + + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classif(x) + return x + + +def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionResnetV2, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def inception_resnet_v2(pretrained=False, **kwargs): + r"""InceptionResnetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." ` paper. + """ + return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs) + + +@register_model +def ens_adv_inception_resnet_v2(pretrained=False, **kwargs): + r""" Ensemble Adversarially trained InceptionResnetV2 model architecture + As per https://arxiv.org/abs/1705.07204 and + https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. + """ + return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb1107b39b18418769f7cf775490cec4e95bb5b --- /dev/null +++ b/timm/models/inception_v3.py @@ -0,0 +1,470 @@ +""" Inception-V3 + +Originally from torchvision Inception3 model +Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import trunc_normal_, create_classifier, Linear + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # original PyTorch weights, ported from Tensorflow but modified + 'inception_v3': _cfg( + url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', + has_aux=True), # checkpoint has aux logit layer weights + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + 'tf_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', + num_classes=1000, has_aux=False, label_offset=1), + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + 'adv_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', + num_classes=1000, has_aux=False, label_offset=1), + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + 'gluon_inception_v3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', + mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults + std=IMAGENET_DEFAULT_STD, # also works well with inception defaults + has_aux=False, + ) +} + + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None): + super(InceptionA, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None): + super(InceptionC, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionE, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3), + ] + branch3x3 = torch.cat(branch3x3, 1) + + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [ + self.branch3x3dbl_3a(branch3x3dbl), + self.branch3x3dbl_3b(branch3x3dbl), + ] + branch3x3dbl = torch.cat(branch3x3dbl, 1) + + branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def forward(self, x): + outputs = self._forward(x) + return torch.cat(outputs, 1) + + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if conv_block is None: + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = Linear(768, num_classes) + self.fc.stddev = 0.001 + + def forward(self, x): + # N x 768 x 17 x 17 + x = F.avg_pool2d(x, kernel_size=5, stride=3) + # N x 768 x 5 x 5 + x = self.conv0(x) + # N x 128 x 5 x 5 + x = self.conv1(x) + # N x 768 x 1 x 1 + # Adaptive average pooling + x = F.adaptive_avg_pool2d(x, (1, 1)) + # N x 768 x 1 x 1 + x = torch.flatten(x, 1) + # N x 768 + x = self.fc(x) + # N x 1000 + return x + + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) + + +class InceptionV3(nn.Module): + """Inception-V3 with no AuxLogits + FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): + super(InceptionV3, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.aux_logits = aux_logits + + self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) + self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) + self.Mixed_5b = InceptionA(192, pool_features=32) + self.Mixed_5c = InceptionA(256, pool_features=64) + self.Mixed_5d = InceptionA(288, pool_features=64) + self.Mixed_6a = InceptionB(288) + self.Mixed_6b = InceptionC(768, channels_7x7=128) + self.Mixed_6c = InceptionC(768, channels_7x7=160) + self.Mixed_6d = InceptionC(768, channels_7x7=160) + self.Mixed_6e = InceptionC(768, channels_7x7=192) + if aux_logits: + self.AuxLogits = InceptionAux(768, num_classes) + else: + self.AuxLogits = None + self.Mixed_7a = InceptionD(768) + self.Mixed_7b = InceptionE(1280) + self.Mixed_7c = InceptionE(2048) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), + dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), + dict(num_chs=288, reduction=8, module='Mixed_5d'), + dict(num_chs=768, reduction=16, module='Mixed_6e'), + dict(num_chs=2048, reduction=32, module='Mixed_7c'), + ] + + self.num_features = 2048 + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + stddev = m.stddev if hasattr(m, 'stddev') else 0.1 + trunc_normal_(m.weight, std=stddev) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward_preaux(self, x): + # N x 3 x 299 x 299 + x = self.Conv2d_1a_3x3(x) + # N x 32 x 149 x 149 + x = self.Conv2d_2a_3x3(x) + # N x 32 x 147 x 147 + x = self.Conv2d_2b_3x3(x) + # N x 64 x 147 x 147 + x = self.Pool1(x) + # N x 64 x 73 x 73 + x = self.Conv2d_3b_1x1(x) + # N x 80 x 73 x 73 + x = self.Conv2d_4a_3x3(x) + # N x 192 x 71 x 71 + x = self.Pool2(x) + # N x 192 x 35 x 35 + x = self.Mixed_5b(x) + # N x 256 x 35 x 35 + x = self.Mixed_5c(x) + # N x 288 x 35 x 35 + x = self.Mixed_5d(x) + # N x 288 x 35 x 35 + x = self.Mixed_6a(x) + # N x 768 x 17 x 17 + x = self.Mixed_6b(x) + # N x 768 x 17 x 17 + x = self.Mixed_6c(x) + # N x 768 x 17 x 17 + x = self.Mixed_6d(x) + # N x 768 x 17 x 17 + x = self.Mixed_6e(x) + # N x 768 x 17 x 17 + return x + + def forward_postaux(self, x): + x = self.Mixed_7a(x) + # N x 1280 x 8 x 8 + x = self.Mixed_7b(x) + # N x 2048 x 8 x 8 + x = self.Mixed_7c(x) + # N x 2048 x 8 x 8 + return x + + def forward_features(self, x): + x = self.forward_preaux(x) + x = self.forward_postaux(x) + return x + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +class InceptionV3Aux(InceptionV3): + """InceptionV3 with AuxLogits + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): + super(InceptionV3Aux, self).__init__( + num_classes, in_chans, drop_rate, global_pool, aux_logits) + + def forward_features(self, x): + x = self.forward_preaux(x) + aux = self.AuxLogits(x) if self.training else None + x = self.forward_postaux(x) + return x, aux + + def forward(self, x): + x, aux = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x, aux + + +def _create_inception_v3(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + aux_logits = kwargs.pop('aux_logits', False) + if aux_logits: + assert not kwargs.pop('features_only', False) + model_cls = InceptionV3Aux + load_strict = default_cfg['has_aux'] + else: + model_cls = InceptionV3 + load_strict = not default_cfg['has_aux'] + return build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfg, + pretrained_strict=load_strict, + **kwargs) + + +@register_model +def inception_v3(pretrained=False, **kwargs): + # original PyTorch weights, ported from Tensorflow but modified + model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) + model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def adv_inception_v3(pretrained=False, **kwargs): + # my port of Tensorflow adversarially trained Inception V3 from + # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz + model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) + return model + + +@register_model +def gluon_inception_v3(pretrained=False, **kwargs): + # from gluon pretrained models, best performing in terms of accuracy/loss metrics + # https://gluon-cv.mxnet.io/model_zoo/classification.html + model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py new file mode 100644 index 0000000000000000000000000000000000000000..cc899e15daf8087ae6acb17017079c292a1e3aa7 --- /dev/null +++ b/timm/models/inception_v4.py @@ -0,0 +1,316 @@ +""" Pytorch Inception-V4 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['InceptionV4'] + +default_cfgs = { + 'inception_v4': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'features.0.conv', 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + } +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed3a(nn.Module): + def __init__(self): + super(Mixed3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed4a(nn.Module): + def __init__(self): + super(Mixed4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed5a(nn.Module): + def __init__(self): + super(Mixed5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class InceptionA(nn.Module): + def __init__(self): + super(InceptionA, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionA(nn.Module): + def __init__(self): + super(ReductionA, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionB(nn.Module): + def __init__(self): + super(InceptionB, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class ReductionB(nn.Module): + def __init__(self): + super(ReductionB, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class InceptionC(nn.Module): + def __init__(self): + super(InceptionC, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'): + super(InceptionV4, self).__init__() + assert output_stride == 32 + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 1536 + + self.features = nn.Sequential( + BasicConv2d(in_chans, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed3a(), + Mixed4a(), + Mixed5a(), + InceptionA(), + InceptionA(), + InceptionA(), + InceptionA(), + ReductionA(), # Mixed6a + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + InceptionB(), + ReductionB(), # Mixed7a + InceptionC(), + InceptionC(), + InceptionC(), + ) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='features.2'), + dict(num_chs=160, reduction=4, module='features.3'), + dict(num_chs=384, reduction=8, module='features.9'), + dict(num_chs=1024, reduction=16, module='features.17'), + dict(num_chs=1536, reduction=32, module='features.21'), + ] + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + return self.features(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_inception_v4(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + InceptionV4, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def inception_v4(pretrained=False, **kwargs): + return _create_inception_v4('inception_v4', pretrained, **kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77d1026e8c88c466250df4fe8e2bbcc9e2b42290 --- /dev/null +++ b/timm/models/layers/__init__.py @@ -0,0 +1,40 @@ +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .blur_pool import BlurPool2d +from .classifier import ClassifierHead, create_classifier +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config +from .conv2d_same import Conv2dSame, conv2d_same +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer, get_act_fn +from .create_attn import get_attn, create_attn +from .create_conv2d import create_conv2d +from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act +from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible +from .inplace_abn import InplaceAbn +from .involution import Involution +from .linear import Linear +from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp, GatedMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .norm import GroupNorm, LayerNorm2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed +from .pool2d_same import AvgPool2dSame, create_pool2d +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite +from .selective_kernel import SelectiveKernel +from .separable_conv import SeparableConv2d, SeparableConvBnAct +from .space_to_depth import SpaceToDepthModule +from .split_attn import SplitAttn +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_ diff --git a/timm/models/layers/__pycache__/__init__.cpython-310.pyc b/timm/models/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6b24f6b3f435474a934b86ba4f19498b02ff1eb Binary files /dev/null and b/timm/models/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/__init__.cpython-311.pyc b/timm/models/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30dec1dbe25fd138ed86fba58dcd084734ca159c Binary files /dev/null and b/timm/models/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/__init__.cpython-313.pyc b/timm/models/layers/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..509d63ca79c11f0c662a06e1917c70c0d3577ec2 Binary files /dev/null and b/timm/models/layers/__pycache__/__init__.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/__init__.cpython-38.pyc b/timm/models/layers/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cfd2bf4aff3f6c9686bec5b23bb8fdae7774679 Binary files /dev/null and b/timm/models/layers/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/activations.cpython-310.pyc b/timm/models/layers/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed612053c6419706a03648c80476e788d1a039b7 Binary files /dev/null and b/timm/models/layers/__pycache__/activations.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/activations.cpython-311.pyc b/timm/models/layers/__pycache__/activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9033000e10a21b529f931c40dec079cd2268b52f Binary files /dev/null and b/timm/models/layers/__pycache__/activations.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/activations.cpython-313.pyc b/timm/models/layers/__pycache__/activations.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..853fffc3f0a74c7384806fca72a8737e21af450b Binary files /dev/null and b/timm/models/layers/__pycache__/activations.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/activations.cpython-38.pyc b/timm/models/layers/__pycache__/activations.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c61bfa67235fb3f3f32874992c67381a85173a4 Binary files /dev/null and b/timm/models/layers/__pycache__/activations.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/activations_jit.cpython-310.pyc b/timm/models/layers/__pycache__/activations_jit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fecf1ab194986349f4926f2f7eaa10fbcc35f76 Binary files /dev/null and b/timm/models/layers/__pycache__/activations_jit.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/activations_jit.cpython-311.pyc b/timm/models/layers/__pycache__/activations_jit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5aca932dac841cd439894107acd7725fa13d2ec Binary files /dev/null and b/timm/models/layers/__pycache__/activations_jit.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/activations_jit.cpython-313.pyc b/timm/models/layers/__pycache__/activations_jit.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31419b093fac27bbea565d868a89e3a089933ca6 Binary files /dev/null and b/timm/models/layers/__pycache__/activations_jit.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc b/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..044fdd3debe951677d3ba18917173ad36a7e76ac Binary files /dev/null and b/timm/models/layers/__pycache__/activations_jit.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/activations_me.cpython-310.pyc b/timm/models/layers/__pycache__/activations_me.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5414b4775748c881002d0beb946f3f1c3d9c05c Binary files /dev/null and b/timm/models/layers/__pycache__/activations_me.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/activations_me.cpython-311.pyc b/timm/models/layers/__pycache__/activations_me.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..820270a73ab39990d9c5198eeca0d397b083d521 Binary files /dev/null and b/timm/models/layers/__pycache__/activations_me.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/activations_me.cpython-313.pyc b/timm/models/layers/__pycache__/activations_me.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e68402d33aff707f6300455a10189266b84c66af Binary files /dev/null and b/timm/models/layers/__pycache__/activations_me.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/activations_me.cpython-38.pyc b/timm/models/layers/__pycache__/activations_me.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a91b61f35812c4d2da88edfd0d43a38361399d Binary files /dev/null and b/timm/models/layers/__pycache__/activations_me.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-310.pyc b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..163d2a528d37c5f765fcd9e43d0ecad63a1f9275 Binary files /dev/null and b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-311.pyc b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..995605817b2280c1763d2824ee7496d6b879c01f Binary files /dev/null and b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-313.pyc b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ad0d074e15c80131276659865f14c075261fc07 Binary files /dev/null and b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..384d2b3ffe35788a57b65718c2c1d41edd5a39b4 Binary files /dev/null and b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/attention_pool2d.cpython-38.pyc b/timm/models/layers/__pycache__/attention_pool2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a41b93ccb9061f6c635871ec8fc5e5adb5c7560e Binary files /dev/null and b/timm/models/layers/__pycache__/attention_pool2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/blur_pool.cpython-310.pyc b/timm/models/layers/__pycache__/blur_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..494b4ea8dee9dcc0053a6982168bb1ba527a57fd Binary files /dev/null and b/timm/models/layers/__pycache__/blur_pool.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/blur_pool.cpython-311.pyc b/timm/models/layers/__pycache__/blur_pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa4a1b92333c5310a59d688f323a04c9820472fb Binary files /dev/null and b/timm/models/layers/__pycache__/blur_pool.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/blur_pool.cpython-313.pyc b/timm/models/layers/__pycache__/blur_pool.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29a43fc02a01c70944b6809042c28e47e8732d4f Binary files /dev/null and b/timm/models/layers/__pycache__/blur_pool.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc b/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2c28f5be8c002d05ce62cf8f3430db47b1cd6cd Binary files /dev/null and b/timm/models/layers/__pycache__/blur_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/bottleneck_attn.cpython-310.pyc b/timm/models/layers/__pycache__/bottleneck_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5019f4bb509959454f6f680531367130d6d1b59a Binary files /dev/null and b/timm/models/layers/__pycache__/bottleneck_attn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/bottleneck_attn.cpython-311.pyc b/timm/models/layers/__pycache__/bottleneck_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f289f21c08965a24a559466942c472f4d488967d Binary files /dev/null and b/timm/models/layers/__pycache__/bottleneck_attn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/bottleneck_attn.cpython-313.pyc b/timm/models/layers/__pycache__/bottleneck_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa37f88b30a519c21fa9cba8a62b61aa62b7d6f8 Binary files /dev/null and b/timm/models/layers/__pycache__/bottleneck_attn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc b/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ca2f9a5a3851741c1fb0d7d783f65aed6d60dae Binary files /dev/null and b/timm/models/layers/__pycache__/bottleneck_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/cbam.cpython-310.pyc b/timm/models/layers/__pycache__/cbam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb73942be11c5139e2eb5949a7ccf52c27e18f4 Binary files /dev/null and b/timm/models/layers/__pycache__/cbam.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/cbam.cpython-311.pyc b/timm/models/layers/__pycache__/cbam.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad4e31c7fbd3b5405c0987984f9828b279bb672 Binary files /dev/null and b/timm/models/layers/__pycache__/cbam.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/cbam.cpython-313.pyc b/timm/models/layers/__pycache__/cbam.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19603cb4e2e9602009ddabd3f6276f2c7d74f938 Binary files /dev/null and b/timm/models/layers/__pycache__/cbam.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/cbam.cpython-38.pyc b/timm/models/layers/__pycache__/cbam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da2d2d4e9b46f505d792fa67b52c41199b64a35b Binary files /dev/null and b/timm/models/layers/__pycache__/cbam.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/classifier.cpython-310.pyc b/timm/models/layers/__pycache__/classifier.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3f34aead7d18caea8ca97b6775750aff1289ab8 Binary files /dev/null and b/timm/models/layers/__pycache__/classifier.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/classifier.cpython-311.pyc b/timm/models/layers/__pycache__/classifier.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ce3938dde5f6b196c0f794454e9c3564ea3398e Binary files /dev/null and b/timm/models/layers/__pycache__/classifier.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/classifier.cpython-313.pyc b/timm/models/layers/__pycache__/classifier.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca07dd474495b6552bd0e4705d6f8eead20b5972 Binary files /dev/null and b/timm/models/layers/__pycache__/classifier.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/classifier.cpython-38.pyc b/timm/models/layers/__pycache__/classifier.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9bbebbb6df5282b3ac58d7a98656a56d3c023fc1 Binary files /dev/null and b/timm/models/layers/__pycache__/classifier.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/cond_conv2d.cpython-310.pyc b/timm/models/layers/__pycache__/cond_conv2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f6bd2a9d398f48ddc2b0810c7f0d59751e5d705 Binary files /dev/null and b/timm/models/layers/__pycache__/cond_conv2d.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/cond_conv2d.cpython-311.pyc b/timm/models/layers/__pycache__/cond_conv2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eee9290f638dcd4636b5250419f9b1b3eb7a782 Binary files /dev/null and b/timm/models/layers/__pycache__/cond_conv2d.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/cond_conv2d.cpython-313.pyc b/timm/models/layers/__pycache__/cond_conv2d.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe905c24f8c57cdc278f9f32d883fc53a0a34adf Binary files /dev/null and b/timm/models/layers/__pycache__/cond_conv2d.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc b/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02ec95fc74afe0354b7a4c561ef9655721324833 Binary files /dev/null and b/timm/models/layers/__pycache__/cond_conv2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/config.cpython-310.pyc b/timm/models/layers/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0a3ed37cde5ec9f6572b4f9026c7f5a202e723c Binary files /dev/null and b/timm/models/layers/__pycache__/config.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/config.cpython-311.pyc b/timm/models/layers/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1adc9c69672b788f96f0bdf95011102b8fd0dce Binary files /dev/null and b/timm/models/layers/__pycache__/config.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/config.cpython-313.pyc b/timm/models/layers/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05a8af9e83f3db84a316e2d6c075fc608bbfb37b Binary files /dev/null and b/timm/models/layers/__pycache__/config.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/config.cpython-38.pyc b/timm/models/layers/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5969a3681a4a40b92ab2cf18a4a33eb5b6e6d78a Binary files /dev/null and b/timm/models/layers/__pycache__/config.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/conv2d_same.cpython-310.pyc b/timm/models/layers/__pycache__/conv2d_same.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cec9674f4c921837abf72e17e3b1fcd17e6447c Binary files /dev/null and b/timm/models/layers/__pycache__/conv2d_same.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/conv2d_same.cpython-311.pyc b/timm/models/layers/__pycache__/conv2d_same.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61b34b2b81145abb06f1ba37fc8d643f12b44377 Binary files /dev/null and b/timm/models/layers/__pycache__/conv2d_same.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/conv2d_same.cpython-313.pyc b/timm/models/layers/__pycache__/conv2d_same.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcb7e2a309a59e2388f930bafbb31038959266e2 Binary files /dev/null and b/timm/models/layers/__pycache__/conv2d_same.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc b/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d679e73f081f4c68db3e5f2b7646205ed739e9cd Binary files /dev/null and b/timm/models/layers/__pycache__/conv2d_same.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/conv_bn_act.cpython-310.pyc b/timm/models/layers/__pycache__/conv_bn_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..108e308e3fe9978f66733dd5c157c98a33ecb49f Binary files /dev/null and b/timm/models/layers/__pycache__/conv_bn_act.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/conv_bn_act.cpython-311.pyc b/timm/models/layers/__pycache__/conv_bn_act.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a76c7b4ef31e69594a2c591293dcb01627dc7b7b Binary files /dev/null and b/timm/models/layers/__pycache__/conv_bn_act.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/conv_bn_act.cpython-313.pyc b/timm/models/layers/__pycache__/conv_bn_act.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eedcb448ddbf143682ba2f534eab065800f5d975 Binary files /dev/null and b/timm/models/layers/__pycache__/conv_bn_act.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc b/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc713cae96749cbd836f7fc83ba0a3bdf01b6b29 Binary files /dev/null and b/timm/models/layers/__pycache__/conv_bn_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_act.cpython-310.pyc b/timm/models/layers/__pycache__/create_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f49d29f01a502855a2e869fe42d46704d6aa54 Binary files /dev/null and b/timm/models/layers/__pycache__/create_act.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/create_act.cpython-311.pyc b/timm/models/layers/__pycache__/create_act.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5045efdb5278278e113517eed2bf790ff72d03ff Binary files /dev/null and b/timm/models/layers/__pycache__/create_act.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/create_act.cpython-313.pyc b/timm/models/layers/__pycache__/create_act.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f60edb19890d63fb4a6d6f86fabbb47d306b06e9 Binary files /dev/null and b/timm/models/layers/__pycache__/create_act.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/create_act.cpython-38.pyc b/timm/models/layers/__pycache__/create_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68adb64a84683d7901bdf187fecbd4575622e6f0 Binary files /dev/null and b/timm/models/layers/__pycache__/create_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_attn.cpython-310.pyc b/timm/models/layers/__pycache__/create_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55784b80031b73d7e114be9c4d19135b58e5c6be Binary files /dev/null and b/timm/models/layers/__pycache__/create_attn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/create_attn.cpython-311.pyc b/timm/models/layers/__pycache__/create_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a8b70f9b8a0df2f06b45e91ea5b15bb0c29883 Binary files /dev/null and b/timm/models/layers/__pycache__/create_attn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/create_attn.cpython-313.pyc b/timm/models/layers/__pycache__/create_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a418aaf6a1ab1f28c0f89fcf3d23d7b5f3f3d67 Binary files /dev/null and b/timm/models/layers/__pycache__/create_attn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/create_attn.cpython-38.pyc b/timm/models/layers/__pycache__/create_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aea2c0586a81784ed5259177fbdbcfb28cc97759 Binary files /dev/null and b/timm/models/layers/__pycache__/create_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_conv2d.cpython-310.pyc b/timm/models/layers/__pycache__/create_conv2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fae7c649a7bbd64c70031a20c80d3eb27480535b Binary files /dev/null and b/timm/models/layers/__pycache__/create_conv2d.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/create_conv2d.cpython-311.pyc b/timm/models/layers/__pycache__/create_conv2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21a40d515287f2b9c1d46529e4bbcfb484f045ba Binary files /dev/null and b/timm/models/layers/__pycache__/create_conv2d.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/create_conv2d.cpython-313.pyc b/timm/models/layers/__pycache__/create_conv2d.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6519ac58973fea5415af510710598bb45e882b09 Binary files /dev/null and b/timm/models/layers/__pycache__/create_conv2d.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc b/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..868425070fe7982ecf5939dfe75fbff17a8cdd04 Binary files /dev/null and b/timm/models/layers/__pycache__/create_conv2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/create_norm_act.cpython-310.pyc b/timm/models/layers/__pycache__/create_norm_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c645999d5e7975df750a141a448b68a151cf44a4 Binary files /dev/null and b/timm/models/layers/__pycache__/create_norm_act.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/create_norm_act.cpython-311.pyc b/timm/models/layers/__pycache__/create_norm_act.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e77cf81f2414338093541b8ab87cdbe2a0ad1bd Binary files /dev/null and b/timm/models/layers/__pycache__/create_norm_act.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/create_norm_act.cpython-313.pyc b/timm/models/layers/__pycache__/create_norm_act.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4abc92130a34c45a1ed19c4441437b2626da24ee Binary files /dev/null and b/timm/models/layers/__pycache__/create_norm_act.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc b/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56dd76b6af80599e2d553ae5f5489d6a7f185a35 Binary files /dev/null and b/timm/models/layers/__pycache__/create_norm_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/drop.cpython-310.pyc b/timm/models/layers/__pycache__/drop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1196e5a9a0508e3b4e6cc72131f6402e415379e Binary files /dev/null and b/timm/models/layers/__pycache__/drop.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/drop.cpython-311.pyc b/timm/models/layers/__pycache__/drop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18d24aa288efb3e6bd6161ed699add313db1834d Binary files /dev/null and b/timm/models/layers/__pycache__/drop.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/drop.cpython-313.pyc b/timm/models/layers/__pycache__/drop.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c9ab39cc89e2c2799cc4fb91d2bc27f607107c2 Binary files /dev/null and b/timm/models/layers/__pycache__/drop.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/drop.cpython-38.pyc b/timm/models/layers/__pycache__/drop.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11b76a55bbe3d642f1ed356beaf3ee9b608ff334 Binary files /dev/null and b/timm/models/layers/__pycache__/drop.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/eca.cpython-310.pyc b/timm/models/layers/__pycache__/eca.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48bdcf318f98d95b706b893d16fa9307eeafabbf Binary files /dev/null and b/timm/models/layers/__pycache__/eca.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/eca.cpython-311.pyc b/timm/models/layers/__pycache__/eca.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb596ac8f81209c70b455abddabf8c383894663c Binary files /dev/null and b/timm/models/layers/__pycache__/eca.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/eca.cpython-313.pyc b/timm/models/layers/__pycache__/eca.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f6ceeafccd5d963c2a23d7111a678d54a7f3053 Binary files /dev/null and b/timm/models/layers/__pycache__/eca.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/eca.cpython-38.pyc b/timm/models/layers/__pycache__/eca.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51c89247e1d0c72fe627658d0d1cd0916bd10b89 Binary files /dev/null and b/timm/models/layers/__pycache__/eca.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/evo_norm.cpython-310.pyc b/timm/models/layers/__pycache__/evo_norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1711e3a7046278c9a6046139110256ba39ad75d Binary files /dev/null and b/timm/models/layers/__pycache__/evo_norm.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/evo_norm.cpython-311.pyc b/timm/models/layers/__pycache__/evo_norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff7a4252a72f93b7223161a641d3df7cb59948f0 Binary files /dev/null and b/timm/models/layers/__pycache__/evo_norm.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/evo_norm.cpython-313.pyc b/timm/models/layers/__pycache__/evo_norm.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48da5818315a64e03524aff282b00d2b789ae82 Binary files /dev/null and b/timm/models/layers/__pycache__/evo_norm.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc b/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..691993bc5df2abc4419f9057dc095b0cdfd3eb73 Binary files /dev/null and b/timm/models/layers/__pycache__/evo_norm.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/gather_excite.cpython-310.pyc b/timm/models/layers/__pycache__/gather_excite.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a58f0116283dd9d0a36665007b529f14223a2088 Binary files /dev/null and b/timm/models/layers/__pycache__/gather_excite.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/gather_excite.cpython-311.pyc b/timm/models/layers/__pycache__/gather_excite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea94d84d55b392885aac85db9ede2978cc274f47 Binary files /dev/null and b/timm/models/layers/__pycache__/gather_excite.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/gather_excite.cpython-313.pyc b/timm/models/layers/__pycache__/gather_excite.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115d38cd793fb66cb41f5bd2cf610e369300cdcf Binary files /dev/null and b/timm/models/layers/__pycache__/gather_excite.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc b/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71a5f49c8af30cd746fae6a336984270034b46e9 Binary files /dev/null and b/timm/models/layers/__pycache__/gather_excite.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/global_context.cpython-310.pyc b/timm/models/layers/__pycache__/global_context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54b80fa26e9540e017d103f4edf990a25676a858 Binary files /dev/null and b/timm/models/layers/__pycache__/global_context.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/global_context.cpython-311.pyc b/timm/models/layers/__pycache__/global_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..161d7a97e83c91727e0f17607cb233d3a8176545 Binary files /dev/null and b/timm/models/layers/__pycache__/global_context.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/global_context.cpython-313.pyc b/timm/models/layers/__pycache__/global_context.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55defd1bba653c45c6160cb1db62dd3fa4329300 Binary files /dev/null and b/timm/models/layers/__pycache__/global_context.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/global_context.cpython-38.pyc b/timm/models/layers/__pycache__/global_context.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bd7336a7bdde03e84292d09951b3e9539b6d725 Binary files /dev/null and b/timm/models/layers/__pycache__/global_context.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/halo_attn.cpython-310.pyc b/timm/models/layers/__pycache__/halo_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36ff02c8632045723281989f77d8e34f388557ea Binary files /dev/null and b/timm/models/layers/__pycache__/halo_attn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/halo_attn.cpython-311.pyc b/timm/models/layers/__pycache__/halo_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..068b8e33e74828edb3b2a07e33bf1d4d35e517eb Binary files /dev/null and b/timm/models/layers/__pycache__/halo_attn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/halo_attn.cpython-313.pyc b/timm/models/layers/__pycache__/halo_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563d46c77db0110544ea2d3f9328250f05e5fb67 Binary files /dev/null and b/timm/models/layers/__pycache__/halo_attn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc b/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a03b93ea52be5c79d4d185aea7bee6e9371a3f9 Binary files /dev/null and b/timm/models/layers/__pycache__/halo_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/helpers.cpython-310.pyc b/timm/models/layers/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f6aadb7cd48df1c5e159d891c155b4b649e42ca Binary files /dev/null and b/timm/models/layers/__pycache__/helpers.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/helpers.cpython-311.pyc b/timm/models/layers/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e79cb0910934c595dad50437083c1b053a024589 Binary files /dev/null and b/timm/models/layers/__pycache__/helpers.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/helpers.cpython-313.pyc b/timm/models/layers/__pycache__/helpers.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56bff7f4a0b904032a2b801a942822b63cfcb6c9 Binary files /dev/null and b/timm/models/layers/__pycache__/helpers.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/helpers.cpython-38.pyc b/timm/models/layers/__pycache__/helpers.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e2ae30bad1205668ee75edcff579169e2d043d Binary files /dev/null and b/timm/models/layers/__pycache__/helpers.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/inplace_abn.cpython-310.pyc b/timm/models/layers/__pycache__/inplace_abn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53b76ef66a5bbd2b32b2088a0f8ed08b79574244 Binary files /dev/null and b/timm/models/layers/__pycache__/inplace_abn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/inplace_abn.cpython-311.pyc b/timm/models/layers/__pycache__/inplace_abn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a17bef3bba795574e942e3d9faebb4d6df72a47 Binary files /dev/null and b/timm/models/layers/__pycache__/inplace_abn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/inplace_abn.cpython-313.pyc b/timm/models/layers/__pycache__/inplace_abn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5806ced39ffdaf85d1d5efb55e9a4e7ef2923a Binary files /dev/null and b/timm/models/layers/__pycache__/inplace_abn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc b/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc29e26d881af82bbcecab2ea62b56f9ce282c2b Binary files /dev/null and b/timm/models/layers/__pycache__/inplace_abn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/involution.cpython-310.pyc b/timm/models/layers/__pycache__/involution.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b590eeca06f4c101abb39df3f3c86ffb6de0e47f Binary files /dev/null and b/timm/models/layers/__pycache__/involution.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/involution.cpython-311.pyc b/timm/models/layers/__pycache__/involution.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..684f53b920e24f3fe603710099909bb61a9cc7cb Binary files /dev/null and b/timm/models/layers/__pycache__/involution.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/involution.cpython-313.pyc b/timm/models/layers/__pycache__/involution.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2af67dead431ed47653c74dc115c38937c6fd85 Binary files /dev/null and b/timm/models/layers/__pycache__/involution.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/involution.cpython-38.pyc b/timm/models/layers/__pycache__/involution.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f65f741acb86b934f6806c62993f295622033ffe Binary files /dev/null and b/timm/models/layers/__pycache__/involution.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/lambda_layer.cpython-310.pyc b/timm/models/layers/__pycache__/lambda_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9480d8a7d85a79fa12d5b46009bab501a828c6dc Binary files /dev/null and b/timm/models/layers/__pycache__/lambda_layer.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/lambda_layer.cpython-311.pyc b/timm/models/layers/__pycache__/lambda_layer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..863239ec9f34912168e1779e15d6e333b1be294f Binary files /dev/null and b/timm/models/layers/__pycache__/lambda_layer.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/lambda_layer.cpython-313.pyc b/timm/models/layers/__pycache__/lambda_layer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cada0bb283cd36eed1e9e4bf25514bddcd9f6474 Binary files /dev/null and b/timm/models/layers/__pycache__/lambda_layer.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc b/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc3060bf1ac2da18ebdfb97421b223281fa254b5 Binary files /dev/null and b/timm/models/layers/__pycache__/lambda_layer.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/linear.cpython-310.pyc b/timm/models/layers/__pycache__/linear.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3cc1fac54a82211cd7e85b67bccebe6057413a4 Binary files /dev/null and b/timm/models/layers/__pycache__/linear.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/linear.cpython-311.pyc b/timm/models/layers/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e857ebe4ede62a910d54cdf2c9450b57c4bc161 Binary files /dev/null and b/timm/models/layers/__pycache__/linear.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/linear.cpython-313.pyc b/timm/models/layers/__pycache__/linear.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef43e2bb712debc179aebbbd16721698174a89e5 Binary files /dev/null and b/timm/models/layers/__pycache__/linear.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/linear.cpython-38.pyc b/timm/models/layers/__pycache__/linear.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebbc3d338500a32b5363445ec9a7030028dc8af Binary files /dev/null and b/timm/models/layers/__pycache__/linear.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/median_pool.cpython-38.pyc b/timm/models/layers/__pycache__/median_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..806ba48a599cb6bde6ccdd7ccb4a35969f45f4a6 Binary files /dev/null and b/timm/models/layers/__pycache__/median_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/mixed_conv2d.cpython-310.pyc b/timm/models/layers/__pycache__/mixed_conv2d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff0660f64aed326645f18571a8618ce681a5a0ee Binary files /dev/null and b/timm/models/layers/__pycache__/mixed_conv2d.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/mixed_conv2d.cpython-311.pyc b/timm/models/layers/__pycache__/mixed_conv2d.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d42786a7893bf69b37187b798d76f796c827ae2 Binary files /dev/null and b/timm/models/layers/__pycache__/mixed_conv2d.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/mixed_conv2d.cpython-313.pyc b/timm/models/layers/__pycache__/mixed_conv2d.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8964f1a39510b1c2e0678c53b5947a48fd40ea69 Binary files /dev/null and b/timm/models/layers/__pycache__/mixed_conv2d.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc b/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11ab8d5fcce3457f2660deb50804f573466e9755 Binary files /dev/null and b/timm/models/layers/__pycache__/mixed_conv2d.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/mlp.cpython-310.pyc b/timm/models/layers/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7f754b2b6dedefab1705e5fc508a70bce97f947 Binary files /dev/null and b/timm/models/layers/__pycache__/mlp.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/mlp.cpython-311.pyc b/timm/models/layers/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..908bdc1b1f00b1d2e745d24300eec5bf7b7c5577 Binary files /dev/null and b/timm/models/layers/__pycache__/mlp.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/mlp.cpython-313.pyc b/timm/models/layers/__pycache__/mlp.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843bc524e42246f9c1145cef7c27231aff2d4487 Binary files /dev/null and b/timm/models/layers/__pycache__/mlp.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/mlp.cpython-38.pyc b/timm/models/layers/__pycache__/mlp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffba1bca5a3a2037976dd471d7f269e596b3dc7a Binary files /dev/null and b/timm/models/layers/__pycache__/mlp.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/non_local_attn.cpython-310.pyc b/timm/models/layers/__pycache__/non_local_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0dff5433ae9b52193ef51790fd7164457f8d9b8 Binary files /dev/null and b/timm/models/layers/__pycache__/non_local_attn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/non_local_attn.cpython-311.pyc b/timm/models/layers/__pycache__/non_local_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892bd537288fd464fb0ff0612767bd97a2d03cc3 Binary files /dev/null and b/timm/models/layers/__pycache__/non_local_attn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/non_local_attn.cpython-313.pyc b/timm/models/layers/__pycache__/non_local_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30ea7c11680630a1a80049f50c21c9a8d1cf4fe2 Binary files /dev/null and b/timm/models/layers/__pycache__/non_local_attn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc b/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa7a46ac1cdfff1b3494cd060df75a79c9c4a09 Binary files /dev/null and b/timm/models/layers/__pycache__/non_local_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/norm.cpython-310.pyc b/timm/models/layers/__pycache__/norm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66b466380db2cc0cb7497baf160cb973b892968e Binary files /dev/null and b/timm/models/layers/__pycache__/norm.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/norm.cpython-311.pyc b/timm/models/layers/__pycache__/norm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..188b8c3e2e6e51677d672c29866936cfc106d2a3 Binary files /dev/null and b/timm/models/layers/__pycache__/norm.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/norm.cpython-313.pyc b/timm/models/layers/__pycache__/norm.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de126d3f7a417f85df5e2bd1b0856bfc2d52774 Binary files /dev/null and b/timm/models/layers/__pycache__/norm.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/norm.cpython-38.pyc b/timm/models/layers/__pycache__/norm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..512350fcf81f772fe72969d94768e142ce30ec47 Binary files /dev/null and b/timm/models/layers/__pycache__/norm.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/norm_act.cpython-310.pyc b/timm/models/layers/__pycache__/norm_act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd4f0d45426423dea06ecec3aa41f71792213463 Binary files /dev/null and b/timm/models/layers/__pycache__/norm_act.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/norm_act.cpython-311.pyc b/timm/models/layers/__pycache__/norm_act.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85ef8e3624f20053b32086b647c7f1f12f50bf99 Binary files /dev/null and b/timm/models/layers/__pycache__/norm_act.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/norm_act.cpython-313.pyc b/timm/models/layers/__pycache__/norm_act.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8eff47198e6111388a3f7c36a56c337e8233461 Binary files /dev/null and b/timm/models/layers/__pycache__/norm_act.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/norm_act.cpython-38.pyc b/timm/models/layers/__pycache__/norm_act.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52da2abd36d6b7c8426414e4d37d9b4941b55cfe Binary files /dev/null and b/timm/models/layers/__pycache__/norm_act.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/padding.cpython-310.pyc b/timm/models/layers/__pycache__/padding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..062dc72be43fe62be1e2df51933455a20bea164d Binary files /dev/null and b/timm/models/layers/__pycache__/padding.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/padding.cpython-311.pyc b/timm/models/layers/__pycache__/padding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e19732260bd2be90cb9de04c25591c62e4d83beb Binary files /dev/null and b/timm/models/layers/__pycache__/padding.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/padding.cpython-313.pyc b/timm/models/layers/__pycache__/padding.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b86336f42ff7f6f3ef8ea39d3eff1de657facd0b Binary files /dev/null and b/timm/models/layers/__pycache__/padding.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/padding.cpython-38.pyc b/timm/models/layers/__pycache__/padding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4086d57ffece38b31153f6b34443185fb22e96e4 Binary files /dev/null and b/timm/models/layers/__pycache__/padding.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/patch_embed.cpython-310.pyc b/timm/models/layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb2ad55e3a80c2a861d993ad728321e59f06ec2a Binary files /dev/null and b/timm/models/layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/patch_embed.cpython-311.pyc b/timm/models/layers/__pycache__/patch_embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72c4a11a56731e834d3aa5da6bfe3da5eba3121c Binary files /dev/null and b/timm/models/layers/__pycache__/patch_embed.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/patch_embed.cpython-313.pyc b/timm/models/layers/__pycache__/patch_embed.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22d3da8ca842a71e2f8ea97248b42d0436642bb5 Binary files /dev/null and b/timm/models/layers/__pycache__/patch_embed.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc b/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f40d058da658095bfc6900a4e6a57caaf29e1a1 Binary files /dev/null and b/timm/models/layers/__pycache__/patch_embed.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/pool2d_same.cpython-310.pyc b/timm/models/layers/__pycache__/pool2d_same.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdb145cb2da188050e57680b4a942681d1943f06 Binary files /dev/null and b/timm/models/layers/__pycache__/pool2d_same.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/pool2d_same.cpython-311.pyc b/timm/models/layers/__pycache__/pool2d_same.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1373e3c10ad7d957403e4d1fd11190243d80fb Binary files /dev/null and b/timm/models/layers/__pycache__/pool2d_same.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/pool2d_same.cpython-313.pyc b/timm/models/layers/__pycache__/pool2d_same.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44ccff05935fcd173fae1a6025993836ece888c9 Binary files /dev/null and b/timm/models/layers/__pycache__/pool2d_same.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc b/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78c07c11ce43de2f109e5c12738ed9943cb6ca2e Binary files /dev/null and b/timm/models/layers/__pycache__/pool2d_same.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/selective_kernel.cpython-310.pyc b/timm/models/layers/__pycache__/selective_kernel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9913a446e635179da2394943cc6b4a4209039217 Binary files /dev/null and b/timm/models/layers/__pycache__/selective_kernel.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/selective_kernel.cpython-311.pyc b/timm/models/layers/__pycache__/selective_kernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc23a97f2aeb93bd24a2fee5002c8c3d9b6419e Binary files /dev/null and b/timm/models/layers/__pycache__/selective_kernel.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/selective_kernel.cpython-313.pyc b/timm/models/layers/__pycache__/selective_kernel.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0007cecd48af79f39af3d6fcf3e6f35fdc4aaba Binary files /dev/null and b/timm/models/layers/__pycache__/selective_kernel.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc b/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a83003985518d29aaa02bbff856885761672f4 Binary files /dev/null and b/timm/models/layers/__pycache__/selective_kernel.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/separable_conv.cpython-310.pyc b/timm/models/layers/__pycache__/separable_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1429376fb8c4a460084eaa5614d766edb94873b Binary files /dev/null and b/timm/models/layers/__pycache__/separable_conv.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/separable_conv.cpython-311.pyc b/timm/models/layers/__pycache__/separable_conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04726b63886c8dc74cae75d35e457fbdb906a476 Binary files /dev/null and b/timm/models/layers/__pycache__/separable_conv.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/separable_conv.cpython-313.pyc b/timm/models/layers/__pycache__/separable_conv.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92ac223a222434b9c3852ef2e729b0dabada287c Binary files /dev/null and b/timm/models/layers/__pycache__/separable_conv.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc b/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc945d6178d1127a5a2cff9b1d37c5846dc0d430 Binary files /dev/null and b/timm/models/layers/__pycache__/separable_conv.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/space_to_depth.cpython-310.pyc b/timm/models/layers/__pycache__/space_to_depth.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95e9da8ca7aed20a93745e70f94541e126c314be Binary files /dev/null and b/timm/models/layers/__pycache__/space_to_depth.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/space_to_depth.cpython-311.pyc b/timm/models/layers/__pycache__/space_to_depth.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f13fed01e10e567e405d00c2244c42ed1e9f177 Binary files /dev/null and b/timm/models/layers/__pycache__/space_to_depth.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/space_to_depth.cpython-313.pyc b/timm/models/layers/__pycache__/space_to_depth.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..575ff837bcff3200455d1f4db191158ba928031e Binary files /dev/null and b/timm/models/layers/__pycache__/space_to_depth.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc b/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55e5e060ad6c735fb99c7ce33c2d6451d6041993 Binary files /dev/null and b/timm/models/layers/__pycache__/space_to_depth.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/split_attn.cpython-310.pyc b/timm/models/layers/__pycache__/split_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ccc2b0a8e0b921ae2aac3a595e8d815c0bfd735 Binary files /dev/null and b/timm/models/layers/__pycache__/split_attn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/split_attn.cpython-311.pyc b/timm/models/layers/__pycache__/split_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01ef729ccfa0c9e186581615d4614b4c7a88b074 Binary files /dev/null and b/timm/models/layers/__pycache__/split_attn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/split_attn.cpython-313.pyc b/timm/models/layers/__pycache__/split_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3f4d13fb86ca19757855e8a7632f7a2121ee621 Binary files /dev/null and b/timm/models/layers/__pycache__/split_attn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/split_attn.cpython-38.pyc b/timm/models/layers/__pycache__/split_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a777b7acc76bccae2c88eb1e905f45ce93c0c43e Binary files /dev/null and b/timm/models/layers/__pycache__/split_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/split_batchnorm.cpython-310.pyc b/timm/models/layers/__pycache__/split_batchnorm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c58935c797c85e2c02fdf5f6d96d472f74241720 Binary files /dev/null and b/timm/models/layers/__pycache__/split_batchnorm.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/split_batchnorm.cpython-311.pyc b/timm/models/layers/__pycache__/split_batchnorm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6705fb7ae5011c44ada859048cc0e47c8760564a Binary files /dev/null and b/timm/models/layers/__pycache__/split_batchnorm.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/split_batchnorm.cpython-313.pyc b/timm/models/layers/__pycache__/split_batchnorm.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78aef51b210509b98ee42f5ebf12d7c12b62e15f Binary files /dev/null and b/timm/models/layers/__pycache__/split_batchnorm.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc b/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fb5ac577e82355b3d84b3258c7c539926d0c0f9 Binary files /dev/null and b/timm/models/layers/__pycache__/split_batchnorm.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/squeeze_excite.cpython-310.pyc b/timm/models/layers/__pycache__/squeeze_excite.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8175c25bec0ecd576b2ee4327c75886e68c73b6 Binary files /dev/null and b/timm/models/layers/__pycache__/squeeze_excite.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/squeeze_excite.cpython-311.pyc b/timm/models/layers/__pycache__/squeeze_excite.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2503b8839c24e84242cdb19efd06abefca8db6b Binary files /dev/null and b/timm/models/layers/__pycache__/squeeze_excite.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/squeeze_excite.cpython-313.pyc b/timm/models/layers/__pycache__/squeeze_excite.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..409437d7b5790b4603d60f2607fa0d7197a80458 Binary files /dev/null and b/timm/models/layers/__pycache__/squeeze_excite.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc b/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4379451a57c1bd72fe206129be3e23340697e755 Binary files /dev/null and b/timm/models/layers/__pycache__/squeeze_excite.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/std_conv.cpython-310.pyc b/timm/models/layers/__pycache__/std_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8519ae903aea1d456e1dad8a406df8e5df44269 Binary files /dev/null and b/timm/models/layers/__pycache__/std_conv.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/std_conv.cpython-311.pyc b/timm/models/layers/__pycache__/std_conv.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1691eb4b72203b1eceecc7b6f693147c11af6a61 Binary files /dev/null and b/timm/models/layers/__pycache__/std_conv.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/std_conv.cpython-313.pyc b/timm/models/layers/__pycache__/std_conv.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59e33c32263ef61f3d6497d7c2c01d26e4a2ece0 Binary files /dev/null and b/timm/models/layers/__pycache__/std_conv.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/std_conv.cpython-38.pyc b/timm/models/layers/__pycache__/std_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..572331f5ddbb4c382337bf2949caeb5521334e6e Binary files /dev/null and b/timm/models/layers/__pycache__/std_conv.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/swin_attn.cpython-310.pyc b/timm/models/layers/__pycache__/swin_attn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3537a26154459936c3fba8d46019d33cc3241a78 Binary files /dev/null and b/timm/models/layers/__pycache__/swin_attn.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/swin_attn.cpython-311.pyc b/timm/models/layers/__pycache__/swin_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..007853e61b4b8a84d82bc5206b1821335940d170 Binary files /dev/null and b/timm/models/layers/__pycache__/swin_attn.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/swin_attn.cpython-313.pyc b/timm/models/layers/__pycache__/swin_attn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10c5f2c05b5d367a9159e7f9b08bf806d29dc0ee Binary files /dev/null and b/timm/models/layers/__pycache__/swin_attn.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc b/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..894cda313ab8a478c0bf9097ceb6262385caf2cb Binary files /dev/null and b/timm/models/layers/__pycache__/swin_attn.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/test_time_pool.cpython-310.pyc b/timm/models/layers/__pycache__/test_time_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63e998d26f6d999acb6b78bb2d10cb24e61fcaf Binary files /dev/null and b/timm/models/layers/__pycache__/test_time_pool.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/test_time_pool.cpython-311.pyc b/timm/models/layers/__pycache__/test_time_pool.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ea04908b0742a1d4da0ca4f200b1dca43ff3d8d Binary files /dev/null and b/timm/models/layers/__pycache__/test_time_pool.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/test_time_pool.cpython-313.pyc b/timm/models/layers/__pycache__/test_time_pool.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1459aacb6156c52d1f24c4df44ab225d15331cb3 Binary files /dev/null and b/timm/models/layers/__pycache__/test_time_pool.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc b/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4832a13a8b4061937bdea65a6beb7dd85b858c78 Binary files /dev/null and b/timm/models/layers/__pycache__/test_time_pool.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/trace_utils.cpython-38.pyc b/timm/models/layers/__pycache__/trace_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c576f99e63a2ed749c176018ff838ca1c12057fe Binary files /dev/null and b/timm/models/layers/__pycache__/trace_utils.cpython-38.pyc differ diff --git a/timm/models/layers/__pycache__/weight_init.cpython-310.pyc b/timm/models/layers/__pycache__/weight_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8d184f0e1b632180d08a104afbf2f6f9f699ba Binary files /dev/null and b/timm/models/layers/__pycache__/weight_init.cpython-310.pyc differ diff --git a/timm/models/layers/__pycache__/weight_init.cpython-311.pyc b/timm/models/layers/__pycache__/weight_init.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25d88e11b796ce5447f2376479e5bbf5a4cec323 Binary files /dev/null and b/timm/models/layers/__pycache__/weight_init.cpython-311.pyc differ diff --git a/timm/models/layers/__pycache__/weight_init.cpython-313.pyc b/timm/models/layers/__pycache__/weight_init.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9030ef4893b4f375706d8d9e6f20c6e5b00d67e3 Binary files /dev/null and b/timm/models/layers/__pycache__/weight_init.cpython-313.pyc differ diff --git a/timm/models/layers/__pycache__/weight_init.cpython-38.pyc b/timm/models/layers/__pycache__/weight_init.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b377bb4268f0e3f4b634255b76264efd8e4c75eb Binary files /dev/null and b/timm/models/layers/__pycache__/weight_init.cpython-38.pyc differ diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..e16b3bd3a1898365530c1ffc5154a0a4746a136e --- /dev/null +++ b/timm/models/layers/activations.py @@ -0,0 +1,145 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + NOTE: I don't have a working inplace variant + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + + def forward(self, x): + return mish(x) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + +def hard_mish(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + if inplace: + return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) + else: + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_mish(x, self.inplace) + + +class PReLU(nn.PReLU): + """Applies PReLU (w/ dummy inplace arg) + """ + def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: + super(PReLU, self).__init__(num_parameters=num_parameters, init=init) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.prelu(input, self.weight) + + +def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x) + + +class GELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input) diff --git a/timm/models/layers/activations_jit.py b/timm/models/layers/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a516530ad0abf41f720ac83d02791179bb7b67 --- /dev/null +++ b/timm/models/layers/activations_jit.py @@ -0,0 +1,90 @@ +""" Activations + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) + + +@torch.jit.script +def hard_mish_jit(x, inplace: bool = False): + """ Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +class HardMishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishJit, self).__init__() + + def forward(self, x): + return hard_mish_jit(x) diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..9a12bb7ebbfef02c508801742d38da6b48dd1bb6 --- /dev/null +++ b/timm/models/layers/activations_me.py @@ -0,0 +1,218 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + @staticmethod + def symbolic(g, x): + return g.op("Mul", x, g.op("Sigmoid", x)) + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_mish_jit_fwd(x): + return 0.5 * x * (x + 2).clamp(min=0, max=2) + + +@torch.jit.script +def hard_mish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= -2.) + m = torch.where((x >= -2.) & (x <= 0.), x + 1., m) + return grad_output * m + + +class HardMishJitAutoFn(torch.autograd.Function): + """ A memory efficient, jit scripted variant of Hard Mish + Experimental, based on notes by Mish author Diganta Misra at + https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_mish_jit_bwd(x, grad_output) + + +def hard_mish_me(x, inplace: bool = False): + return HardMishJitAutoFn.apply(x) + + +class HardMishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardMishMe, self).__init__() + + def forward(self, x): + return HardMishJitAutoFn.apply(x) + + + diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc6ada8c5b28c7eac5785b0cc2933eb01a15d46 --- /dev/null +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -0,0 +1,118 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_pool_feat_mult(pool_type='avg'): + if pool_type == 'catavgmax': + return 2 + else: + return 1 + + +def adaptive_avgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return 0.5 * (x_avg + x_max) + + +def adaptive_catavgmax_pool2d(x, output_size=1): + x_avg = F.adaptive_avg_pool2d(x, output_size) + x_max = F.adaptive_max_pool2d(x, output_size) + return torch.cat((x_avg, x_max), 1) + + +def select_adaptive_pool2d(x, pool_type='avg', output_size=1): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avg': + x = F.adaptive_avg_pool2d(x, output_size) + elif pool_type == 'avgmax': + x = adaptive_avgmax_pool2d(x, output_size) + elif pool_type == 'catavgmax': + x = adaptive_catavgmax_pool2d(x, output_size) + elif pool_type == 'max': + x = F.adaptive_max_pool2d(x, output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + return x + + +class FastAdaptiveAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastAdaptiveAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + return x.mean((2, 3), keepdim=not self.flatten) + + +class AdaptiveAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_avgmax_pool2d(x, self.output_size) + + +class AdaptiveCatAvgMaxPool2d(nn.Module): + def __init__(self, output_size=1): + super(AdaptiveCatAvgMaxPool2d, self).__init__() + self.output_size = output_size + + def forward(self, x): + return adaptive_catavgmax_pool2d(x, self.output_size) + + +class SelectAdaptivePool2d(nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='fast', flatten=False): + super(SelectAdaptivePool2d, self).__init__() + self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing + self.flatten = nn.Flatten(1) if flatten else nn.Identity() + if pool_type == '': + self.pool = nn.Identity() # pass through + elif pool_type == 'fast': + assert output_size == 1 + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() + elif pool_type == 'avg': + self.pool = nn.AdaptiveAvgPool2d(output_size) + elif pool_type == 'avgmax': + self.pool = AdaptiveAvgMaxPool2d(output_size) + elif pool_type == 'catavgmax': + self.pool = AdaptiveCatAvgMaxPool2d(output_size) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + assert False, 'Invalid pool type: %s' % pool_type + + def is_identity(self): + return not self.pool_type + + def forward(self, x): + x = self.pool(x) + x = self.flatten(x) + return x + + def feat_mult(self): + return adaptive_pool_feat_mult(self.pool_type) + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'pool_type=' + self.pool_type \ + + ', flatten=' + str(self.flatten) + ')' + diff --git a/timm/models/layers/attention_pool2d.py b/timm/models/layers/attention_pool2d.py new file mode 100644 index 0000000000000000000000000000000000000000..66e49b8a93d7455077866f74421e0027ca73d95a --- /dev/null +++ b/timm/models/layers/attention_pool2d.py @@ -0,0 +1,182 @@ +""" Attention Pool 2D + +Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. + +Based on idea in CLIP by OpenAI, licensed Apache 2.0 +https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +from typing import List, Union, Tuple + +import torch +import torch.nn as nn + +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +class RotaryEmbedding(nn.Module): + """ Rotary position embedding + + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + def __init__(self, dim, max_freq=4): + super().__init__() + self.dim = dim + self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False) + + def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None): + """ + NOTE: shape arg should include spatial dim only + """ + device = device or self.bands.device + dtype = dtype or self.bands.dtype + if not isinstance(shape, torch.Size): + shape = torch.Size(shape) + N = shape.numel() + grid = torch.stack(torch.meshgrid( + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1) + emb = grid * math.pi * self.bands + sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1) + cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1) + return sin, cos + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) + + +class RotAttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from + train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW + """ + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + assert embed_dim % num_heads == 0 + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + self.pos_embed = RotaryEmbedding(self.head_dim) + + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:]) + x = x.reshape(B, -1, N).permute(0, 2, 1) + + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + + qc, q = q[:, :, :1], q[:, :, 1:] + q = apply_rot_embed(q, sin_emb, cos_emb) + q = torch.cat([qc, q], dim=2) + + kc, k = k[:, :, :1], k[:, :, 1:] + k = apply_rot_embed(k, sin_emb, cos_emb) + k = torch.cat([kc, k], dim=2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] + + +class AttentionPool2d(nn.Module): + """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + + It was based on impl in CLIP by OpenAI + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + + NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. + """ + def __init__( + self, + in_features: int, + feat_size: Union[int, Tuple[int, int]], + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.feat_size = to_2tuple(feat_size) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim ** -0.5 + + spatial_dim = self.feat_size[0] * self.feat_size[1] + self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + trunc_normal_(self.pos_embed, std=in_features ** -0.5) + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + assert self.feat_size[0] == H + assert self.feat_size[1] == W + x = x.reshape(B, -1, N).permute(0, 2, 1) + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4ce756e434d577c38a20e2e8de2909777862d4 --- /dev/null +++ b/timm/models/layers/blur_pool.py @@ -0,0 +1,42 @@ +""" +BlurPool layer inspired by + - Kornia's Max_BlurPool2d + - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` + +Hacked together by Chris Ha and Ross Wightman +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .padding import get_padding + + +class BlurPool2d(nn.Module): + r"""Creates a module that computes blurs and downsample a given feature map. + See :cite:`zhang2019shiftinvar` for more details. + Corresponds to the Downsample class, which does blurring and subsampling + + Args: + channels = Number of input channels + filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. + stride (int): downsampling filter stride + + Returns: + torch.Tensor: the transformed tensor. + """ + def __init__(self, channels, filt_size=3, stride=2) -> None: + super(BlurPool2d, self).__init__() + assert filt_size > 1 + self.channels = channels + self.filt_size = filt_size + self.stride = stride + self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 + coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) + blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) + self.register_buffer('filt', blur_filter, persistent=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, self.padding, 'reflect') + return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..9604e8a6cfb992c50bc1fc15c54979f30b1d2c94 --- /dev/null +++ b/timm/models/layers/bottleneck_attn.py @@ -0,0 +1,126 @@ +""" Bottleneck Self Attention (Bottleneck Transformers) + +Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + +@misc{2101.11605, +Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani}, +Title = {Bottleneck Transformers for Visual Recognition}, +Year = {2021}, +} + +Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + +This impl is a WIP but given that it is based on the ref gist likely not too far off. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, num_heads, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B * num_heads, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + """ + def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_out = dim_out + self.dim_head = dim_out // num_heads + self.scale = self.dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.pos_embed.height and W == self.pos_embed.width + + x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W + x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) + q, k, v = torch.split(x, self.num_heads, dim=1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale + attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + + attn_out = attn_logits.softmax(dim = -1) + attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + attn_out = self.pool(attn_out) + return attn_out + + diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..bacf5cf07b695ce6c5fd87facc79f6a5773e6ecf --- /dev/null +++ b/timm/models/layers/cbam.py @@ -0,0 +1,112 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on +some tasks, especially fine-grained it seems. I may end up removing this impl. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn +import torch.nn.functional as F + +from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(ChannelAttn, self).__init__() + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) + + def forward(self, x): + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * F.sigmoid(x_attn) + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(SpatialAttn, self).__init__() + self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7, gate_layer='sigmoid'): + super(LightSpatialAttn, self).__init__() + self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) + x_attn = self.conv(x_attn) + return x * self.gate(x_attn) + + +class CbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(CbamModule, self).__init__() + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2b74541341ad24bfb97f7ea90ac6470b83a73aa3 --- /dev/null +++ b/timm/models/layers/classifier.py @@ -0,0 +1,56 @@ +""" Classifier head and layer factory + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .linear import Linear + + +def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): + flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling + if not pool_type: + assert num_classes == 0 or use_conv,\ + 'Pooling can only be disabled if classifier is also removed or conv classifier is used' + flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) + global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) + num_pooled_features = num_features * global_pool.feat_mult() + return global_pool, num_pooled_features + + +def _create_fc(num_features, num_classes, use_conv=False): + if num_classes <= 0: + fc = nn.Identity() # pass-through (no classifier) + elif use_conv: + fc = nn.Conv2d(num_features, num_classes, 1, bias=True) + else: + # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue + fc = Linear(num_features, num_classes, bias=True) + return fc + + +def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): + global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) + fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + return global_pool, fc + + +class ClassifierHead(nn.Module): + """Classifier head w/ configurable global pooling and dropout.""" + + def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + super(ClassifierHead, self).__init__() + self.drop_rate = drop_rate + self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + + def forward(self, x): + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + x = self.flatten(x) + return x diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..8b4bbca84d6f12e0fb875b4edb435b976fc649d6 --- /dev/null +++ b/timm/models/layers/cond_conv2d.py @@ -0,0 +1,122 @@ +""" PyTorch Conditionally Parameterized Convolution (CondConv) + +Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference +(https://arxiv.org/abs/1904.04971) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .helpers import to_2tuple +from .conv2d_same import conv2d_same +from .padding import get_padding_value + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditionally Parameterized Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = to_2tuple(padding_val) + self.dilation = to_2tuple(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/timm/models/layers/config.py b/timm/models/layers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f07b9d782ba0597c174dee81097c28280335fdba --- /dev/null +++ b/timm/models/layers/config.py @@ -0,0 +1,115 @@ +""" Model / Layer Config singleton state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..75f0f98d4ec1e3f4a0dc004b977815afaa25e7fc --- /dev/null +++ b/timm/models/layers/conv2d_same.py @@ -0,0 +1,42 @@ +""" Conv2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +from .padding import pad_same, get_padding_value + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py new file mode 100644 index 0000000000000000000000000000000000000000..33005c37b752bd995aeb983ad8480c36b94d0a0c --- /dev/null +++ b/timm/models/layers/conv_bn_act.py @@ -0,0 +1,40 @@ +""" Conv2d + BN + Act + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, + drop_block=None): + super(ConvBnAct, self).__init__() + use_aa = aa_layer is not None + + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.aa is not None: + x = self.aa(x) + return x diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py new file mode 100644 index 0000000000000000000000000000000000000000..aa557692accff431fe1f9cfb7a5c6d94314b14f6 --- /dev/null +++ b/timm/models/layers/create_act.py @@ -0,0 +1,153 @@ +""" Activation Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Union, Callable, Type + +from .activations import * +from .activations_jit import * +from .activations_me import * +from .config import is_exportable, is_scriptable, is_no_jit + +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. +# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present. +# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used. +_has_silu = 'silu' in dir(torch.nn.functional) +_has_hardswish = 'hardswish' in dir(torch.nn.functional) +_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional) +_has_mish = 'mish' in dir(torch.nn.functional) + + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=F.mish if _has_mish else mish, + relu=F.relu, + relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + celu=F.celu, + selu=F.selu, + gelu=gelu, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, + hard_swish=F.hardswish if _has_hardswish else hard_swish, + hard_mish=hard_mish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=F.mish if _has_mish else mish_jit, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit, + hard_swish=F.hardswish if _has_hardswish else hard_swish_jit, + hard_mish=hard_mish_jit +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=F.mish if _has_mish else mish_me, + hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me, + hard_swish=F.hardswish if _has_hardswish else hard_swish_me, + hard_mish=hard_mish_me, +) + +_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT) +for a in _ACT_FNS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=nn.Mish if _has_mish else Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + leaky_relu=nn.LeakyReLU, + elu=nn.ELU, + prelu=PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=GELU, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, + hard_swish=nn.Hardswish if _has_hardswish else HardSwish, + hard_mish=HardMish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=nn.Mish if _has_mish else MishJit, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit, + hard_mish=HardMishJit +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=nn.Mish if _has_mish else MishMe, + hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe, + hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe, + hard_mish=HardMishMe, +) + +_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT) +for a in _ACT_LAYERS: + a.setdefault('hardsigmoid', a.get('hard_sigmoid')) + a.setdefault('hardswish', a.get('hard_swish')) + + +def get_act_fn(name: Union[Callable, str] = 'relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, Callable): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + # If not exporting or scripting the model, first look for a memory-efficient version with + # custom autograd, then fallback + if name in _ACT_FN_ME: + return _ACT_FN_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_FN_JIT: + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if not name: + return None + if isinstance(name, type): + return name + if not (is_no_jit() or is_exportable() or is_scriptable()): + if name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): + if name in _ACT_LAYER_JIT: + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + +def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs): + act_layer = get_act_layer(name) + if act_layer is None: + return None + return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs) diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed646b3378774c3155cf0d01651a043164ef21 --- /dev/null +++ b/timm/models/layers/create_attn.py @@ -0,0 +1,93 @@ +""" Attention Factory + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from functools import partial + +from .bottleneck_attn import BottleneckAttn +from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .involution import Involution +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn +from .squeeze_excite import SEModule, EffectiveSEModule +from .swin_attn import WindowAttention + + +def get_attn(attn_type): + if isinstance(attn_type, torch.nn.Module): + return attn_type + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'ese': + module_cls = EffectiveSEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) + elif attn_type == 'ceca': + module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'swin': + return WindowAttention + elif attn_type == 'involution': + return Involution + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + return module_cls + + +def create_attn(attn_type, channels, **kwargs): + module_cls = get_attn(attn_type) + if module_cls is not None: + # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels + return module_cls(channels, **kwargs) + return None diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..3a0cc03a5c8c23fe047d1d3c24782700422e2e6e --- /dev/null +++ b/timm/models/layers/create_conv2d.py @@ -0,0 +1,31 @@ +""" Create Conv2d Factory Method + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + assert 'groups' not in kwargs # MixedConv groups are defined by kernel list + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop('groups', 1) + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) + return m diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5629457dc14b5da3b9673b7e21d7d80f7cda4c --- /dev/null +++ b/timm/models/layers/create_norm_act.py @@ -0,0 +1,83 @@ +""" NormAct (Normalizaiton + Activation Layer) Factory + +Create norm + act combo modules that attempt to be backwards compatible with separate norm + act +isntances in models. Where these are used it will be possible to swap separate BN + act layers with +combined modules like IABN or EvoNorms. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import types +import functools + +import torch +import torch.nn as nn + +from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .norm_act import BatchNormAct2d, GroupNormAct +from .inplace_abn import InplaceAbn + +_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} +_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type + + +def get_norm_act_layer(layer_class): + layer_class = layer_class.replace('_', '').lower() + if layer_class.startswith("batchnorm"): + layer = BatchNormAct2d + elif layer_class.startswith("groupnorm"): + layer = GroupNormAct + elif layer_class == "evonormbatch": + layer = EvoNormBatch2d + elif layer_class == "evonormsample": + layer = EvoNormSample2d + elif layer_class == "iabn" or layer_class == "inplaceabn": + layer = InplaceAbn + else: + assert False, "Invalid norm_act layer (%s)" % layer_class + return layer + + +def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): + layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu + assert len(layer_parts) in (1, 2) + layer = get_norm_act_layer(layer_parts[0]) + #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? + layer_instance = layer(num_features, apply_act=apply_act, **kwargs) + if jit: + layer_instance = torch.jit.script(layer_instance) + return layer_instance + + +def convert_norm_act(norm_layer, act_layer): + assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) + assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) + norm_act_kwargs = {} + + # unbind partial fn, so args can be rebound later + if isinstance(norm_layer, functools.partial): + norm_act_kwargs.update(norm_layer.keywords) + norm_layer = norm_layer.func + + if isinstance(norm_layer, str): + norm_act_layer = get_norm_act_layer(norm_layer) + elif norm_layer in _NORM_ACT_TYPES: + norm_act_layer = norm_layer + elif isinstance(norm_layer, types.FunctionType): + # if function type, must be a lambda/fn that creates a norm_act layer + norm_act_layer = norm_layer + else: + type_name = norm_layer.__name__.lower() + if type_name.startswith('batchnorm'): + norm_act_layer = BatchNormAct2d + elif type_name.startswith('groupnorm'): + norm_act_layer = GroupNormAct + else: + assert False, f"No equivalent norm_act layer for {type_name}" + + if norm_act_layer in _NORM_ACT_REQUIRES_ARG: + # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. + # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types + norm_act_kwargs.setdefault('act_layer', act_layer) + if norm_act_kwargs: + norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args + return norm_act_layer diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..6de9e3f729f7f1ca29d4511f6c64733d3169fbec --- /dev/null +++ b/timm/models/layers/drop.py @@ -0,0 +1,168 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Papers: +DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + +Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + +Code: +DropBlock impl inspired by two Tensorflow impl that I liked: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def drop_block_2d( + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. This layer has been tested on a few training + runs with success, but needs further validation and possibly optimization for lower runtime impact. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + # seed_drop_rate, the gamma parameter + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) + + if batchwise: + # one mask for whole batch, quite a bit faster + uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) + else: + uniform_noise = torch.rand_like(x) + block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) + else: + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +def drop_block_fast_2d( + x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, + gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + + DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid + block mask at edges. + """ + B, C, H, W = x.shape + total_size = W * H + clipped_block_size = min(block_size, min(W, H)) + gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (W - block_size + 1) * (H - block_size + 1)) + + if batchwise: + # one mask for whole batch, quite a bit faster + block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma + else: + # mask per batch element + block_mask = torch.rand_like(x) < gamma + block_mask = F.max_pool2d( + block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) + + if with_noise: + normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) + if inplace: + x.mul_(1. - block_mask).add_(normal_noise * block_mask) + else: + x = x * (1. - block_mask) + normal_noise * block_mask + else: + block_mask = 1 - block_mask + normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) + if inplace: + x.mul_(block_mask * normalize_scale) + else: + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + self.inplace = inplace + self.batchwise = batchwise + self.fast = fast # FIXME finish comparisons of fast vs not + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + if self.fast: + return drop_block_fast_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + else: + return drop_block_2d( + x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py new file mode 100644 index 0000000000000000000000000000000000000000..e29be6ac3c95bb61229cdcdd659ec89d541f1a53 --- /dev/null +++ b/timm/models/layers/eca.py @@ -0,0 +1,145 @@ +""" +ECA module from ECAnet + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package +by Chris Ha https://github.com/VRandme + +Original License: + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import math +from torch import nn +import torch.nn.functional as F + + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class EcaModule(nn.Module): + """Constructs an ECA module. + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): + super(EcaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + assert kernel_size % 2 == 1 + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv + y = self.conv(y) + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +EfficientChannelAttn = EcaModule # alias + + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without significantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channels: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + kernel_size: Adaptive selection of kernel size (default=3) + gamm: used in kernel_size calc, see above + beta: used in kernel_size calc, see above + act_layer: optional non-linearity after conv, enables conv bias, this is an experiment + gate_layer: gating non-linearity to use + """ + + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + super(CecaModule, self).__init__() + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) + has_act = act_layer is not None + assert kernel_size % 2 == 1 + + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding + self.padding = (kernel_size - 1) // 2 + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + y = x.mean((2, 3)).view(x.shape[0], 1, -1) + # Manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y, (self.padding, self.padding), mode='circular') + y = self.conv(y) + y = self.gate(y).view(x.shape[0], -1, 1, 1) + return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9023afd0e81dc8a76871d03141866217d59f4770 --- /dev/null +++ b/timm/models/layers/evo_norm.py @@ -0,0 +1,83 @@ +"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch + +An attempt at getting decent performing EvoNorms running in PyTorch. +While currently faster than other impl, still quite a ways off the built-in BN +in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). + +Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +import torch.nn as nn + + +class EvoNormBatch2d(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): + super(EvoNormBatch2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if apply_act: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + x_type = x.dtype + if self.training: + var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + n = x.numel() / x.shape[1] + self.running_var.copy_( + var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) + else: + var = self.running_var + + if self.apply_act: + v = self.v.to(dtype=x_type) + d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) + d = d.max((var + self.eps).sqrt().to(dtype=x_type)) + x = x / d + return x * self.weight + self.bias + + +class EvoNormSample2d(nn.Module): + def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None): + super(EvoNormSample2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.groups = groups + self.eps = eps + param_shape = (1, num_features, 1, 1) + self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) + if apply_act: + self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.apply_act: + nn.init.ones_(self.v) + + def forward(self, x): + assert x.dim() == 4, 'expected 4D input' + B, C, H, W = x.shape + assert C % self.groups == 0 + if self.apply_act: + n = x * (x * self.v).sigmoid() + x = x.reshape(B, self.groups, -1) + x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() + x = x.reshape(B, C, H, W) + return x * self.weight + self.bias diff --git a/timm/models/layers/gather_excite.py b/timm/models/layers/gather_excite.py new file mode 100644 index 0000000000000000000000000000000000000000..2d60dc961e2b5e135d38e290b8fa5820ef0fe18f --- /dev/null +++ b/timm/models/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2c82f3aa75f8fedad49305952667f9a6fd5363 --- /dev/null +++ b/timm/models/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..87cae8952cb7318cbec9bc513e7b2010ede7312d --- /dev/null +++ b/timm/models/layers/halo_attn.py @@ -0,0 +1,166 @@ +""" Halo Self Attention + +Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + +@misc{2103.12731, +Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and + Jonathon Shlens}, +Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones}, +Year = {2021}, +} + +Status: +This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. + +Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model +is extremely slow. Something isn't right. However, the models do appear to train and experimental +variants with attn in C4 and/or C5 stages are tolerable speed. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from typing import Tuple, List + +import torch +from torch import nn +import torch.nn.functional as F + +from .weight_init import trunc_normal_ + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, height, width, dim) + rel_k: (2 * window - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + rel_size = rel_k.shape[0] + win_size = (rel_size + 1) // 2 + + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, rel_size) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, rel_size - W]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, rel_size) + x = x_pad[:, :W, win_size - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + """ + def __init__(self, block_size, win_size, dim_head, scale): + """ + Args: + block_size (int): block size + win_size (int): neighbourhood window size + dim_head (int): attention head dim + scale (float): scale factor (for init) + """ + super().__init__() + self.block_size = block_size + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, BB, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(-1, self.block_size, self.block_size, self.dim_head) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, BB, HW, -1) + return rel_logits + + +class HaloAttn(nn.Module): + """ Halo Attention + + Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` + - https://arxiv.org/abs/2103.12731 + """ + def __init__( + self, dim, dim_out=None, stride=1, num_heads=8, dim_head=16, block_size=8, halo_size=3, qkv_bias=False): + super().__init__() + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.stride = stride + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_qk = num_heads * dim_head + self.dim_v = dim_out + self.block_size = block_size + self.halo_size = halo_size + self.win_size = block_size + halo_size * 2 # neighbourhood window size + self.scale = self.dim_head ** -0.5 + + # FIXME not clear if this stride behaviour is what the paper intended + # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving + # data in unfolded block form. I haven't wrapped my head around how that'd look. + self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + + self.pos_embed = PosEmbedRel( + block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + assert H % self.block_size == 0 and W % self.block_size == 0 + num_h_blocks = H // self.block_size + num_w_blocks = W // self.block_size + num_blocks = num_h_blocks * num_w_blocks + + q = self.q(x) + q = F.unfold(q, kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # B, num_heads * dim_head * block_size ** 2, num_blocks + q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + # B * num_heads, num_blocks, block_size ** 2, dim_head + + kv = self.kv(x) + # FIXME I 'think' this unfold does what I want it to, but I should investigate + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( + B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) + k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? + attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 + + attn_out = attn_logits.softmax(dim=-1) + attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks + attn_out = F.fold( + attn_out.reshape(B, -1, num_blocks), + (H // self.stride, W // self.stride), + kernel_size=self.block_size // self.stride, stride=self.block_size // self.stride) + # B, dim_out, H // stride, W // stride + return attn_out diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..cc54ca7f8a24de7e1ee0e5d27decf3e88c55ece3 --- /dev/null +++ b/timm/models/layers/helpers.py @@ -0,0 +1,31 @@ +""" Layer/Module Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +from itertools import repeat +import collections.abc + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v diff --git a/timm/models/layers/inplace_abn.py b/timm/models/layers/inplace_abn.py new file mode 100644 index 0000000000000000000000000000000000000000..3aae7cf563edfe6c9d2bf1a9f3994d911aacea23 --- /dev/null +++ b/timm/models/layers/inplace_abn.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn + +try: + from inplace_abn.functions import inplace_abn, inplace_abn_sync + has_iabn = True +except ImportError: + has_iabn = False + + def inplace_abn(x, weight, bias, running_mean, running_var, + training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): + raise ImportError( + "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") + + def inplace_abn_sync(**kwargs): + inplace_abn(**kwargs) + + +class InplaceAbn(nn.Module): + """Activated Batch Normalization + + This gathers a BatchNorm and an activation function in a single module + + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + act_layer : str or nn.Module type + Name or type of the activation functions, one of: `leaky_relu`, `elu` + act_param : float + Negative slope for the `leaky_relu` activation. + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, + act_layer="leaky_relu", act_param=0.01, drop_block=None): + super(InplaceAbn, self).__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + self.momentum = momentum + if apply_act: + if isinstance(act_layer, str): + assert act_layer in ('leaky_relu', 'elu', 'identity', '') + self.act_name = act_layer if act_layer else 'identity' + else: + # convert act layer passed as type to string + if act_layer == nn.ELU: + self.act_name = 'elu' + elif act_layer == nn.LeakyReLU: + self.act_name = 'leaky_relu' + elif act_layer == nn.Identity: + self.act_name = 'identity' + else: + assert False, f'Invalid act layer {act_layer.__name__} for IABN' + else: + self.act_name = 'identity' + self.act_param = act_param + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.constant_(self.running_mean, 0) + nn.init.constant_(self.running_var, 1) + if self.affine: + nn.init.constant_(self.weight, 1) + nn.init.constant_(self.bias, 0) + + def forward(self, x): + output = inplace_abn( + x, self.weight, self.bias, self.running_mean, self.running_var, + self.training, self.momentum, self.eps, self.act_name, self.act_param) + if isinstance(output, tuple): + output = output[0] + return output diff --git a/timm/models/layers/involution.py b/timm/models/layers/involution.py new file mode 100644 index 0000000000000000000000000000000000000000..ccdeefcbe96cabb9285e08408a447ce8a89435db --- /dev/null +++ b/timm/models/layers/involution.py @@ -0,0 +1,50 @@ +""" PyTorch Involution Layer + +Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py +Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255 +""" +import torch.nn as nn +from .conv_bn_act import ConvBnAct +from .create_conv2d import create_conv2d + + +class Involution(nn.Module): + + def __init__( + self, + channels, + kernel_size=3, + stride=1, + group_size=16, + rd_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(Involution, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.channels = channels + self.group_size = group_size + self.groups = self.channels // self.group_size + self.conv1 = ConvBnAct( + in_channels=channels, + out_channels=channels // rd_ratio, + kernel_size=1, + norm_layer=norm_layer, + act_layer=act_layer) + self.conv2 = self.conv = create_conv2d( + in_channels=channels // rd_ratio, + out_channels=kernel_size**2 * self.groups, + kernel_size=1, + stride=1) + self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity() + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride) + + def forward(self, x): + weight = self.conv2(self.conv1(self.avgpool(x))) + B, C, H, W = weight.shape + KK = int(self.kernel_size ** 2) + weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2) + out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W) + out = (weight * out).sum(dim=3).view(B, self.channels, H, W) + return out diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1027a18146f3171724a45d82107c77e1297e5c --- /dev/null +++ b/timm/models/layers/lambda_layer.py @@ -0,0 +1,84 @@ +""" Lambda Layer + +Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + +@misc{2102.08602, +Author = {Irwan Bello}, +Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention}, +Year = {2021}, +} + +Status: +This impl is a WIP. Code snippets in the paper were used as reference but +good chance some details are missing/wrong. + +I've only implemented local lambda conv based pos embeddings. + +For a PyTorch impl that includes other embedding options checkout +https://github.com/lucidrains/lambda-networks + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +from torch import nn +import torch.nn.functional as F + +from .weight_init import trunc_normal_ + + +class LambdaLayer(nn.Module): + """Lambda Layer w/ lambda conv position embedding + + Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` + - https://arxiv.org/abs/2102.08602 + """ + def __init__( + self, + dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + super().__init__() + self.dim = dim + self.dim_out = dim_out or dim + self.dim_k = dim_head # query depth 'k' + self.num_heads = num_heads + assert self.dim_out % num_heads == 0, ' should be divided by num_heads' + self.dim_v = self.dim_out // num_heads # value depth 'v' + self.r = r # relative position neighbourhood (lambda conv kernel size) + + self.qkv = nn.Conv2d( + dim, + num_heads * dim_head + dim_head + self.dim_v, + kernel_size=1, bias=qkv_bias) + self.norm_q = nn.BatchNorm2d(num_heads * dim_head) + self.norm_v = nn.BatchNorm2d(self.dim_v) + + # NOTE currently only supporting the local lambda convolutions for positional + self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + + def forward(self, x): + B, C, H, W = x.shape + M = H * W + + qkv = self.qkv(x) + q, k, v = torch.split(qkv, [ + self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) + q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K + v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V + k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M + + content_lam = k @ v # B, K, V + content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V + + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V + + out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = self.pool(out) + return out diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..38fe3380b067ea0b275c45ffd689afdeb4598f3c --- /dev/null +++ b/timm/models/layers/linear.py @@ -0,0 +1,19 @@ +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` + + Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting + weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. + """ + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/timm/models/layers/median_pool.py b/timm/models/layers/median_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..40bd71a7a3840aaebefd2af0a99605b845054cd7 --- /dev/null +++ b/timm/models/layers/median_pool.py @@ -0,0 +1,49 @@ +""" Median Pool +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F +from .helpers import to_2tuple, to_4tuple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = to_2tuple(kernel_size) + self.stride = to_2tuple(stride) + self.padding = to_4tuple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..fa0ce565c0a9d348d4e68165960fa77fcf7f70d7 --- /dev/null +++ b/timm/models/layers/mixed_conv2d.py @@ -0,0 +1,51 @@ +""" PyTorch Mixed Convolution + +Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = in_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..05d076527cfb6f15bcf5f2830fa36777abbc5a1e --- /dev/null +++ b/timm/models/layers/mlp.py @@ -0,0 +1,108 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.drop = nn.Dropout(drop) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..a537d60e6e575f093b93a146b83fb8e6398f6288 --- /dev/null +++ b/timm/models/layers/non_local_attn.py @@ -0,0 +1,143 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t: int): + B, C, block_size, block_size1 = x.shape + assert block_size == block_size1 + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py new file mode 100644 index 0000000000000000000000000000000000000000..433552b4cec1e901147d61b05ed6c68ea9c3799f --- /dev/null +++ b/timm/models/layers/norm.py @@ -0,0 +1,23 @@ +""" Normalization layers and wrappers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x): + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm2d(nn.LayerNorm): + """ Layernorm for channels of '2d' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__([num_channels, 1, 1]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py new file mode 100644 index 0000000000000000000000000000000000000000..02cabe88861f96345599b71a4a96edd8d115f6d3 --- /dev/null +++ b/timm/models/layers/norm_act.py @@ -0,0 +1,85 @@ +""" Normalization + Activation Layers +""" +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .create_act import get_act_layer + + +class BatchNormAct2d(nn.BatchNorm2d): + """BatchNorm + Activation + + This module performs BatchNorm + Activation in a manner that will remain backwards + compatible with weights trained with separate bn, act. This is why we inherit from BN + instead of composing it as a .bn member. + """ + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def _forward_jit(self, x): + """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function + """ + # exponential_average_factor is self.momentum set to + # (when it is available) only so that if gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + x = F.batch_norm( + x, self.running_mean, self.running_var, self.weight, self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, self.eps) + return x + + @torch.jit.ignore + def _forward_python(self, x): + return super(BatchNormAct2d, self).forward(x) + + def forward(self, x): + # FIXME cannot call parent forward() and maintain jit.script compatibility? + if torch.jit.is_scripting(): + x = self._forward_jit(x) + else: + x = self._forward_python(x) + x = self.act(x) + return x + + +class GroupNormAct(nn.GroupNorm): + # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.act(x) + return x diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py new file mode 100644 index 0000000000000000000000000000000000000000..34afc37c6c59c8782ad29c7a779f58177011f891 --- /dev/null +++ b/timm/models/layers/padding.py @@ -0,0 +1,56 @@ +""" Padding Helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +from typing import List, Tuple + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) + return x + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a2c8594ae45ff0909ad33851578260f21b92b1 --- /dev/null +++ b/timm/models/layers/patch_embed.py @@ -0,0 +1,39 @@ +""" Image to Patch Embedding using Conv2d + +A convolution based approach to patchifying a 2D image w/ embedding projection. + +Based on the impl in https://github.com/google-research/vision_transformer + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from torch import nn as nn + +from .helpers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2a1c44713e552be850865ada9623a1c3b1d836 --- /dev/null +++ b/timm/models/layers/pool2d_same.py @@ -0,0 +1,73 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List, Tuple, Optional + +from .helpers import to_2tuple +from .padding import pad_same, get_padding_value + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + # FIXME how to deal with count_include_pad vs not for external padding? + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride) + return F.avg_pool2d( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) + + +def max_pool2d_same( + x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + dilation: List[int] = (1, 1), ceil_mode: bool = False): + x = pad_same(x, kernel_size, stride, value=-float('inf')) + return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) + + +class MaxPool2dSame(nn.MaxPool2d): + """ Tensorflow like 'SAME' wrapper for 2D max pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) + + def forward(self, x): + x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) + return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) + + +def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): + stride = stride or kernel_size + padding = kwargs.pop('padding', '') + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) + if is_dynamic: + if pool_type == 'avg': + return AvgPool2dSame(kernel_size, stride=stride, **kwargs) + elif pool_type == 'max': + return MaxPool2dSame(kernel_size, stride=stride, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' + else: + if pool_type == 'avg': + return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + elif pool_type == 'max': + return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) + else: + assert False, f'Unsupported pool type {pool_type}' diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..f28b8d2e9ad49740081d4e1da5287e45f5ee76b8 --- /dev/null +++ b/timm/models/layers/selective_kernel.py @@ -0,0 +1,119 @@ +""" Selective Kernel Convolution/Attention + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + """ Selective Kernel Attention Module + + Selective Kernel attention mechanism factored out into its own module. + + """ + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = x.sum(1).mean((2, 3), keepdim=True) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernel(nn.Module): + + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + """ Selective Kernel Convolution Module + + As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. + + Largest change is the input split, which divides the input channels across each convolution path, this can + be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps + the parameter count from ballooning when the convolutions themselves don't have groups, but still provides + a noteworthy increase in performance over similar param count models without this attention layer. -Ross W + + Args: + in_channels (int): module input (feature) channel count + out_channels (int): module output (feature) channel count + kernel_size (int, list): kernel size for each convolution branch + stride (int): stride for convolutions + dilation (int): dilation for module as a whole, impacts dilation of each branch + groups (int): number of groups for each branch + rd_ratio (int, float): reduction factor for attention features + keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations + split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, + can be viewed as grouping by path, output expands to module out_channels count + drop_block (nn.Module): drop block module + act_layer (nn.Module): activation layer to use + norm_layer (nn.Module): batchnorm/norm layer to use + """ + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels + kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddcb4e62409492f898ab963027a9c2229b72f64 --- /dev/null +++ b/timm/models/layers/separable_conv.py @@ -0,0 +1,73 @@ +""" Depthwise Separable Conv Modules + +Basic DWS convs. Other variations of DWS exist with batch norm or activations between the +DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + +from .create_conv2d import create_conv2d +from .create_norm_act import convert_norm_act + + +class SeparableConvBnAct(nn.Module): + """ Separable Conv w/ trailing Norm and Activation + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, + apply_act=True, drop_block=None): + super(SeparableConvBnAct, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + norm_act_layer = convert_norm_act(norm_layer, act_layer) + self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + if self.bn is not None: + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + """ Separable Conv + """ + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, + channel_multiplier=1.0, pw_kernel_size=1): + super(SeparableConv2d, self).__init__() + + self.conv_dw = create_conv2d( + in_channels, int(in_channels * channel_multiplier), kernel_size, + stride=stride, dilation=dilation, padding=padding, depthwise=True) + + self.conv_pw = create_conv2d( + int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + + def forward(self, x): + x = self.conv_dw(x) + x = self.conv_pw(x) + return x diff --git a/timm/models/layers/space_to_depth.py b/timm/models/layers/space_to_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e8e0b2a486d51fe3e4ab0472d89b7f1b92e1dc --- /dev/null +++ b/timm/models/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, no_jit=False): + super().__init__() + if not no_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..dde601befa933727e169d9b84b035cf1f035e67c --- /dev/null +++ b/timm/models/layers/split_attn.py @@ -0,0 +1,85 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + +from .helpers import make_divisible + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) + """ + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels + self.radix = radix + self.drop_block = drop_block + mid_chs = out_channels * radix + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + + padding = kernel_size // 2 if padding is None else padding + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = x.sum(dim=1) + else: + x_gap = x + x_gap = x_gap.mean((2, 3), keepdim=True) + x_gap = self.fc1(x_gap) + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/timm/models/layers/split_batchnorm.py b/timm/models/layers/split_batchnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..830781b335161f8d6dd74c9458070bb1fa88a918 --- /dev/null +++ b/timm/models/layers/split_batchnorm.py @@ -0,0 +1,75 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +import torch.nn as nn + + +class SplitBatchNorm2d(torch.nn.BatchNorm2d): + + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, + track_running_stats=True, num_splits=2): + super().__init__(num_features, eps, momentum, affine, track_running_stats) + assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' + self.num_splits = num_splits + self.aux_bn = nn.ModuleList([ + nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) + + def forward(self, input: torch.Tensor): + if self.training: # aux BN only relevant while training + split_size = input.shape[0] // self.num_splits + assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" + split_input = input.split(split_size) + x = [super().forward(split_input[0])] + for i, a in enumerate(self.aux_bn): + x.append(a(split_input[i + 1])) + return torch.cat(x, dim=0) + else: + return super().forward(input) + + +def convert_splitbn_model(module, num_splits=2): + """ + Recursively traverse module and its children to replace all instances of + ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. + Args: + module (torch.nn.Module): input module + num_splits: number of separate batchnorm layers to split input across + Example:: + >>> # model is an instance of torch.nn.Module + >>> model = timm.models.convert_splitbn_model(model, num_splits=2) + """ + mod = module + if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): + return module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + mod = SplitBatchNorm2d( + module.num_features, module.eps, module.momentum, module.affine, + module.track_running_stats, num_splits=num_splits) + mod.running_mean = module.running_mean + mod.running_var = module.running_var + mod.num_batches_tracked = module.num_batches_tracked + if module.affine: + mod.weight.data = module.weight.data.clone().detach() + mod.bias.data = module.bias.data.clone().detach() + for aux in mod.aux_bn: + aux.running_mean = module.running_mean.clone() + aux.running_var = module.running_var.clone() + aux.num_batches_tracked = module.num_batches_tracked.clone() + if module.affine: + aux.weight.data = module.weight.data.clone().detach() + aux.bias.data = module.bias.data.clone().detach() + for name, child in module.named_children(): + mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) + del module + return mod diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py new file mode 100644 index 0000000000000000000000000000000000000000..e5da29ef166de27705cc160f729b6e3b45061c59 --- /dev/null +++ b/timm/models/layers/squeeze_excite.py @@ -0,0 +1,74 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccc16e1197a41440add454a40ed3146ed0b6211 --- /dev/null +++ b/timm/models/layers/std_conv.py @@ -0,0 +1,133 @@ +""" Convolution with Weight Standardization (StdConv and ScaledStdConv) + +StdConv: +@article{weightstandardization, + author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, + title = {Weight Standardization}, + journal = {arXiv preprint arXiv:1903.10520}, + year = {2019}, +} +Code: https://github.com/joe-siyuan-qiao/WeightStandardization + +ScaledStdConv: +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Hacked together by / copyright Ross Wightman, 2021. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .padding import get_padding, get_padding_value, pad_same + + +class StdConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=False, eps=1e-6): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=False, eps=1e-6): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=None, + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps + + def forward(self, x): + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class ScaledStdConv2dSame(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + + NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor. + """ + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding='SAME', + dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0): + padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, + groups=groups, bias=bias) + self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init)) + self.scale = gamma * self.weight[0].numel() ** -0.5 + self.same_pad = is_dynamic + self.eps = eps + + def forward(self, x): + if self.same_pad: + x = pad_same(x, self.kernel_size, self.stride, self.dilation) + weight = F.batch_norm( + self.weight.view(1, self.out_channels, -1), None, None, + weight=(self.gain * self.scale).view(-1), + training=True, momentum=0., eps=self.eps).reshape_as(self.weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..02131bbc4dec3f726a23da0444bec108f9c3903a --- /dev/null +++ b/timm/models/layers/swin_attn.py @@ -0,0 +1,182 @@ +""" Shifted Window Attn + +This is a WIP experiment to apply windowed attention from the Swin Transformer +to a stand-alone module for use as an attn block in conv nets. + +Based on original swin window code at https://github.com/microsoft/Swin-Transformer +Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf +""" +from typing import Optional + +import torch +import torch.nn as nn + +from .drop import DropPath +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def window_partition(x, win_size: int): + """ + Args: + x: (B, H, W, C) + win_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) + return windows + + +def window_reverse(windows, win_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + win_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / win_size / win_size)) + x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + win_size (int): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + """ + + def __init__( + self, dim, dim_out=None, feat_size=None, stride=1, win_size=8, shift_size=None, num_heads=8, + qkv_bias=True, attn_drop=0.): + + super().__init__() + self.dim_out = dim_out or dim + self.feat_size = to_2tuple(feat_size) + self.win_size = win_size + self.shift_size = shift_size or win_size // 2 + if min(self.feat_size) <= win_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.win_size = min(self.feat_size) + assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-window_size" + self.num_heads = num_heads + head_dim = self.dim_out // num_heads + self.scale = head_dim ** -0.5 + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.feat_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = ( + slice(0, -self.win_size), + slice(-self.win_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.win_size) # num_win, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.win_size * self.win_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + self.register_buffer("attn_mask", attn_mask) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + # 2 * Wh - 1 * 2 * Ww - 1, nH + torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.win_size) + coords_w = torch.arange(self.win_size) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.win_size - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.win_size - 1 + relative_coords[:, :, 0] *= 2 * self.win_size - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.softmax = nn.Softmax(dim=-1) + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self, x): + B, C, H, W = x.shape + x = x.permute(0, 2, 3, 1) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + win_size_sq = self.win_size * self.win_size + x_windows = window_partition(shifted_x, self.win_size) # num_win * B, window_size, window_size, C + x_windows = x_windows.view(-1, win_size_sq, C) # num_win * B, window_size*window_size, C + BW, N, _ = x_windows.shape + + qkv = self.qkv(x_windows) + qkv = qkv.reshape(BW, N, 3, self.num_heads, self.dim_out // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(win_size_sq, win_size_sq, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh * Ww, Wh * Ww + attn = attn + relative_position_bias.unsqueeze(0) + if self.attn_mask is not None: + num_win = self.attn_mask.shape[0] + attn = attn.view(B, num_win, self.num_heads, N, N) + self.attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(BW, N, self.dim_out) + + # merge windows + x = x.view(-1, self.win_size, self.win_size, self.dim_out) + shifted_x = window_reverse(x, self.win_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H, W, self.dim_out).permute(0, 3, 1, 2) + x = self.pool(x) + return x + + diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..98c0bf53a74eb954a25b96d84712ef974eb8ea3b --- /dev/null +++ b/timm/models/layers/test_time_pool.py @@ -0,0 +1,52 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import logging +from torch import nn +import torch.nn.functional as F + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + + +_logger = logging.getLogger(__name__) + + +class TestTimePoolHead(nn.Module): + def __init__(self, base, original_pool=7): + super(TestTimePoolHead, self).__init__() + self.base = base + self.original_pool = original_pool + base_fc = self.base.get_classifier() + if isinstance(base_fc, nn.Conv2d): + self.fc = base_fc + else: + self.fc = nn.Conv2d( + self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) + self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) + self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) + self.base.reset_classifier(0) # delete original fc layer + + def forward(self, x): + x = self.base.forward_features(x) + x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) + x = self.fc(x) + x = adaptive_avgmax_pool2d(x, 1) + return x.view(x.size(0), -1) + + +def apply_test_time_pool(model, config, use_test_size=True): + test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False + if use_test_size and 'test_input_size' in model.default_cfg: + df_input_size = model.default_cfg['test_input_size'] + else: + df_input_size = model.default_cfg['input_size'] + if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: + _logger.info('Target input size %s > pretrained default %s, using test time pooling' % + (str(config['input_size'][-2:]), str(df_input_size[-2:]))) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/timm/models/layers/trace_utils.py b/timm/models/layers/trace_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..83970729e628b525d24162f5df37ee5bc253438f --- /dev/null +++ b/timm/models/layers/trace_utils.py @@ -0,0 +1,13 @@ +try: + from torch import _assert +except ImportError: + def _assert(condition: bool, message: str): + assert condition, message + + +def _float_to_int(x: float) -> int: + """ + Symbolic tracing helper to substitute for inbuilt `int`. + Hint: Inbuilt `int` can't accept an argument of type `Proxy` + """ + return int(x) diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..305a2fd067e7104e58b9b5ff70d96e89a06050af --- /dev/null +++ b/timm/models/layers/weight_init.py @@ -0,0 +1,89 @@ +import torch +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') diff --git a/timm/models/levit.py b/timm/models/levit.py new file mode 100644 index 0000000000000000000000000000000000000000..9987e4ba987ea66d99b9627f59e55526f9ed8655 --- /dev/null +++ b/timm/models/levit.py @@ -0,0 +1,563 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications by/coyright Copyright 2021 Ross Wightman +""" + +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools +from copy import deepcopy +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_ntuple, get_act_layer +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), + **kwargs + } + + +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), +) + +__all__ = ['Levit'] + + +@register_model +def levit_128s(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_256(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +@register_model +def levit_384(pretrained=False, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) + + +class ConvNorm(nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = nn.BatchNorm2d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', nn.Linear(a, b, bias=False)) + bn = nn.BatchNorm1d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', nn.BatchNorm1d(a)) + l = nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): + super().__init__() + + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm + h = self.dh + nh_kd * 2 + self.qkv = ln_layer(dim, h, resolution=resolution) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + else: + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = self.d * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) + + h = self.dh + nh_kd + self.kv = ln_layer(in_dim, h, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, nh_kd, resolution=resolution_)) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + self.ab = {} # per-device attention_biases cache + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class Levit(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + + NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems + w/ train scripts that don't take tuple outputs, + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, + hybrid_backbone=None, + down_ops=None, + act_layer='hard_swish', + attn_act_layer='hard_swish', + distillation=True, + use_conv=False, + drop_rate=0., + drop_path_rate=0.): + super().__init__() + act_layer = get_act_layer(act_layer) + attn_act_layer = get_act_layer(attn_act_layer) + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + N = len(embed_dim) + assert len(depth) == len(num_heads) == N + key_dim = to_ntuple(N)(key_dim) + attn_ratio = to_ntuple(N)(attn_ratio) + mlp_ratio = to_ntuple(N)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) + self.distillation = distillation + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm + + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) + + self.blocks = [] + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), + drop_path_rate)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path_rate)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_, use_conv=use_conv)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path_rate)) + self.blocks = nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distillation: + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def get_classifier(self): + if self.head_dist is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool='', distillation=None): + self.num_classes = num_classes + self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + if distillation is not None: + self.distillation = distillation + if self.distillation: + self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + else: + self.head_dist = None + + def forward_features(self, x): + x = self.patch_embed(x) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x), self.head_dist(x) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + Levit, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) + #if fuse: + # utils.replace_batchnorm(model) + return model + diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py new file mode 100644 index 0000000000000000000000000000000000000000..f128b9c916a68839a11fbd3ed6409743ac7a6a4c --- /dev/null +++ b/timm/models/mlp_mixer.py @@ -0,0 +1,625 @@ +""" MLP-Mixer, ResMLP, and gMLP in PyTorch + +This impl originally based on MLP-Mixer paper. + +Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py + +Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + +@article{tolstikhin2021, + title={MLP-Mixer: An all-MLP Architecture for Vision}, + author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, + Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey}, + journal={arXiv preprint arXiv:2105.01601}, + year={2021} +} + +Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP + +Code: https://github.com/facebookresearch/deit +Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 +@misc{touvron2021resmlp, + title={ResMLP: Feedforward networks for image classification with data-efficient training}, + author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and + Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and HervĂ© JĂ©gou}, + year={2021}, + eprint={2105.03404}, +} + +Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 +@misc{liu2021pay, + title={Pay Attention to MLPs}, + author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le}, + year={2021}, + eprint={2105.08050}, +} + +A thank you to paper authors for releasing code and weights. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply +from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + mixer_s32_224=_cfg(), + mixer_s16_224=_cfg(), + mixer_b32_224=_cfg(), + mixer_b16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', + ), + mixer_b16_224_in21k=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth', + num_classes=21843 + ), + mixer_l32_224=_cfg(), + mixer_l16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth', + ), + mixer_l16_224_in21k=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', + num_classes=21843 + ), + + # Mixer ImageNet-21K-P pretraining + mixer_b16_224_miil_in21k=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + mixer_b16_224_miil=_cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), + + gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + gmixer_24_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth', + #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_big_24_distilled_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_big_24_224_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + gmlp_ti16_224=_cfg(), + gmlp_s16_224=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth', + ), + gmlp_b16_224=_cfg(), +) + + +class MixerBlock(nn.Module): + """ Residual Block w/ token mixing and channel MLPs + Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + def __init__( + self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] + self.norm1 = norm_layer(dim) + self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class Affine(nn.Module): + def __init__(self, dim): + super().__init__() + self.alpha = nn.Parameter(torch.ones((1, 1, dim))) + self.beta = nn.Parameter(torch.zeros((1, 1, dim))) + + def forward(self, x): + return torch.addcmul(self.beta, self.alpha, x) + + +class ResBlock(nn.Module): + """ Residual MLP block w/ LayerScale and Affine 'norm' + + Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + def __init__( + self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine, + act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.): + super().__init__() + channel_dim = int(dim * mlp_ratio) + self.norm1 = norm_layer(dim) + self.linear_tokens = nn.Linear(seq_len, seq_len) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x))) + return x + + +class SpatialGatingUnit(nn.Module): + """ Spatial Gating Unit + + Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm): + super().__init__() + gate_dim = dim // 2 + self.norm = norm_layer(gate_dim) + self.proj = nn.Linear(seq_len, seq_len) + + def init_weights(self): + # special init for the projection gate, called as override by base model init + nn.init.normal_(self.proj.weight, std=1e-6) + nn.init.ones_(self.proj.bias) + + def forward(self, x): + u, v = x.chunk(2, dim=-1) + v = self.norm(v) + v = self.proj(v.transpose(-1, -2)) + return u * v.transpose(-1, -2) + + +class SpatialGatingBlock(nn.Module): + """ Residual Block w/ Spatial Gating + + Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + def __init__( + self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + channel_dim = int(dim * mlp_ratio) + self.norm = norm_layer(dim) + sgu = partial(SpatialGatingUnit, seq_len=seq_len) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.mlp_channels(self.norm(x))) + return x + + +class MlpMixer(nn.Module): + + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + patch_size=16, + num_blocks=8, + embed_dim=512, + mlp_ratio=(0.5, 4.0), + block_layer=MixerBlock, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + ): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) + # FIXME drop_path (stochastic depth scaling rule or all the same?) + self.blocks = nn.Sequential(*[ + block_layer( + embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) + for _ in range(num_blocks)]) + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=1) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + """ Mixer weight initialization (trying to match Flax defaults) + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # like MLP init in vit (my original init) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + # NOTE if a parent module contains init_weights method, it can override the init of the + # child modules as this will be called in depth-first order. + module.init_weights() + + +def checkpoint_filter_fn(state_dict, model): + """ Remap checkpoints if needed """ + if 'patch_embed.proj.weight' in state_dict: + # Remap FB ResMlp models -> timm + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('patch_embed.', 'stem.') + k = k.replace('attn.', 'linear_tokens.') + k = k.replace('mlp.', 'mlp_channels.') + k = k.replace('gamma_', 'ls') + if k.endswith('.alpha') or k.endswith('.beta'): + v = v.reshape(1, 1, -1) + out_dict[k] = v + return out_dict + return state_dict + + +def _create_mixer(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for MLP-Mixer models.') + + model = build_model_with_cfg( + MlpMixer, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def mixer_s32_224(pretrained=False, **kwargs): + """ Mixer-S/32 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) + model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_s16_224(pretrained=False, **kwargs): + """ Mixer-S/16 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs) + model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b32_224(pretrained=False, **kwargs): + """ Mixer-B/32 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l32_224(pretrained=False, **kwargs): + """ Mixer-L/32 224x224. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs) + model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l16_224(pretrained=False, **kwargs): + """ Mixer-L/16 224x224. ImageNet-1k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) + model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_l16_224_in21k(pretrained=False, **kwargs): + """ Mixer-L/16 224x224. ImageNet-21k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) + model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_miil(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) + return model + + +@register_model +def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): + """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) + model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmixer_12_224(pretrained=False, **kwargs): + """ Glu-Mixer-12 224x224 + Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0), + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmixer_24_224(pretrained=False, **kwargs): + """ Glu-Mixer-24 224x224 + Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0), + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_distilled_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_distilled_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_distilled_224(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs): + """ ResMLP-B-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4, + block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_ti16_224(pretrained=False, **kwargs): + """ gMLP-Tiny + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_s16_224(pretrained=False, **kwargs): + """ gMLP-Small + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_b16_224(pretrained=False, **kwargs): + """ gMLP-Base + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..f810eb8281510b3c3445ce27809cee613626aff6 --- /dev/null +++ b/timm/models/mobilenetv3.py @@ -0,0 +1,562 @@ + +""" MobileNet V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from functools import partial +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT +from .features import FeatureInfo, FeatureHooks +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid +from .registry import register_model + +__all__ = ['MobileNetV3', 'MobileNetV3Features'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv_stem', 'classifier': 'classifier', + **kwargs + } + + +default_cfgs = { + 'mobilenetv3_large_075': _cfg(url=''), + 'mobilenetv3_large_100': _cfg( + interpolation='bicubic', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), + 'mobilenetv3_large_100_miil': _cfg( + interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_78_0.pth'), + 'mobilenetv3_large_100_miil_in21k': _cfg( + interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_in21k_miil.pth', num_classes=11221), + 'mobilenetv3_small_075': _cfg(url=''), + 'mobilenetv3_small_100': _cfg(url=''), + + 'mobilenetv3_rw': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + interpolation='bicubic'), + + 'tf_mobilenetv3_large_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_large_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_075': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_100': _cfg( + url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + 'tf_mobilenetv3_small_minimal_100': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), + + 'fbnetv3_b': _cfg(), + 'fbnetv3_d': _cfg(), + 'fbnetv3_g': _cfg(), +} + + +class MobileNetV3(nn.Module): + """ MobiletNet-V3 + + Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific + 'efficient head', where global pooling is done before the head convolution without a final batch-norm + layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True, + round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): + super(MobileNetV3, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.num_classes = num_classes + self.num_features = num_features + self.drop_rate = drop_rate + + # Stem + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = builder.features + head_chs = builder.in_chs + + # Head + Pooling + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + num_pooled_chs = head_chs * self.global_pool.feat_mult() + self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + efficientnet_init_weights(self) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([self.global_pool, self.conv_head, self.act2]) + layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.flatten(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +class MobileNetV3Features(nn.Module): + """ MobileNetV3 Feature Extractor + + A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation + and object detection models. + """ + + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True, + act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + super(MobileNetV3Features, self).__init__() + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite + self.drop_rate = drop_rate + + # Stem + stem_size = round_chs_fn(stem_size) + self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size) + self.act1 = act_layer(inplace=True) + + # Middle stages (IR/ER/DS Blocks) + builder = EfficientNetBuilder( + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, + drop_path_rate=drop_path_rate, feature_location=feature_location) + self.blocks = nn.Sequential(*builder(stem_size, block_args)) + self.feature_info = FeatureInfo(builder.features, out_indices) + self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} + + efficientnet_init_weights(self) + + # Register feature extraction hooks with FeatureHooks helper + self.feature_hooks = None + if feature_location != 'bottleneck': + hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) + self.feature_hooks = FeatureHooks(hooks, self.named_modules()) + + def forward(self, x) -> List[torch.Tensor]: + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + if self.feature_hooks is None: + features = [] + if 0 in self._stage_out_idx: + features.append(x) # add stem out + for i, b in enumerate(self.blocks): + x = b(x) + if i + 1 in self._stage_out_idx: + features.append(x) + return features + else: + self.blocks(x) + out = self.feature_hooks.get_output(x.device) + return list(out.values()) + + +def _create_mnv3(variant, pretrained=False, **kwargs): + features_only = False + model_cls = MobileNetV3 + kwargs_filter = None + if kwargs.pop('features_only', False): + features_only = True + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') + model_cls = MobileNetV3Features + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'), + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = resolve_act_layer(kwargs, 'relu') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = resolve_act_layer(kwargs, 'hard_swish') + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels) + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNetV3 + Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining` + - https://arxiv.org/abs/2006.02049 + FIXME untested, this is a preliminary impl of some FBNet-V3 variants. + """ + vl = variant.split('_')[-1] + if vl in ('a', 'b'): + stem_size = 16 + arch_def = [ + ['ds_r2_k3_s1_e1_c16'], + ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'], + ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'], + ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'], + ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'], + ['cn_r1_k1_s1_c1344'], + ] + elif vl == 'd': + stem_size = 24 + arch_def = [ + ['ds_r2_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'], + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'], + ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'], + ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'], + ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'], + ['cn_r1_k1_s1_c1440'], + ] + elif vl == 'g': + stem_size = 32 + arch_def = [ + ['ds_r3_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'], + ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'], + ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'], + ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'], + ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'], + ['cn_r1_k1_s1_c1728'], + ] + else: + raise NotImplemented + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95) + se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn) + act_layer = resolve_act_layer(kwargs, 'hard_swish') + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1984, + head_bias=False, + stem_size=stem_size, + round_chs_fn=round_chs_fn, + se_from_exp=False, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=act_layer, + se_layer=se_layer, + **kwargs, + ) + model = _create_mnv3(variant, pretrained, **model_kwargs) + return model + + +@register_model +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100_miil(pretrained=False, **kwargs): + """ MobileNet V3 + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs): + """ MobileNet V3, 21k pretraining + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet V3 """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_b(pretrained=False, **kwargs): + """ FBNetV3-B """ + model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_d(pretrained=False, **kwargs): + """ FBNetV3-D """ + model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs) + return model + + +@register_model +def fbnetv3_g(pretrained=False, **kwargs): + """ FBNetV3-G """ + model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs) + return model diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2afe82c3f374dd4790bc940289c8e3794497fbbc --- /dev/null +++ b/timm/models/nasnet.py @@ -0,0 +1,567 @@ +""" NasNet-A (Large) + nasnetalarge implementation grabbed from Cadene's pretrained models + https://github.com/Cadene/pretrained-models.pytorch +""" +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .registry import register_model + +__all__ = ['NASNetALarge'] + +default_cfgs = { + 'nasnetalarge': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv0.conv', + 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + }, +} + + +class ActConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + + def forward(self, x): + x = self.act(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=0) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False): + super(BranchSeparables, self).__init__() + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1) + self.act_2 = nn.ReLU(inplace=True) + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=pad_type) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1) + + def forward(self, x): + x = self.act_1(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.act_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class CellStem0(nn.Module): + def __init__(self, stem_size, num_channels=42, pad_type=''): + super(CellStem0, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x): + x1 = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x1) + x_comb_iter_0_right = self.comb_iter_0_right(x) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x1) + x_comb_iter_1_right = self.comb_iter_1_right(x) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x1) + x_comb_iter_2_right = self.comb_iter_2_right(x) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x1) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem1(nn.Module): + + def __init__(self, stem_size, num_channels, pad_type=''): + super(CellStem1, self).__init__() + self.num_channels = num_channels + self.stem_size = stem_size + self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1) + + self.act = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1) + + self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x_conv0, x_stem_0): + x_left = self.conv_1x1(x_stem_0) + + x_relu = self.act(x_conv0) + # path 1 + x_path1 = self.path_1(x_relu) + # path 2 + x_path2 = self.path_2(x_relu) + # final path + x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_right) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_left) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_left) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class FirstCell(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(FirstCell, self).__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1) + + self.act = nn.ReLU() + self.path_1 = nn.Sequential() + self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.path_2 = nn.Sequential() + self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1))) + self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)) + self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False)) + + self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + def forward(self, x, x_prev): + x_relu = self.act(x_prev) + x_path1 = self.path_1(x_relu) + x_path2 = self.path_2(x_relu) + x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NormalCell(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(NormalCell, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + + self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_left) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_left + + x_comb_iter_3_left = self.comb_iter_3_left(x_left) + x_comb_iter_3_right = self.comb_iter_3_right(x_left) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_right + + x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell0(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(ReductionCell0, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class ReductionCell1(nn.Module): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(ReductionCell1, self).__init__() + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type) + self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type) + + self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type) + self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type) + + self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type) + self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type) + + def forward(self, x, x_prev): + x_left = self.conv_prev_1x1(x_prev) + x_right = self.conv_1x1(x) + + x_comb_iter_0_left = self.comb_iter_0_left(x_right) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_left) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_left) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0) + x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1 + + x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0) + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class NASNetALarge(nn.Module): + """NASNetALarge (6 @ 4032) """ + + def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2, + num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'): + super(NASNetALarge, self).__init__() + self.num_classes = num_classes + self.stem_size = stem_size + self.num_features = num_features + self.channel_multiplier = channel_multiplier + self.drop_rate = drop_rate + assert output_stride == 32 + + channels = self.num_features // 24 + # 24 is default value for the architecture + + self.conv0 = ConvBnAct( + in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) + + self.cell_stem_0 = CellStem0( + self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type) + self.cell_stem_1 = CellStem1( + self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type) + + self.cell_0 = FirstCell( + in_chs_left=channels, out_chs_left=channels // 2, + in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_1 = NormalCell( + in_chs_left=2 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_2 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_3 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_4 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + self.cell_5 = NormalCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type) + + self.reduction_cell_0 = ReductionCell0( + in_chs_left=6 * channels, out_chs_left=2 * channels, + in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_6 = FirstCell( + in_chs_left=6 * channels, out_chs_left=channels, + in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_7 = NormalCell( + in_chs_left=8 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_8 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_9 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_10 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + self.cell_11 = NormalCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type) + + self.reduction_cell_1 = ReductionCell1( + in_chs_left=12 * channels, out_chs_left=4 * channels, + in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_12 = FirstCell( + in_chs_left=12 * channels, out_chs_left=2 * channels, + in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_13 = NormalCell( + in_chs_left=16 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_14 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_15 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_16 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.cell_17 = NormalCell( + in_chs_left=24 * channels, out_chs_left=4 * channels, + in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type) + self.act = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv0'), + dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'), + dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'), + dict(num_chs=4032, reduction=32, module='act'), + ] + + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x_conv0 = self.conv0(x) + + x_stem_0 = self.cell_stem_0(x_conv0) + x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0) + + x_cell_0 = self.cell_0(x_stem_1, x_stem_0) + x_cell_1 = self.cell_1(x_cell_0, x_stem_1) + x_cell_2 = self.cell_2(x_cell_1, x_cell_0) + x_cell_3 = self.cell_3(x_cell_2, x_cell_1) + x_cell_4 = self.cell_4(x_cell_3, x_cell_2) + x_cell_5 = self.cell_5(x_cell_4, x_cell_3) + + x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4) + x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4) + x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0) + x_cell_8 = self.cell_8(x_cell_7, x_cell_6) + x_cell_9 = self.cell_9(x_cell_8, x_cell_7) + x_cell_10 = self.cell_10(x_cell_9, x_cell_8) + x_cell_11 = self.cell_11(x_cell_10, x_cell_9) + + x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10) + x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10) + x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1) + x_cell_14 = self.cell_14(x_cell_13, x_cell_12) + x_cell_15 = self.cell_15(x_cell_14, x_cell_13) + x_cell_16 = self.cell_16(x_cell_15, x_cell_14) + x_cell_17 = self.cell_17(x_cell_16, x_cell_15) + x = self.act(x_cell_17) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_nasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + NASNetALarge, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model + **kwargs) + + +@register_model +def nasnetalarge(pretrained=False, **kwargs): + """NASNet-A large model architecture. + """ + model_kwargs = dict(pad_type='same', **kwargs) + return _create_nasnet('nasnetalarge', pretrained, **model_kwargs) diff --git a/timm/models/nest.py b/timm/models/nest.py new file mode 100644 index 0000000000000000000000000000000000000000..22cf609916393761380575da8844b6780a6adf26 --- /dev/null +++ b/timm/models/nest.py @@ -0,0 +1,465 @@ +""" Nested Transformer (NesT) in PyTorch + +A PyTorch implement of Aggregating Nested Transformers as described in: + +'Aggregating Nested Transformers' + - https://arxiv.org/abs/2105.12723 + +The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights +have been converted with convert/convert_nest_flax.py + +Acknowledgments: +* The paper authors for sharing their research, code, and model weights +* Ross Wightman's existing code off which I based this + +Copyright 2021 Alexander Soare +""" + +import collections.abc +import logging +import math +from functools import partial + +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_ +from .layers import _assert +from .layers import create_conv2d, create_pool2d, to_ntuple +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14], + 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # (weights from official Google JAX impl) + 'nest_base': _cfg(), + 'nest_small': _cfg(), + 'nest_tiny': _cfg(), + 'jx_nest_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'), + 'jx_nest_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'), + 'jx_nest_tiny': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'), +} + + +class Attention(nn.Module): + """ + This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with + an extra "image block" dim + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + """ + x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim) + """ + B, T, N, C = x.shape + # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head) + qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, H, T, N, C'), permute -> (B, T, N, C', H) + x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x # (B, T, N, C) + + +class TransformerLayer(nn.Module): + """ + This is much like `.vision_transformer.Block` but: + - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks") + - Uses modified Attention layer that handles the "block" dimension + """ + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + y = self.norm1(x) + x = x + self.drop_path(self.attn(y)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ConvPool(nn.Module): + def __init__(self, in_channels, out_channels, norm_layer, pad_type=''): + super().__init__() + self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True) + self.norm = norm_layer(out_channels) + self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=pad_type) + + def forward(self, x): + """ + x is expected to have shape (B, C, H, W) + """ + _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims') + _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims') + x = self.conv(x) + # Layer norm done over channel dim only + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.pool(x) + return x # (B, C, H//2, W//2) + + +def blockify(x, block_size: int): + """image to blocks + Args: + x (Tensor): with shape (B, H, W, C) + block_size (int): edge length of a single square block in units of H, W + """ + B, H, W, C = x.shape + _assert(H % block_size == 0, '`block_size` must divide input height evenly') + _assert(W % block_size == 0, '`block_size` must divide input width evenly') + grid_height = H // block_size + grid_width = W // block_size + x = x.reshape(B, grid_height, block_size, grid_width, block_size, C) + x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C) + return x # (B, T, N, C) + + +@register_notrace_function # reason: int receives Proxy +def deblockify(x, block_size: int): + """blocks to image + Args: + x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block + block_size (int): edge length of a single square block in units of desired H, W + """ + B, T, _, C = x.shape + grid_size = int(math.sqrt(T)) + height = width = grid_size * block_size + x = x.reshape(B, grid_size, grid_size, block_size, block_size, C) + x = x.transpose(2, 3).reshape(B, height, width, C) + return x # (B, H, W, C) + + +class NestLevel(nn.Module): + """ Single hierarchical level of a Nested Transformer + """ + def __init__( + self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None, + mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[], + norm_layer=None, act_layer=None, pad_type=''): + super().__init__() + self.block_size = block_size + self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim)) + + if prev_embed_dim is not None: + self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type) + else: + self.pool = nn.Identity() + + # Transformer encoder + if len(drop_path_rates): + assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers' + self.transformer_encoder = nn.Sequential(*[ + TransformerLayer( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + + def forward(self, x): + """ + expects x as (B, C, H, W) + """ + x = self.pool(x) + x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer + x = blockify(x, self.block_size) # (B, T, N, C') + x = x + self.pos_embed + x = self.transformer_encoder(x) # (B, T, N, C') + x = deblockify(x, self.block_size) # (B, H', W', C') + # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage + return x.permute(0, 3, 1, 2) # (B, C, H', W') + + +class Nest(nn.Module): + """ Nested Transformer (NesT) + + A PyTorch impl of : `Aggregating Nested Transformers` + - https://arxiv.org/abs/2105.12723 + """ + + def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512), + num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None, + pad_type='', weight_init='', global_pool='avg'): + """ + Args: + img_size (int, tuple): input image size + in_chans (int): number of input channels + patch_size (int): patch size + num_levels (int): number of block hierarchies (T_d in the paper) + embed_dims (int, tuple): embedding dimensions of each level + num_heads (int, tuple): number of attention heads for each level + depths (int, tuple): number of transformer layers for each level + num_classes (int): number of classes for classification head + mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer for transformer layers + act_layer: (nn.Module): activation layer in MLP of transformer layers + pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME + weight_init: (str): weight init scheme + global_pool: (str): type of pooling operation to apply to final feature map + + Notes: + - Default values follow NesT-B from the original Jax code. + - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`. + - For those following the paper, Table A1 may have errors! + - https://github.com/google-research/nested-transformer/issues/2 + """ + super().__init__() + + for param_name in ['embed_dims', 'num_heads', 'depths']: + param_value = locals()[param_name] + if isinstance(param_value, collections.abc.Sequence): + assert len(param_value) == num_levels, f'Require `len({param_name}) == num_levels`' + + embed_dims = to_ntuple(num_levels)(embed_dims) + num_heads = to_ntuple(num_levels)(num_heads) + depths = to_ntuple(num_levels)(depths) + self.num_classes = num_classes + self.num_features = embed_dims[-1] + self.feature_info = [] + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.drop_rate = drop_rate + self.num_levels = num_levels + if isinstance(img_size, collections.abc.Sequence): + assert img_size[0] == img_size[1], 'Model only handles square inputs' + img_size = img_size[0] + assert img_size % patch_size == 0, '`patch_size` must divide `img_size` evenly' + self.patch_size = patch_size + + # Number of blocks at each level + self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist() + assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \ + 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`' + + # Block edge size in units of patches + # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the + # number of blocks along edge of image + self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0])) + + # Patch embedding + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False) + self.num_patches = self.patch_embed.num_patches + self.seq_length = self.num_patches // self.num_blocks[0] + + # Build up each hierarchical level + levels = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + prev_dim = None + curr_stride = 4 + for i in range(len(self.num_blocks)): + dim = embed_dims[i] + levels.append(NestLevel( + self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim, + mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type)) + self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')] + prev_dim = dim + curr_stride *= 2 + self.levels = nn.Sequential(*levels) + + # Final normalization layer + self.norm = norm_layer(embed_dims[-1]) + + # Classifier + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + for level in self.levels: + trunc_normal_(level.pos_embed, std=.02, a=-2, b=2) + named_apply(partial(_init_nest_weights, head_bias=head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {f'level.{i}.pos_embed' for i in range(len(self.levels))} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + """ x shape (B, C, H, W) + """ + x = self.patch_embed(x) + x = self.levels(x) + # Layer norm done over channel dim only (to NHWC and back) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + def forward(self, x): + """ x shape (B, C, H, W) + """ + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.head(x) + + +def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.): + """ NesT weight initialization + Can replicate Jax implementation. Otherwise follows vision_transformer.py + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + trunc_normal_(module.weight, std=.02, a=-2, b=2) + nn.init.constant_(module.bias, head_bias) + else: + trunc_normal_(module.weight, std=.02, a=-2, b=2) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=.02, a=-2, b=2) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +def resize_pos_embed(posemb, posemb_new): + """ + Rescale the grid of position embeddings when loading from state_dict + Expected shape of position embeddings is (1, T, N, C), and considers only square images + """ + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + seq_length_old = posemb.shape[2] + num_blocks_new, seq_length_new = posemb_new.shape[1:3] + size_new = int(math.sqrt(num_blocks_new*seq_length_new)) + # First change to (1, C, H, W) + posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2) + posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bicubic', align_corners=False) + # Now change to new (1, T, N, C) + posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new))) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ resize positional embeddings of pretrained weights """ + pos_embed_keys = [k for k in state_dict.keys() if k.startswith('pos_embed_')] + for k in pos_embed_keys: + if state_dict[k].shape != getattr(model, k).shape: + state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k)) + return state_dict + + +def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + model = build_model_with_cfg( + Nest, variant, pretrained, + default_cfg=default_cfg, + feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True), + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +@register_model +def nest_base(pretrained=False, **kwargs): + """ Nest-B @ 224x224 + """ + model_kwargs = dict( + embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) + model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def nest_small(pretrained=False, **kwargs): + """ Nest-S @ 224x224 + """ + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) + model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def nest_tiny(pretrained=False, **kwargs): + """ Nest-T @ 224x224 + """ + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) + model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_base(pretrained=False, **kwargs): + """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs) + model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_small(pretrained=False, **kwargs): + """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs) + model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def jx_nest_tiny(pretrained=False, **kwargs): + """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl. + """ + kwargs['pad_type'] = 'same' + model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs) + model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0f2b211155dc1e304cf076506929817c78d913 --- /dev/null +++ b/timm/models/nfnet.py @@ -0,0 +1,966 @@ +""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models + +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + +Paper: `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + +Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets + +Status: +* These models are a work in progress, experiments ongoing. +* Pretrained weights for two models so far, more to come. +* Model details updated to closer match official JAX code now that it's released +* NF-ResNet, NF-RegNet-B, and NFNet-F models supported + +Hacked together by / copyright Ross Wightman, 2021. +""" +import math +from dataclasses import dataclass, field +from collections import OrderedDict +from typing import Tuple, Optional +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ + get_act_layer, get_act_fn, get_attn, make_divisible + + +def _dcfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + dm_nfnet_f0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth', + pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9), + dm_nfnet_f1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91), + dm_nfnet_f2=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92), + dm_nfnet_f3=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94), + dm_nfnet_f4=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth', + pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951), + dm_nfnet_f5=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth', + pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954), + dm_nfnet_f6=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth', + pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956), + + nfnet_f0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nfnet_f1=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), + nfnet_f2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), + nfnet_f3=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), + nfnet_f4=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), + nfnet_f5=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), + nfnet_f6=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), + nfnet_f7=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), + + nfnet_f0s=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)), + nfnet_f1s=_dcfg( + url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)), + nfnet_f2s=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)), + nfnet_f3s=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)), + nfnet_f4s=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)), + nfnet_f5s=_dcfg( + url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)), + nfnet_f6s=_dcfg( + url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)), + nfnet_f7s=_dcfg( + url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)), + + nfnet_l0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), + eca_nfnet_l0=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth', + hf_hub='timm/eca_nfnet_l0', + pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0), + eca_nfnet_l1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), + eca_nfnet_l2=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), + eca_nfnet_l3=_dcfg( + url='', + pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0), + + nf_regnet_b0=_dcfg( + url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), + nf_regnet_b1=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec + nf_regnet_b2=_dcfg( + url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'), + nf_regnet_b3=_dcfg( + url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'), + nf_regnet_b4=_dcfg( + url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'), + nf_regnet_b5=_dcfg( + url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'), + + nf_resnet26=_dcfg(url='', first_conv='stem.conv'), + nf_resnet50=_dcfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth', + pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'), + nf_resnet101=_dcfg(url='', first_conv='stem.conv'), + + nf_seresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_seresnet101=_dcfg(url='', first_conv='stem.conv'), + + nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'), + nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'), +) + + +@dataclass +class NfCfg: + depths: Tuple[int, int, int, int] + channels: Tuple[int, int, int, int] + alpha: float = 0.2 + stem_type: str = '3x3' + stem_chs: Optional[int] = None + group_size: Optional[int] = None + attn_layer: Optional[str] = None + attn_kwargs: dict = None + attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used + width_factor: float = 1.0 + bottle_ratio: float = 0.5 + num_features: int = 0 # num out_channels for final conv, no final_conv if 0 + ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal + reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle + extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models + gamma_in_act: bool = False + same_padding: bool = False + std_conv_eps: float = 1e-5 + skipinit: bool = False # disabled by default, non-trivial performance impact + zero_init_fc: bool = False + act_layer: str = 'silu' + + +def _nfres_cfg( + depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + attn_kwargs = attn_kwargs or {} + cfg = NfCfg( + depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, + group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): + num_features = 1280 * channels[-1] // 440 + attn_kwargs = dict(rd_ratio=0.5) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, + num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + return cfg + + +def _nfnet_cfg( + depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., + act_layer='gelu', attn_layer='se', attn_kwargs=None): + num_features = int(channels[-1] * feat_mult) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, + bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, + attn_layer=attn_layer, attn_kwargs=attn_kwargs) + return cfg + + +def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): + cfg = NfCfg( + depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, + bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, + num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) + return cfg + + + +model_cfgs = dict( + # NFNet-F models w/ GELU compatible with DeepMind weights + dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), + dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)), + dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)), + dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)), + dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)), + dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)), + dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)), + + # NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU) + nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)), + nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)), + nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)), + nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)), + nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)), + nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)), + nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)), + nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)), + + # NFNet-F models w/ SiLU (much faster in PyTorch) + nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'), + nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'), + nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'), + nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'), + nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'), + nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'), + nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'), + nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'), + + # Experimental 'light' versions of NFNet-F that are little leaner + nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'), + eca_nfnet_l0=_nfnet_cfg( + depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l1=_nfnet_cfg( + depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l2=_nfnet_cfg( + depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l3=_nfnet_cfg( + depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + + # EffNet influenced RegNet defs. + # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. + nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)), + nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)), + nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)), + nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)), + nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)), + nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)), + # FIXME add B6-B8 + + # ResNet (preact, D style deep stem/avg down) defs + nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)), + nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), + nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), + + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + + nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()), + +) + + +class GammaAct(nn.Module): + def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False): + super().__init__() + self.act_fn = get_act_fn(act_type) + self.gamma = gamma + self.inplace = inplace + + def forward(self, x): + return self.act_fn(x, inplace=self.inplace).mul_(self.gamma) + + +def act_with_gamma(act_type, gamma: float = 1.): + def _create(inplace=False): + return GammaAct(act_type, gamma=gamma, inplace=inplace) + return _create + + +class DownsampleAvg(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d): + """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + + def forward(self, x): + return self.conv(self.pool(x)) + + +class NormFreeBlock(nn.Module): + """Normalization-Free pre-activation block. + """ + + def __init__( + self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, + alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, + skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + out_chs = out_chs or in_chs + # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div) + groups = 1 if not group_size else mid_chs // group_size + if group_size and group_size % ch_div == 0: + mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error + self.alpha = alpha + self.beta = beta + self.attn_gain = attn_gain + + if in_chs != out_chs or stride != 1 or dilation != first_dilation: + self.downsample = DownsampleAvg( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) + else: + self.downsample = None + + self.act1 = act_layer() + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.act2 = act_layer(inplace=True) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + if extra_conv: + self.act2b = act_layer(inplace=True) + self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups) + else: + self.act2b = None + self.conv2b = None + if reg and attn_layer is not None: + self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3 + else: + self.attn = None + self.act3 = act_layer() + self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.) + if not reg and attn_layer is not None: + self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3 + else: + self.attn_last = None + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None + + def forward(self, x): + out = self.act1(x) * self.beta + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(out) + + # residual branch + out = self.conv1(out) + out = self.conv2(self.act2(out)) + if self.conv2b is not None: + out = self.conv2b(self.act2b(out)) + if self.attn is not None: + out = self.attn_gain * self.attn(out) + out = self.conv3(self.act3(out)) + if self.attn_last is not None: + out = self.attn_gain * self.attn_last(out) + out = self.drop_path(out) + + if self.skipinit_gain is not None: + out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD + out = out * self.alpha + shortcut + return out + + +def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True): + stem_stride = 2 + stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv') + stem = OrderedDict() + assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') + if 'deep' in stem_type: + if 'quad' in stem_type: + # 4 deep conv stack as in NFNet-F models + assert not 'pool' in stem_type + stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs) + strides = (2, 1, 1, 2) + stem_stride = 4 + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3') + else: + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets + strides = (2, 1, 1) + stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2') + last_idx = len(stem_chs) - 1 + for i, (c, s) in enumerate(zip(stem_chs, strides)): + stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s) + if i != last_idx: + stem[f'act{i + 2}'] = act_layer(inplace=True) + in_chs = c + elif '3x3' in stem_type: + # 3x3 stem conv as in RegNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) + else: + # 7x7 stem conv as in ResNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + + if 'pool' in stem_type: + stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) + stem_stride = 4 + + return nn.Sequential(stem), stem_stride, stem_feature + + +# from https://github.com/deepmind/deepmind-research/tree/master/nfnets +_nonlin_gamma = dict( + identity=1.0, + celu=1.270926833152771, + elu=1.2716004848480225, + gelu=1.7015043497085571, + leaky_relu=1.70590341091156, + log_sigmoid=1.9193484783172607, + log_softmax=1.0002083778381348, + relu=1.7139588594436646, + relu6=1.7131484746932983, + selu=1.0008515119552612, + sigmoid=4.803835391998291, + silu=1.7881293296813965, + softsign=2.338853120803833, + softplus=1.9203323125839233, + tanh=1.5939117670059204, +) + + +class NormFreeNet(nn.Module): + """ Normalization-Free Network + + As described in : + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + and + `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171 + + This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and + the (preact) ResNet models described earlier in the paper. + + There are a few differences: + * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy), + this changes channel dim and param counts slightly from the paper models + * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance + impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. + * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but + apply it in each activation. This is slightly slower, numerically different, but matches official impl. + * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput + for what it is/does. Approx 8-10% throughput loss. + """ + def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + drop_rate=0., drop_path_rate=0.): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." + conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d + if cfg.gamma_in_act: + act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial(conv_layer, eps=cfg.std_conv_eps) + else: + act_layer = get_act_layer(cfg.act_layer) + conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps) + attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + + stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) + self.stem, stem_stride, stem_feat = create_stem( + in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) + + self.feature_info = [stem_feat] + drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + prev_chs = stem_chs + net_stride = stem_stride + dilation = 1 + expected_var = 1.0 + stages = [] + for stage_idx, stage_depth in enumerate(cfg.depths): + stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 + if net_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + blocks = [] + for block_idx in range(cfg.depths[stage_idx]): + first_block = block_idx == 0 and stage_idx == 0 + out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) + blocks += [NormFreeBlock( + in_chs=prev_chs, out_chs=out_chs, + alpha=cfg.alpha, + beta=1. / expected_var ** 0.5, + stride=stride if block_idx == 0 else 1, + dilation=dilation, + first_dilation=first_dilation, + group_size=cfg.group_size, + bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio, + ch_div=cfg.ch_div, + reg=cfg.reg, + extra_conv=cfg.extra_conv, + skipinit=cfg.skipinit, + attn_layer=attn_layer, + attn_gain=cfg.attn_gain, + act_layer=act_layer, + conv_layer=conv_layer, + drop_path_rate=drop_path_rates[stage_idx][block_idx], + )] + if block_idx == 0: + expected_var = 1. # expected var is reset after first block of each stage + expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance + first_dilation = dilation + prev_chs = out_chs + self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')] + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) + + if cfg.num_features: + # The paper NFRegNet models have an EfficientNet-like final head convolution. + self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) + self.final_conv = conv_layer(prev_chs, self.num_features, 1) + self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv') + else: + self.num_features = prev_chs + self.final_conv = nn.Identity() + self.final_act = act_layer(inplace=cfg.num_features > 0) + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + for n, m in self.named_modules(): + if 'fc' in n and isinstance(m, nn.Linear): + if cfg.zero_init_fc: + nn.init.zeros_(m.weight) + else: + nn.init.normal_(m.weight, 0., .01) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') + if m.bias is not None: + nn.init.zeros_(m.bias) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.final_conv(x) + x = self.final_act(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_normfreenet(variant, pretrained=False, **kwargs): + model_cfg = model_cfgs[variant] + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + NormFreeNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfg, + feature_cfg=feature_cfg, + **kwargs) + + +@register_model +def dm_nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def dm_nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 (DeepMind weight compatible) + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0(pretrained=False, **kwargs): + """ NFNet-F0 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1(pretrained=False, **kwargs): + """ NFNet-F1 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2(pretrained=False, **kwargs): + """ NFNet-F2 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3(pretrained=False, **kwargs): + """ NFNet-F3 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4(pretrained=False, **kwargs): + """ NFNet-F4 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5(pretrained=False, **kwargs): + """ NFNet-F5 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6(pretrained=False, **kwargs): + """ NFNet-F6 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7(pretrained=False, **kwargs): + """ NFNet-F7 + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f0s(pretrained=False, **kwargs): + """ NFNet-F0 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f1s(pretrained=False, **kwargs): + """ NFNet-F1 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f2s(pretrained=False, **kwargs): + """ NFNet-F2 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f3s(pretrained=False, **kwargs): + """ NFNet-F3 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f4s(pretrained=False, **kwargs): + """ NFNet-F4 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f5s(pretrained=False, **kwargs): + """ NFNet-F5 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f6s(pretrained=False, **kwargs): + """ NFNet-F6 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_f7s(pretrained=False, **kwargs): + """ NFNet-F7 w/ SiLU + `High-Performance Large-Scale Image Recognition Without Normalization` + - https://arxiv.org/abs/2102.06171 + """ + return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs) + + +@register_model +def nfnet_l0(pretrained=False, **kwargs): + """ NFNet-L0b w/ SiLU + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio + """ + return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l0(pretrained=False, **kwargs): + """ ECA-NFNet-L0 w/ SiLU + My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l1(pretrained=False, **kwargs): + """ ECA-NFNet-L1 w/ SiLU + My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l2(pretrained=False, **kwargs): + """ ECA-NFNet-L2 w/ SiLU + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l3(pretrained=False, **kwargs): + """ ECA-NFNet-L3 w/ SiLU + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b0(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B0 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b1(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B1 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b2(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B2 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b3(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B3 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b4(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B4 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b5(pretrained=False, **kwargs): + """ Normalization-Free RegNet-B5 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet26(pretrained=False, **kwargs): + """ Normalization-Free ResNet-26 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet50(pretrained=False, **kwargs): + """ Normalization-Free ResNet-50 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet101(pretrained=False, **kwargs): + """ Normalization-Free ResNet-101 + `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + """ + return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet26(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet26 + """ + return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet50(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet50 + """ + return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet101(pretrained=False, **kwargs): + """ Normalization-Free SE-ResNet101 + """ + return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet26(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet26 + """ + return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet50(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet50 + """ + return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet101(pretrained=False, **kwargs): + """ Normalization-Free ECA-ResNet101 + """ + return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs) diff --git a/timm/models/pit.py b/timm/models/pit.py new file mode 100644 index 0000000000000000000000000000000000000000..460824e2d65ae403caea62c4fe8ac48a2a0f78e9 --- /dev/null +++ b/timm/models/pit.py @@ -0,0 +1,384 @@ +""" Pooling-based Vision Transformer (PiT) in PyTorch + +A PyTorch implement of Pooling-based Vision Transformers as described in +'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302 + +This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below. + +Modifications for timm by / Copyright 2020 Ross Wightman +""" +# PiT +# Copyright 2021-present NAVER Corp. +# Apache License v2.0 + +import math +import re +from copy import deepcopy +from functools import partial +from typing import Tuple + +import torch +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import trunc_normal_, to_2tuple +from .registry import register_model +from .vision_transformer import Block + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # deit models (FB weights) + 'pit_ti_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'), + 'pit_xs_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'), + 'pit_s_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'), + 'pit_b_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'), + 'pit_ti_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth', + classifier=('head', 'head_dist')), + 'pit_xs_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth', + classifier=('head', 'head_dist')), + 'pit_s_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth', + classifier=('head', 'head_dist')), + 'pit_b_distilled_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth', + classifier=('head', 'head_dist')), +} + + +class SequentialTuple(nn.Sequential): + """ This module exists to work around torchscript typing issues list -> list""" + def __init__(self, *args): + super(SequentialTuple, self).__init__(*args) + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + for module in self: + x = module(x) + return x + + +class Transformer(nn.Module): + def __init__( + self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): + super(Transformer, self).__init__() + self.layers = nn.ModuleList([]) + embed_dim = base_dim * heads + + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, + num_heads=heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_prob[i], + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ) + for i in range(depth)]) + + self.pool = pool + + def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + x, cls_tokens = x + B, C, H, W = x.shape + token_length = cls_tokens.shape[1] + + x = x.flatten(2).transpose(1, 2) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.blocks(x) + + cls_tokens = x[:, :token_length] + x = x[:, token_length:] + x = x.transpose(1, 2).reshape(B, C, H, W) + + if self.pool is not None: + x, cls_tokens = self.pool(x, cls_tokens) + return x, cls_tokens + + +class ConvHeadPooling(nn.Module): + def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'): + super(ConvHeadPooling, self).__init__() + + self.conv = nn.Conv2d( + in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride, + padding_mode=padding_mode, groups=in_feature) + self.fc = nn.Linear(in_feature, out_feature) + + def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: + + x = self.conv(x) + cls_token = self.fc(cls_token) + + return x, cls_token + + +class ConvEmbedding(nn.Module): + def __init__(self, in_channels, out_channels, patch_size, stride, padding): + super(ConvEmbedding, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True) + + def forward(self, x): + x = self.conv(x) + return x + + +class PoolingVisionTransformer(nn.Module): + """ Pooling-based Vision Transformer + + A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers' + - https://arxiv.org/abs/2103.16302 + """ + def __init__(self, img_size, patch_size, stride, base_dims, depth, heads, + mlp_ratio, num_classes=1000, in_chans=3, distilled=False, + attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0): + super(PoolingVisionTransformer, self).__init__() + + padding = 0 + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1) + width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1) + + self.base_dims = base_dims + self.heads = heads + self.num_classes = num_classes + self.num_tokens = 2 if distilled else 1 + + self.patch_size = patch_size + self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width)) + self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding) + + self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0])) + self.pos_drop = nn.Dropout(p=drop_rate) + + transformers = [] + # stochastic depth decay rule + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)] + for stage in range(len(depth)): + pool = None + if stage < len(heads) - 1: + pool = ConvHeadPooling( + base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2) + transformers += [Transformer( + base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool, + drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage]) + ] + self.transformers = SequentialTuple(*transformers) + self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) + self.num_features = self.embed_dim = base_dims[-1] * heads[-1] + + # Classifier head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + if self.head_dist is not None: + return self.head, self.head_dist + else: + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.head_dist is not None: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.pos_drop(x + self.pos_embed) + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x, cls_tokens = self.transformers((x, cls_tokens)) + cls_tokens = self.norm(cls_tokens) + if self.head_dist is not None: + return cls_tokens[:, 0], cls_tokens[:, 1] + else: + return cls_tokens[:, 0] + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + return (x + x_dist) / 2 + else: + return self.head(x) + + +def checkpoint_filter_fn(state_dict, model): + """ preprocess checkpoints """ + out_dict = {} + p_blocks = re.compile(r'pools\.(\d)\.') + for k, v in state_dict.items(): + # FIXME need to update resize for PiT impl + # if k == 'pos_embed' and v.shape != model.pos_embed.shape: + # # To resize pos embedding when using model at different size from pretrained weights + # v = resize_pos_embed(v, model.pos_embed) + k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k) + out_dict[k] = v + return out_dict + + +def _create_pit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + PoolingVisionTransformer, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def pit_b_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=14, + stride=7, + base_dims=[64, 64, 64], + depth=[3, 6, 4], + heads=[4, 8, 16], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_b_224', pretrained, **model_kwargs) + + +@register_model +def pit_s_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[3, 6, 12], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_s_224', pretrained, **model_kwargs) + + +@register_model +def pit_xs_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_xs_224', pretrained, **model_kwargs) + + +@register_model +def pit_ti_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[32, 32, 32], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + **kwargs + ) + return _create_pit('pit_ti_224', pretrained, **model_kwargs) + + +@register_model +def pit_b_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=14, + stride=7, + base_dims=[64, 64, 64], + depth=[3, 6, 4], + heads=[4, 8, 16], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_s_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[3, 6, 12], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_xs_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[48, 48, 48], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs) + + +@register_model +def pit_ti_distilled_224(pretrained, **kwargs): + model_kwargs = dict( + patch_size=16, + stride=8, + base_dims=[32, 32, 32], + depth=[2, 6, 4], + heads=[2, 4, 8], + mlp_ratio=4, + distilled=True, + **kwargs + ) + return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs) \ No newline at end of file diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py new file mode 100644 index 0000000000000000000000000000000000000000..999181563a40b58c751b2ff56a631ae7508047e9 --- /dev/null +++ b/timm/models/pnasnet.py @@ -0,0 +1,350 @@ +""" + pnasnet5large implementation grabbed from Cadene's pretrained models + Additional credit to https://github.com/creafz + + https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py + +""" +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .registry import register_model + +__all__ = ['PNASNet5Large'] + +default_cfgs = { + 'pnasnet5large': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'crop_pct': 0.911, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv_0.conv', + 'classifier': 'last_linear', + 'label_offset': 1, # 1001 classes in pretrained weights + }, +} + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = create_conv2d( + in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels) + self.pointwise_conv2d = create_conv2d( + in_channels, out_channels, kernel_size=1, padding=padding) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''): + super(BranchSeparables, self).__init__() + middle_channels = out_channels if stem_cell else in_channels + self.act_1 = nn.ReLU() + self.separable_1 = SeparableConv2d( + in_channels, middle_channels, kernel_size, stride=stride, padding=padding) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) + self.act_2 = nn.ReLU() + self.separable_2 = SeparableConv2d( + middle_channels, out_channels, kernel_size, stride=1, padding=padding) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act_1(x) + x = self.separable_1(x) + x = self.bn_sep_1(x) + x = self.act_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class ActConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''): + super(ActConvBn, self).__init__() + self.act = nn.ReLU() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class FactorizedReduction(nn.Module): + + def __init__(self, in_channels, out_channels, padding=''): + super(FactorizedReduction, self).__init__() + self.act = nn.ReLU() + self.path_1 = nn.Sequential(OrderedDict([ + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ])) + self.path_2 = nn.Sequential(OrderedDict([ + ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)), + ])) + self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.act(x) + x_path1 = self.path_1(x) + x_path2 = self.path_2(x) + out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + return out + + +class CellBase(nn.Module): + + def cell_forward(self, x_left, x_right): + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) + x_comb_iter_3_right = self.comb_iter_3_right(x_right) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_left) + if self.comb_iter_4_right is not None: + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + else: + x_comb_iter_4_right = x_right + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1) + return x_out + + +class CellStem0(CellBase): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''): + super(CellStem0, self).__init__() + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables( + in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type) + self.comb_iter_0_right = nn.Sequential(OrderedDict([ + ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)), + ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)), + ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)), + ])) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type) + + self.comb_iter_3_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, padding=pad_type) + self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables( + in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type) + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type) + + def forward(self, x_left): + x_right = self.conv_1x1(x_left) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class Cell(CellBase): + + def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='', + is_reduction=False, match_prev_layer_dims=False): + super(Cell, self).__init__() + + # If `is_reduction` is set to `True` stride 2 is used for + # convolution and pooling layers to reduce the spatial size of + # the output of a cell approximately by a factor of 2. + stride = 2 if is_reduction else 1 + + # If `match_prev_layer_dimensions` is set to `True` + # `FactorizedReduction` is used to reduce the spatial size + # of the left input of a cell approximately by a factor of 2. + self.match_prev_layer_dimensions = match_prev_layer_dims + if match_prev_layer_dims: + self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type) + else: + self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type) + self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type) + + self.comb_iter_0_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_1_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type) + self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_2_left = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type) + self.comb_iter_2_right = BranchSeparables( + out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type) + + self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3) + self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type) + + self.comb_iter_4_left = BranchSeparables( + out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type) + if is_reduction: + self.comb_iter_4_right = ActConvBn( + out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type) + else: + self.comb_iter_4_right = None + + def forward(self, x_left, x_right): + x_left = self.conv_prev_1x1(x_left) + x_right = self.conv_1x1(x_right) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class PNASNet5Large(nn.Module): + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''): + super(PNASNet5Large, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + self.num_features = 4320 + assert output_stride == 32 + + self.conv_0 = ConvBnAct( + in_chans, 96, kernel_size=3, stride=2, padding=0, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) + + self.cell_stem_0 = CellStem0( + in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type) + + self.cell_stem_1 = Cell( + in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type, + match_prev_layer_dims=True, is_reduction=True) + self.cell_0 = Cell( + in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_1 = Cell( + in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + self.cell_2 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + self.cell_3 = Cell( + in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type) + + self.cell_4 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type, + is_reduction=True) + self.cell_5 = Cell( + in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_6 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + self.cell_7 = Cell( + in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type) + + self.cell_8 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type, + is_reduction=True) + self.cell_9 = Cell( + in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type, + match_prev_layer_dims=True) + self.cell_10 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.cell_11 = Cell( + in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type) + self.act = nn.ReLU() + self.feature_info = [ + dict(num_chs=96, reduction=2, module='conv_0'), + dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'), + dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'), + dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'), + dict(num_chs=4320, reduction=32, module='act'), + ] + + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x_conv_0 = self.conv_0(x) + x_stem_0 = self.cell_stem_0(x_conv_0) + x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) + x_cell_0 = self.cell_0(x_stem_0, x_stem_1) + x_cell_1 = self.cell_1(x_stem_1, x_cell_0) + x_cell_2 = self.cell_2(x_cell_0, x_cell_1) + x_cell_3 = self.cell_3(x_cell_1, x_cell_2) + x_cell_4 = self.cell_4(x_cell_2, x_cell_3) + x_cell_5 = self.cell_5(x_cell_3, x_cell_4) + x_cell_6 = self.cell_6(x_cell_4, x_cell_5) + x_cell_7 = self.cell_7(x_cell_5, x_cell_6) + x_cell_8 = self.cell_8(x_cell_6, x_cell_7) + x_cell_9 = self.cell_9(x_cell_7, x_cell_8) + x_cell_10 = self.cell_10(x_cell_8, x_cell_9) + x_cell_11 = self.cell_11(x_cell_9, x_cell_10) + x = self.act(x_cell_11) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + +def _create_pnasnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + PNASNet5Large, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model + **kwargs) + + +@register_model +def pnasnet5large(pretrained=False, **kwargs): + r"""PNASNet-5 model architecture from the + `"Progressive Neural Architecture Search" + `_ paper. + """ + model_kwargs = dict(pad_type='same', **kwargs) + return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs) diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/pruned/ecaresnet101d_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..2589b2f9dd3f0d1e02e1d5ddc1fbcd5c143e02c6 --- /dev/null +++ b/timm/models/pruned/ecaresnet101d_pruned.txt @@ -0,0 +1 @@ +conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[45, 64, 1, 1]***layer1.0.bn1.weight:[45]***layer1.0.conv2.weight:[25, 45, 3, 3]***layer1.0.bn2.weight:[25]***layer1.0.conv3.weight:[26, 25, 1, 1]***layer1.0.bn3.weight:[26]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[26, 64, 1, 1]***layer1.0.downsample.2.weight:[26]***layer1.1.conv1.weight:[53, 26, 1, 1]***layer1.1.bn1.weight:[53]***layer1.1.conv2.weight:[20, 53, 3, 3]***layer1.1.bn2.weight:[20]***layer1.1.conv3.weight:[26, 20, 1, 1]***layer1.1.bn3.weight:[26]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[60, 26, 1, 1]***layer1.2.bn1.weight:[60]***layer1.2.conv2.weight:[27, 60, 3, 3]***layer1.2.bn2.weight:[27]***layer1.2.conv3.weight:[26, 27, 1, 1]***layer1.2.bn3.weight:[26]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[81, 26, 1, 1]***layer2.0.bn1.weight:[81]***layer2.0.conv2.weight:[24, 81, 3, 3]***layer2.0.bn2.weight:[24]***layer2.0.conv3.weight:[142, 24, 1, 1]***layer2.0.bn3.weight:[142]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[142, 26, 1, 1]***layer2.0.downsample.2.weight:[142]***layer2.1.conv1.weight:[93, 142, 1, 1]***layer2.1.bn1.weight:[93]***layer2.1.conv2.weight:[49, 93, 3, 3]***layer2.1.bn2.weight:[49]***layer2.1.conv3.weight:[142, 49, 1, 1]***layer2.1.bn3.weight:[142]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[102, 142, 1, 1]***layer2.2.bn1.weight:[102]***layer2.2.conv2.weight:[54, 102, 3, 3]***layer2.2.bn2.weight:[54]***layer2.2.conv3.weight:[142, 54, 1, 1]***layer2.2.bn3.weight:[142]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[122, 142, 1, 1]***layer2.3.bn1.weight:[122]***layer2.3.conv2.weight:[78, 122, 3, 3]***layer2.3.bn2.weight:[78]***layer2.3.conv3.weight:[142, 78, 1, 1]***layer2.3.bn3.weight:[142]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[101, 142, 1, 1]***layer3.0.bn1.weight:[101]***layer3.0.conv2.weight:[25, 101, 3, 3]***layer3.0.bn2.weight:[25]***layer3.0.conv3.weight:[278, 25, 1, 1]***layer3.0.bn3.weight:[278]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[278, 142, 1, 1]***layer3.0.downsample.2.weight:[278]***layer3.1.conv1.weight:[239, 278, 1, 1]***layer3.1.bn1.weight:[239]***layer3.1.conv2.weight:[160, 239, 3, 3]***layer3.1.bn2.weight:[160]***layer3.1.conv3.weight:[278, 160, 1, 1]***layer3.1.bn3.weight:[278]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[234, 278, 1, 1]***layer3.2.bn1.weight:[234]***layer3.2.conv2.weight:[156, 234, 3, 3]***layer3.2.bn2.weight:[156]***layer3.2.conv3.weight:[278, 156, 1, 1]***layer3.2.bn3.weight:[278]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[250, 278, 1, 1]***layer3.3.bn1.weight:[250]***layer3.3.conv2.weight:[176, 250, 3, 3]***layer3.3.bn2.weight:[176]***layer3.3.conv3.weight:[278, 176, 1, 1]***layer3.3.bn3.weight:[278]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[253, 278, 1, 1]***layer3.4.bn1.weight:[253]***layer3.4.conv2.weight:[191, 253, 3, 3]***layer3.4.bn2.weight:[191]***layer3.4.conv3.weight:[278, 191, 1, 1]***layer3.4.bn3.weight:[278]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[251, 278, 1, 1]***layer3.5.bn1.weight:[251]***layer3.5.conv2.weight:[175, 251, 3, 3]***layer3.5.bn2.weight:[175]***layer3.5.conv3.weight:[278, 175, 1, 1]***layer3.5.bn3.weight:[278]***layer3.5.se.conv.weight:[1, 1, 5]***layer3.6.conv1.weight:[230, 278, 1, 1]***layer3.6.bn1.weight:[230]***layer3.6.conv2.weight:[128, 230, 3, 3]***layer3.6.bn2.weight:[128]***layer3.6.conv3.weight:[278, 128, 1, 1]***layer3.6.bn3.weight:[278]***layer3.6.se.conv.weight:[1, 1, 5]***layer3.7.conv1.weight:[244, 278, 1, 1]***layer3.7.bn1.weight:[244]***layer3.7.conv2.weight:[154, 244, 3, 3]***layer3.7.bn2.weight:[154]***layer3.7.conv3.weight:[278, 154, 1, 1]***layer3.7.bn3.weight:[278]***layer3.7.se.conv.weight:[1, 1, 5]***layer3.8.conv1.weight:[244, 278, 1, 1]***layer3.8.bn1.weight:[244]***layer3.8.conv2.weight:[159, 244, 3, 3]***layer3.8.bn2.weight:[159]***layer3.8.conv3.weight:[278, 159, 1, 1]***layer3.8.bn3.weight:[278]***layer3.8.se.conv.weight:[1, 1, 5]***layer3.9.conv1.weight:[238, 278, 1, 1]***layer3.9.bn1.weight:[238]***layer3.9.conv2.weight:[97, 238, 3, 3]***layer3.9.bn2.weight:[97]***layer3.9.conv3.weight:[278, 97, 1, 1]***layer3.9.bn3.weight:[278]***layer3.9.se.conv.weight:[1, 1, 5]***layer3.10.conv1.weight:[244, 278, 1, 1]***layer3.10.bn1.weight:[244]***layer3.10.conv2.weight:[149, 244, 3, 3]***layer3.10.bn2.weight:[149]***layer3.10.conv3.weight:[278, 149, 1, 1]***layer3.10.bn3.weight:[278]***layer3.10.se.conv.weight:[1, 1, 5]***layer3.11.conv1.weight:[253, 278, 1, 1]***layer3.11.bn1.weight:[253]***layer3.11.conv2.weight:[181, 253, 3, 3]***layer3.11.bn2.weight:[181]***layer3.11.conv3.weight:[278, 181, 1, 1]***layer3.11.bn3.weight:[278]***layer3.11.se.conv.weight:[1, 1, 5]***layer3.12.conv1.weight:[245, 278, 1, 1]***layer3.12.bn1.weight:[245]***layer3.12.conv2.weight:[119, 245, 3, 3]***layer3.12.bn2.weight:[119]***layer3.12.conv3.weight:[278, 119, 1, 1]***layer3.12.bn3.weight:[278]***layer3.12.se.conv.weight:[1, 1, 5]***layer3.13.conv1.weight:[255, 278, 1, 1]***layer3.13.bn1.weight:[255]***layer3.13.conv2.weight:[216, 255, 3, 3]***layer3.13.bn2.weight:[216]***layer3.13.conv3.weight:[278, 216, 1, 1]***layer3.13.bn3.weight:[278]***layer3.13.se.conv.weight:[1, 1, 5]***layer3.14.conv1.weight:[256, 278, 1, 1]***layer3.14.bn1.weight:[256]***layer3.14.conv2.weight:[201, 256, 3, 3]***layer3.14.bn2.weight:[201]***layer3.14.conv3.weight:[278, 201, 1, 1]***layer3.14.bn3.weight:[278]***layer3.14.se.conv.weight:[1, 1, 5]***layer3.15.conv1.weight:[253, 278, 1, 1]***layer3.15.bn1.weight:[253]***layer3.15.conv2.weight:[149, 253, 3, 3]***layer3.15.bn2.weight:[149]***layer3.15.conv3.weight:[278, 149, 1, 1]***layer3.15.bn3.weight:[278]***layer3.15.se.conv.weight:[1, 1, 5]***layer3.16.conv1.weight:[254, 278, 1, 1]***layer3.16.bn1.weight:[254]***layer3.16.conv2.weight:[141, 254, 3, 3]***layer3.16.bn2.weight:[141]***layer3.16.conv3.weight:[278, 141, 1, 1]***layer3.16.bn3.weight:[278]***layer3.16.se.conv.weight:[1, 1, 5]***layer3.17.conv1.weight:[256, 278, 1, 1]***layer3.17.bn1.weight:[256]***layer3.17.conv2.weight:[190, 256, 3, 3]***layer3.17.bn2.weight:[190]***layer3.17.conv3.weight:[278, 190, 1, 1]***layer3.17.bn3.weight:[278]***layer3.17.se.conv.weight:[1, 1, 5]***layer3.18.conv1.weight:[256, 278, 1, 1]***layer3.18.bn1.weight:[256]***layer3.18.conv2.weight:[217, 256, 3, 3]***layer3.18.bn2.weight:[217]***layer3.18.conv3.weight:[278, 217, 1, 1]***layer3.18.bn3.weight:[278]***layer3.18.se.conv.weight:[1, 1, 5]***layer3.19.conv1.weight:[255, 278, 1, 1]***layer3.19.bn1.weight:[255]***layer3.19.conv2.weight:[156, 255, 3, 3]***layer3.19.bn2.weight:[156]***layer3.19.conv3.weight:[278, 156, 1, 1]***layer3.19.bn3.weight:[278]***layer3.19.se.conv.weight:[1, 1, 5]***layer3.20.conv1.weight:[256, 278, 1, 1]***layer3.20.bn1.weight:[256]***layer3.20.conv2.weight:[155, 256, 3, 3]***layer3.20.bn2.weight:[155]***layer3.20.conv3.weight:[278, 155, 1, 1]***layer3.20.bn3.weight:[278]***layer3.20.se.conv.weight:[1, 1, 5]***layer3.21.conv1.weight:[256, 278, 1, 1]***layer3.21.bn1.weight:[256]***layer3.21.conv2.weight:[232, 256, 3, 3]***layer3.21.bn2.weight:[232]***layer3.21.conv3.weight:[278, 232, 1, 1]***layer3.21.bn3.weight:[278]***layer3.21.se.conv.weight:[1, 1, 5]***layer3.22.conv1.weight:[256, 278, 1, 1]***layer3.22.bn1.weight:[256]***layer3.22.conv2.weight:[214, 256, 3, 3]***layer3.22.bn2.weight:[214]***layer3.22.conv3.weight:[278, 214, 1, 1]***layer3.22.bn3.weight:[278]***layer3.22.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[499, 278, 1, 1]***layer4.0.bn1.weight:[499]***layer4.0.conv2.weight:[289, 499, 3, 3]***layer4.0.bn2.weight:[289]***layer4.0.conv3.weight:[2042, 289, 1, 1]***layer4.0.bn3.weight:[2042]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2042, 278, 1, 1]***layer4.0.downsample.2.weight:[2042]***layer4.1.conv1.weight:[512, 2042, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[512, 512, 3, 3]***layer4.1.bn2.weight:[512]***layer4.1.conv3.weight:[2042, 512, 1, 1]***layer4.1.bn3.weight:[2042]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2042, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[502, 512, 3, 3]***layer4.2.bn2.weight:[502]***layer4.2.conv3.weight:[2042, 502, 1, 1]***layer4.2.bn3.weight:[2042]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2042]***layer1_2_conv3_M.weight:[256, 26]***layer2_3_conv3_M.weight:[512, 142]***layer3_22_conv3_M.weight:[1024, 278]***layer4_2_conv3_M.weight:[2048, 2042] \ No newline at end of file diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/pruned/ecaresnet50d_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a8b2bf50e0631dce74d66a1a98e26cae10572a7 --- /dev/null +++ b/timm/models/pruned/ecaresnet50d_pruned.txt @@ -0,0 +1 @@ +conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022] \ No newline at end of file diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/pruned/efficientnet_b1_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..0972b527612b283fd242cc5eaeb6e767ea106c66 --- /dev/null +++ b/timm/models/pruned/efficientnet_b1_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[48, 16, 1, 1]***blocks.1.0.bn1.weight:[48]***blocks.1.0.bn1.bias:[48]***blocks.1.0.bn1.running_mean:[48]***blocks.1.0.bn1.running_var:[48]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[48, 1, 3, 3]***blocks.1.0.bn2.weight:[48]***blocks.1.0.bn2.bias:[48]***blocks.1.0.bn2.running_mean:[48]***blocks.1.0.bn2.running_var:[48]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 48, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[48, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[48]***blocks.1.0.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[62, 12, 1, 1]***blocks.1.1.bn1.weight:[62]***blocks.1.1.bn1.bias:[62]***blocks.1.1.bn1.running_mean:[62]***blocks.1.1.bn1.running_var:[62]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[62, 1, 3, 3]***blocks.1.1.bn2.weight:[62]***blocks.1.1.bn2.bias:[62]***blocks.1.1.bn2.running_mean:[62]***blocks.1.1.bn2.running_var:[62]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 62, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[62, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[62]***blocks.1.1.conv_pwl.weight:[12, 62, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[48, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[70, 12, 1, 1]***blocks.2.0.bn1.weight:[70]***blocks.2.0.bn1.bias:[70]***blocks.2.0.bn1.running_mean:[70]***blocks.2.0.bn1.running_var:[70]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[70, 1, 5, 5]***blocks.2.0.bn2.weight:[70]***blocks.2.0.bn2.bias:[70]***blocks.2.0.bn2.running_mean:[70]***blocks.2.0.bn2.running_var:[70]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 70, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[70, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[70]***blocks.2.0.conv_pwl.weight:[35, 70, 1, 1]***blocks.2.0.bn3.weight:[35]***blocks.2.0.bn3.bias:[35]***blocks.2.0.bn3.running_mean:[35]***blocks.2.0.bn3.running_var:[35]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[61, 35, 1, 1]***blocks.2.1.bn1.weight:[61]***blocks.2.1.bn1.bias:[61]***blocks.2.1.bn1.running_mean:[61]***blocks.2.1.bn1.running_var:[61]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[61, 1, 5, 5]***blocks.2.1.bn2.weight:[61]***blocks.2.1.bn2.bias:[61]***blocks.2.1.bn2.running_mean:[61]***blocks.2.1.bn2.running_var:[61]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[10, 61, 1, 1]***blocks.2.1.se.conv_reduce.bias:[10]***blocks.2.1.se.conv_expand.weight:[61, 10, 1, 1]***blocks.2.1.se.conv_expand.bias:[61]***blocks.2.1.conv_pwl.weight:[35, 61, 1, 1]***blocks.2.1.bn3.weight:[35]***blocks.2.1.bn3.bias:[35]***blocks.2.1.bn3.running_mean:[35]***blocks.2.1.bn3.running_var:[35]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[51, 35, 1, 1]***blocks.2.2.bn1.weight:[51]***blocks.2.2.bn1.bias:[51]***blocks.2.2.bn1.running_mean:[51]***blocks.2.2.bn1.running_var:[51]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[51, 1, 5, 5]***blocks.2.2.bn2.weight:[51]***blocks.2.2.bn2.bias:[51]***blocks.2.2.bn2.running_mean:[51]***blocks.2.2.bn2.running_var:[51]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[10, 51, 1, 1]***blocks.2.2.se.conv_reduce.bias:[10]***blocks.2.2.se.conv_expand.weight:[51, 10, 1, 1]***blocks.2.2.se.conv_expand.bias:[51]***blocks.2.2.conv_pwl.weight:[35, 51, 1, 1]***blocks.2.2.bn3.weight:[35]***blocks.2.2.bn3.bias:[35]***blocks.2.2.bn3.running_mean:[35]***blocks.2.2.bn3.running_var:[35]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[175, 35, 1, 1]***blocks.3.0.bn1.weight:[175]***blocks.3.0.bn1.bias:[175]***blocks.3.0.bn1.running_mean:[175]***blocks.3.0.bn1.running_var:[175]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[175, 1, 3, 3]***blocks.3.0.bn2.weight:[175]***blocks.3.0.bn2.bias:[175]***blocks.3.0.bn2.running_mean:[175]***blocks.3.0.bn2.running_var:[175]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[10, 175, 1, 1]***blocks.3.0.se.conv_reduce.bias:[10]***blocks.3.0.se.conv_expand.weight:[175, 10, 1, 1]***blocks.3.0.se.conv_expand.bias:[175]***blocks.3.0.conv_pwl.weight:[74, 175, 1, 1]***blocks.3.0.bn3.weight:[74]***blocks.3.0.bn3.bias:[74]***blocks.3.0.bn3.running_mean:[74]***blocks.3.0.bn3.running_var:[74]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[188, 74, 1, 1]***blocks.3.1.bn1.weight:[188]***blocks.3.1.bn1.bias:[188]***blocks.3.1.bn1.running_mean:[188]***blocks.3.1.bn1.running_var:[188]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[188, 1, 3, 3]***blocks.3.1.bn2.weight:[188]***blocks.3.1.bn2.bias:[188]***blocks.3.1.bn2.running_mean:[188]***blocks.3.1.bn2.running_var:[188]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[20, 188, 1, 1]***blocks.3.1.se.conv_reduce.bias:[20]***blocks.3.1.se.conv_expand.weight:[188, 20, 1, 1]***blocks.3.1.se.conv_expand.bias:[188]***blocks.3.1.conv_pwl.weight:[74, 188, 1, 1]***blocks.3.1.bn3.weight:[74]***blocks.3.1.bn3.bias:[74]***blocks.3.1.bn3.running_mean:[74]***blocks.3.1.bn3.running_var:[74]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[137, 74, 1, 1]***blocks.3.2.bn1.weight:[137]***blocks.3.2.bn1.bias:[137]***blocks.3.2.bn1.running_mean:[137]***blocks.3.2.bn1.running_var:[137]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[137, 1, 3, 3]***blocks.3.2.bn2.weight:[137]***blocks.3.2.bn2.bias:[137]***blocks.3.2.bn2.running_mean:[137]***blocks.3.2.bn2.running_var:[137]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[20, 137, 1, 1]***blocks.3.2.se.conv_reduce.bias:[20]***blocks.3.2.se.conv_expand.weight:[137, 20, 1, 1]***blocks.3.2.se.conv_expand.bias:[137]***blocks.3.2.conv_pwl.weight:[74, 137, 1, 1]***blocks.3.2.bn3.weight:[74]***blocks.3.2.bn3.bias:[74]***blocks.3.2.bn3.running_mean:[74]***blocks.3.2.bn3.running_var:[74]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[164, 74, 1, 1]***blocks.3.3.bn1.weight:[164]***blocks.3.3.bn1.bias:[164]***blocks.3.3.bn1.running_mean:[164]***blocks.3.3.bn1.running_var:[164]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[164, 1, 3, 3]***blocks.3.3.bn2.weight:[164]***blocks.3.3.bn2.bias:[164]***blocks.3.3.bn2.running_mean:[164]***blocks.3.3.bn2.running_var:[164]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[20, 164, 1, 1]***blocks.3.3.se.conv_reduce.bias:[20]***blocks.3.3.se.conv_expand.weight:[164, 20, 1, 1]***blocks.3.3.se.conv_expand.bias:[164]***blocks.3.3.conv_pwl.weight:[74, 164, 1, 1]***blocks.3.3.bn3.weight:[74]***blocks.3.3.bn3.bias:[74]***blocks.3.3.bn3.running_mean:[74]***blocks.3.3.bn3.running_var:[74]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[399, 74, 1, 1]***blocks.4.0.bn1.weight:[399]***blocks.4.0.bn1.bias:[399]***blocks.4.0.bn1.running_mean:[399]***blocks.4.0.bn1.running_var:[399]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[399, 1, 5, 5]***blocks.4.0.bn2.weight:[399]***blocks.4.0.bn2.bias:[399]***blocks.4.0.bn2.running_mean:[399]***blocks.4.0.bn2.running_var:[399]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[20, 399, 1, 1]***blocks.4.0.se.conv_reduce.bias:[20]***blocks.4.0.se.conv_expand.weight:[399, 20, 1, 1]***blocks.4.0.se.conv_expand.bias:[399]***blocks.4.0.conv_pwl.weight:[67, 399, 1, 1]***blocks.4.0.bn3.weight:[67]***blocks.4.0.bn3.bias:[67]***blocks.4.0.bn3.running_mean:[67]***blocks.4.0.bn3.running_var:[67]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[201, 67, 1, 1]***blocks.4.1.bn1.weight:[201]***blocks.4.1.bn1.bias:[201]***blocks.4.1.bn1.running_mean:[201]***blocks.4.1.bn1.running_var:[201]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[201, 1, 5, 5]***blocks.4.1.bn2.weight:[201]***blocks.4.1.bn2.bias:[201]***blocks.4.1.bn2.running_mean:[201]***blocks.4.1.bn2.running_var:[201]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[28, 201, 1, 1]***blocks.4.1.se.conv_reduce.bias:[28]***blocks.4.1.se.conv_expand.weight:[201, 28, 1, 1]***blocks.4.1.se.conv_expand.bias:[201]***blocks.4.1.conv_pwl.weight:[67, 201, 1, 1]***blocks.4.1.bn3.weight:[67]***blocks.4.1.bn3.bias:[67]***blocks.4.1.bn3.running_mean:[67]***blocks.4.1.bn3.running_var:[67]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[160, 67, 1, 1]***blocks.4.2.bn1.weight:[160]***blocks.4.2.bn1.bias:[160]***blocks.4.2.bn1.running_mean:[160]***blocks.4.2.bn1.running_var:[160]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[160, 1, 5, 5]***blocks.4.2.bn2.weight:[160]***blocks.4.2.bn2.bias:[160]***blocks.4.2.bn2.running_mean:[160]***blocks.4.2.bn2.running_var:[160]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[28, 160, 1, 1]***blocks.4.2.se.conv_reduce.bias:[28]***blocks.4.2.se.conv_expand.weight:[160, 28, 1, 1]***blocks.4.2.se.conv_expand.bias:[160]***blocks.4.2.conv_pwl.weight:[67, 160, 1, 1]***blocks.4.2.bn3.weight:[67]***blocks.4.2.bn3.bias:[67]***blocks.4.2.bn3.running_mean:[67]***blocks.4.2.bn3.running_var:[67]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[213, 67, 1, 1]***blocks.4.3.bn1.weight:[213]***blocks.4.3.bn1.bias:[213]***blocks.4.3.bn1.running_mean:[213]***blocks.4.3.bn1.running_var:[213]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[213, 1, 5, 5]***blocks.4.3.bn2.weight:[213]***blocks.4.3.bn2.bias:[213]***blocks.4.3.bn2.running_mean:[213]***blocks.4.3.bn2.running_var:[213]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[28, 213, 1, 1]***blocks.4.3.se.conv_reduce.bias:[28]***blocks.4.3.se.conv_expand.weight:[213, 28, 1, 1]***blocks.4.3.se.conv_expand.bias:[213]***blocks.4.3.conv_pwl.weight:[67, 213, 1, 1]***blocks.4.3.bn3.weight:[67]***blocks.4.3.bn3.bias:[67]***blocks.4.3.bn3.running_mean:[67]***blocks.4.3.bn3.running_var:[67]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[637, 67, 1, 1]***blocks.5.0.bn1.weight:[637]***blocks.5.0.bn1.bias:[637]***blocks.5.0.bn1.running_mean:[637]***blocks.5.0.bn1.running_var:[637]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[637, 1, 5, 5]***blocks.5.0.bn2.weight:[637]***blocks.5.0.bn2.bias:[637]***blocks.5.0.bn2.running_mean:[637]***blocks.5.0.bn2.running_var:[637]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[27, 637, 1, 1]***blocks.5.0.se.conv_reduce.bias:[27]***blocks.5.0.se.conv_expand.weight:[637, 27, 1, 1]***blocks.5.0.se.conv_expand.bias:[637]***blocks.5.0.conv_pwl.weight:[192, 637, 1, 1]***blocks.5.0.bn3.weight:[192]***blocks.5.0.bn3.bias:[192]***blocks.5.0.bn3.running_mean:[192]***blocks.5.0.bn3.running_var:[192]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[806, 192, 1, 1]***blocks.5.1.bn1.weight:[806]***blocks.5.1.bn1.bias:[806]***blocks.5.1.bn1.running_mean:[806]***blocks.5.1.bn1.running_var:[806]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[806, 1, 5, 5]***blocks.5.1.bn2.weight:[806]***blocks.5.1.bn2.bias:[806]***blocks.5.1.bn2.running_mean:[806]***blocks.5.1.bn2.running_var:[806]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[48, 806, 1, 1]***blocks.5.1.se.conv_reduce.bias:[48]***blocks.5.1.se.conv_expand.weight:[806, 48, 1, 1]***blocks.5.1.se.conv_expand.bias:[806]***blocks.5.1.conv_pwl.weight:[192, 806, 1, 1]***blocks.5.1.bn3.weight:[192]***blocks.5.1.bn3.bias:[192]***blocks.5.1.bn3.running_mean:[192]***blocks.5.1.bn3.running_var:[192]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[798, 192, 1, 1]***blocks.5.2.bn1.weight:[798]***blocks.5.2.bn1.bias:[798]***blocks.5.2.bn1.running_mean:[798]***blocks.5.2.bn1.running_var:[798]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[798, 1, 5, 5]***blocks.5.2.bn2.weight:[798]***blocks.5.2.bn2.bias:[798]***blocks.5.2.bn2.running_mean:[798]***blocks.5.2.bn2.running_var:[798]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[48, 798, 1, 1]***blocks.5.2.se.conv_reduce.bias:[48]***blocks.5.2.se.conv_expand.weight:[798, 48, 1, 1]***blocks.5.2.se.conv_expand.bias:[798]***blocks.5.2.conv_pwl.weight:[192, 798, 1, 1]***blocks.5.2.bn3.weight:[192]***blocks.5.2.bn3.bias:[192]***blocks.5.2.bn3.running_mean:[192]***blocks.5.2.bn3.running_var:[192]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[891, 192, 1, 1]***blocks.5.3.bn1.weight:[891]***blocks.5.3.bn1.bias:[891]***blocks.5.3.bn1.running_mean:[891]***blocks.5.3.bn1.running_var:[891]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[891, 1, 5, 5]***blocks.5.3.bn2.weight:[891]***blocks.5.3.bn2.bias:[891]***blocks.5.3.bn2.running_mean:[891]***blocks.5.3.bn2.running_var:[891]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[48, 891, 1, 1]***blocks.5.3.se.conv_reduce.bias:[48]***blocks.5.3.se.conv_expand.weight:[891, 48, 1, 1]***blocks.5.3.se.conv_expand.bias:[891]***blocks.5.3.conv_pwl.weight:[192, 891, 1, 1]***blocks.5.3.bn3.weight:[192]***blocks.5.3.bn3.bias:[192]***blocks.5.3.bn3.running_mean:[192]***blocks.5.3.bn3.running_var:[192]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[990, 192, 1, 1]***blocks.5.4.bn1.weight:[990]***blocks.5.4.bn1.bias:[990]***blocks.5.4.bn1.running_mean:[990]***blocks.5.4.bn1.running_var:[990]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[990, 1, 5, 5]***blocks.5.4.bn2.weight:[990]***blocks.5.4.bn2.bias:[990]***blocks.5.4.bn2.running_mean:[990]***blocks.5.4.bn2.running_var:[990]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[48, 990, 1, 1]***blocks.5.4.se.conv_reduce.bias:[48]***blocks.5.4.se.conv_expand.weight:[990, 48, 1, 1]***blocks.5.4.se.conv_expand.bias:[990]***blocks.5.4.conv_pwl.weight:[192, 990, 1, 1]***blocks.5.4.bn3.weight:[192]***blocks.5.4.bn3.bias:[192]***blocks.5.4.bn3.running_mean:[192]***blocks.5.4.bn3.running_var:[192]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1152, 192, 1, 1]***blocks.6.0.bn1.weight:[1152]***blocks.6.0.bn1.bias:[1152]***blocks.6.0.bn1.running_mean:[1152]***blocks.6.0.bn1.running_var:[1152]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1152, 1, 3, 3]***blocks.6.0.bn2.weight:[1152]***blocks.6.0.bn2.bias:[1152]***blocks.6.0.bn2.running_mean:[1152]***blocks.6.0.bn2.running_var:[1152]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[48, 1152, 1, 1]***blocks.6.0.se.conv_reduce.bias:[48]***blocks.6.0.se.conv_expand.weight:[1152, 48, 1, 1]***blocks.6.0.se.conv_expand.bias:[1152]***blocks.6.0.conv_pwl.weight:[320, 1152, 1, 1]***blocks.6.0.bn3.weight:[320]***blocks.6.0.bn3.bias:[320]***blocks.6.0.bn3.running_mean:[320]***blocks.6.0.bn3.running_var:[320]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[1912, 320, 1, 1]***blocks.6.1.bn1.weight:[1912]***blocks.6.1.bn1.bias:[1912]***blocks.6.1.bn1.running_mean:[1912]***blocks.6.1.bn1.running_var:[1912]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[1912, 1, 3, 3]***blocks.6.1.bn2.weight:[1912]***blocks.6.1.bn2.bias:[1912]***blocks.6.1.bn2.running_mean:[1912]***blocks.6.1.bn2.running_var:[1912]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[80, 1912, 1, 1]***blocks.6.1.se.conv_reduce.bias:[80]***blocks.6.1.se.conv_expand.weight:[1912, 80, 1, 1]***blocks.6.1.se.conv_expand.bias:[1912]***blocks.6.1.conv_pwl.weight:[320, 1912, 1, 1]***blocks.6.1.bn3.weight:[320]***blocks.6.1.bn3.bias:[320]***blocks.6.1.bn3.running_mean:[320]***blocks.6.1.bn3.running_var:[320]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1280, 320, 1, 1]***bn2.weight:[1280]***bn2.bias:[1280]***bn2.running_mean:[1280]***bn2.running_var:[1280]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1280]***classifier.bias:[1000] \ No newline at end of file diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/pruned/efficientnet_b2_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..6e3fadee3e9f92eaade96afd8691a5e4437551ee --- /dev/null +++ b/timm/models/pruned/efficientnet_b2_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[54, 16, 1, 1]***blocks.1.0.bn1.weight:[54]***blocks.1.0.bn1.bias:[54]***blocks.1.0.bn1.running_mean:[54]***blocks.1.0.bn1.running_var:[54]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[54, 1, 3, 3]***blocks.1.0.bn2.weight:[54]***blocks.1.0.bn2.bias:[54]***blocks.1.0.bn2.running_mean:[54]***blocks.1.0.bn2.running_var:[54]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 54, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[54, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[54]***blocks.1.0.conv_pwl.weight:[17, 54, 1, 1]***blocks.1.0.bn3.weight:[17]***blocks.1.0.bn3.bias:[17]***blocks.1.0.bn3.running_mean:[17]***blocks.1.0.bn3.running_var:[17]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[69, 17, 1, 1]***blocks.1.1.bn1.weight:[69]***blocks.1.1.bn1.bias:[69]***blocks.1.1.bn1.running_mean:[69]***blocks.1.1.bn1.running_var:[69]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[69, 1, 3, 3]***blocks.1.1.bn2.weight:[69]***blocks.1.1.bn2.bias:[69]***blocks.1.1.bn2.running_mean:[69]***blocks.1.1.bn2.running_var:[69]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 69, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[69, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[69]***blocks.1.1.conv_pwl.weight:[17, 69, 1, 1]***blocks.1.1.bn3.weight:[17]***blocks.1.1.bn3.bias:[17]***blocks.1.1.bn3.running_mean:[17]***blocks.1.1.bn3.running_var:[17]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[61, 17, 1, 1]***blocks.1.2.bn1.weight:[61]***blocks.1.2.bn1.bias:[61]***blocks.1.2.bn1.running_mean:[61]***blocks.1.2.bn1.running_var:[61]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[61, 1, 3, 3]***blocks.1.2.bn2.weight:[61]***blocks.1.2.bn2.bias:[61]***blocks.1.2.bn2.running_mean:[61]***blocks.1.2.bn2.running_var:[61]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 61, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[61, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[61]***blocks.1.2.conv_pwl.weight:[17, 61, 1, 1]***blocks.1.2.bn3.weight:[17]***blocks.1.2.bn3.bias:[17]***blocks.1.2.bn3.running_mean:[17]***blocks.1.2.bn3.running_var:[17]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[86, 17, 1, 1]***blocks.2.0.bn1.weight:[86]***blocks.2.0.bn1.bias:[86]***blocks.2.0.bn1.running_mean:[86]***blocks.2.0.bn1.running_var:[86]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[86, 1, 5, 5]***blocks.2.0.bn2.weight:[86]***blocks.2.0.bn2.bias:[86]***blocks.2.0.bn2.running_mean:[86]***blocks.2.0.bn2.running_var:[86]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 86, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[86, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[86]***blocks.2.0.conv_pwl.weight:[42, 86, 1, 1]***blocks.2.0.bn3.weight:[42]***blocks.2.0.bn3.bias:[42]***blocks.2.0.bn3.running_mean:[42]***blocks.2.0.bn3.running_var:[42]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[72, 42, 1, 1]***blocks.2.1.bn1.weight:[72]***blocks.2.1.bn1.bias:[72]***blocks.2.1.bn1.running_mean:[72]***blocks.2.1.bn1.running_var:[72]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[72, 1, 5, 5]***blocks.2.1.bn2.weight:[72]***blocks.2.1.bn2.bias:[72]***blocks.2.1.bn2.running_mean:[72]***blocks.2.1.bn2.running_var:[72]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 72, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[72, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[72]***blocks.2.1.conv_pwl.weight:[42, 72, 1, 1]***blocks.2.1.bn3.weight:[42]***blocks.2.1.bn3.bias:[42]***blocks.2.1.bn3.running_mean:[42]***blocks.2.1.bn3.running_var:[42]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[98, 42, 1, 1]***blocks.2.2.bn1.weight:[98]***blocks.2.2.bn1.bias:[98]***blocks.2.2.bn1.running_mean:[98]***blocks.2.2.bn1.running_var:[98]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[98, 1, 5, 5]***blocks.2.2.bn2.weight:[98]***blocks.2.2.bn2.bias:[98]***blocks.2.2.bn2.running_mean:[98]***blocks.2.2.bn2.running_var:[98]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 98, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[98, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[98]***blocks.2.2.conv_pwl.weight:[42, 98, 1, 1]***blocks.2.2.bn3.weight:[42]***blocks.2.2.bn3.bias:[42]***blocks.2.2.bn3.running_mean:[42]***blocks.2.2.bn3.running_var:[42]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[245, 42, 1, 1]***blocks.3.0.bn1.weight:[245]***blocks.3.0.bn1.bias:[245]***blocks.3.0.bn1.running_mean:[245]***blocks.3.0.bn1.running_var:[245]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[245, 1, 3, 3]***blocks.3.0.bn2.weight:[245]***blocks.3.0.bn2.bias:[245]***blocks.3.0.bn2.running_mean:[245]***blocks.3.0.bn2.running_var:[245]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 245, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[245, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[245]***blocks.3.0.conv_pwl.weight:[85, 245, 1, 1]***blocks.3.0.bn3.weight:[85]***blocks.3.0.bn3.bias:[85]***blocks.3.0.bn3.running_mean:[85]***blocks.3.0.bn3.running_var:[85]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[274, 85, 1, 1]***blocks.3.1.bn1.weight:[274]***blocks.3.1.bn1.bias:[274]***blocks.3.1.bn1.running_mean:[274]***blocks.3.1.bn1.running_var:[274]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[274, 1, 3, 3]***blocks.3.1.bn2.weight:[274]***blocks.3.1.bn2.bias:[274]***blocks.3.1.bn2.running_mean:[274]***blocks.3.1.bn2.running_var:[274]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[22, 274, 1, 1]***blocks.3.1.se.conv_reduce.bias:[22]***blocks.3.1.se.conv_expand.weight:[274, 22, 1, 1]***blocks.3.1.se.conv_expand.bias:[274]***blocks.3.1.conv_pwl.weight:[85, 274, 1, 1]***blocks.3.1.bn3.weight:[85]***blocks.3.1.bn3.bias:[85]***blocks.3.1.bn3.running_mean:[85]***blocks.3.1.bn3.running_var:[85]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[254, 85, 1, 1]***blocks.3.2.bn1.weight:[254]***blocks.3.2.bn1.bias:[254]***blocks.3.2.bn1.running_mean:[254]***blocks.3.2.bn1.running_var:[254]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[254, 1, 3, 3]***blocks.3.2.bn2.weight:[254]***blocks.3.2.bn2.bias:[254]***blocks.3.2.bn2.running_mean:[254]***blocks.3.2.bn2.running_var:[254]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[22, 254, 1, 1]***blocks.3.2.se.conv_reduce.bias:[22]***blocks.3.2.se.conv_expand.weight:[254, 22, 1, 1]***blocks.3.2.se.conv_expand.bias:[254]***blocks.3.2.conv_pwl.weight:[85, 254, 1, 1]***blocks.3.2.bn3.weight:[85]***blocks.3.2.bn3.bias:[85]***blocks.3.2.bn3.running_mean:[85]***blocks.3.2.bn3.running_var:[85]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[292, 85, 1, 1]***blocks.3.3.bn1.weight:[292]***blocks.3.3.bn1.bias:[292]***blocks.3.3.bn1.running_mean:[292]***blocks.3.3.bn1.running_var:[292]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[292, 1, 3, 3]***blocks.3.3.bn2.weight:[292]***blocks.3.3.bn2.bias:[292]***blocks.3.3.bn2.running_mean:[292]***blocks.3.3.bn2.running_var:[292]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[22, 292, 1, 1]***blocks.3.3.se.conv_reduce.bias:[22]***blocks.3.3.se.conv_expand.weight:[292, 22, 1, 1]***blocks.3.3.se.conv_expand.bias:[292]***blocks.3.3.conv_pwl.weight:[85, 292, 1, 1]***blocks.3.3.bn3.weight:[85]***blocks.3.3.bn3.bias:[85]***blocks.3.3.bn3.running_mean:[85]***blocks.3.3.bn3.running_var:[85]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[502, 85, 1, 1]***blocks.4.0.bn1.weight:[502]***blocks.4.0.bn1.bias:[502]***blocks.4.0.bn1.running_mean:[502]***blocks.4.0.bn1.running_var:[502]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[502, 1, 5, 5]***blocks.4.0.bn2.weight:[502]***blocks.4.0.bn2.bias:[502]***blocks.4.0.bn2.running_mean:[502]***blocks.4.0.bn2.running_var:[502]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[22, 502, 1, 1]***blocks.4.0.se.conv_reduce.bias:[22]***blocks.4.0.se.conv_expand.weight:[502, 22, 1, 1]***blocks.4.0.se.conv_expand.bias:[502]***blocks.4.0.conv_pwl.weight:[116, 502, 1, 1]***blocks.4.0.bn3.weight:[116]***blocks.4.0.bn3.bias:[116]***blocks.4.0.bn3.running_mean:[116]***blocks.4.0.bn3.running_var:[116]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[315, 116, 1, 1]***blocks.4.1.bn1.weight:[315]***blocks.4.1.bn1.bias:[315]***blocks.4.1.bn1.running_mean:[315]***blocks.4.1.bn1.running_var:[315]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[315, 1, 5, 5]***blocks.4.1.bn2.weight:[315]***blocks.4.1.bn2.bias:[315]***blocks.4.1.bn2.running_mean:[315]***blocks.4.1.bn2.running_var:[315]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[30, 315, 1, 1]***blocks.4.1.se.conv_reduce.bias:[30]***blocks.4.1.se.conv_expand.weight:[315, 30, 1, 1]***blocks.4.1.se.conv_expand.bias:[315]***blocks.4.1.conv_pwl.weight:[116, 315, 1, 1]***blocks.4.1.bn3.weight:[116]***blocks.4.1.bn3.bias:[116]***blocks.4.1.bn3.running_mean:[116]***blocks.4.1.bn3.running_var:[116]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[354, 116, 1, 1]***blocks.4.2.bn1.weight:[354]***blocks.4.2.bn1.bias:[354]***blocks.4.2.bn1.running_mean:[354]***blocks.4.2.bn1.running_var:[354]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[354, 1, 5, 5]***blocks.4.2.bn2.weight:[354]***blocks.4.2.bn2.bias:[354]***blocks.4.2.bn2.running_mean:[354]***blocks.4.2.bn2.running_var:[354]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[30, 354, 1, 1]***blocks.4.2.se.conv_reduce.bias:[30]***blocks.4.2.se.conv_expand.weight:[354, 30, 1, 1]***blocks.4.2.se.conv_expand.bias:[354]***blocks.4.2.conv_pwl.weight:[116, 354, 1, 1]***blocks.4.2.bn3.weight:[116]***blocks.4.2.bn3.bias:[116]***blocks.4.2.bn3.running_mean:[116]***blocks.4.2.bn3.running_var:[116]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[443, 116, 1, 1]***blocks.4.3.bn1.weight:[443]***blocks.4.3.bn1.bias:[443]***blocks.4.3.bn1.running_mean:[443]***blocks.4.3.bn1.running_var:[443]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[443, 1, 5, 5]***blocks.4.3.bn2.weight:[443]***blocks.4.3.bn2.bias:[443]***blocks.4.3.bn2.running_mean:[443]***blocks.4.3.bn2.running_var:[443]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[30, 443, 1, 1]***blocks.4.3.se.conv_reduce.bias:[30]***blocks.4.3.se.conv_expand.weight:[443, 30, 1, 1]***blocks.4.3.se.conv_expand.bias:[443]***blocks.4.3.conv_pwl.weight:[116, 443, 1, 1]***blocks.4.3.bn3.weight:[116]***blocks.4.3.bn3.bias:[116]***blocks.4.3.bn3.running_mean:[116]***blocks.4.3.bn3.running_var:[116]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[719, 116, 1, 1]***blocks.5.0.bn1.weight:[719]***blocks.5.0.bn1.bias:[719]***blocks.5.0.bn1.running_mean:[719]***blocks.5.0.bn1.running_var:[719]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[719, 1, 5, 5]***blocks.5.0.bn2.weight:[719]***blocks.5.0.bn2.bias:[719]***blocks.5.0.bn2.running_mean:[719]***blocks.5.0.bn2.running_var:[719]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[30, 719, 1, 1]***blocks.5.0.se.conv_reduce.bias:[30]***blocks.5.0.se.conv_expand.weight:[719, 30, 1, 1]***blocks.5.0.se.conv_expand.bias:[719]***blocks.5.0.conv_pwl.weight:[208, 719, 1, 1]***blocks.5.0.bn3.weight:[208]***blocks.5.0.bn3.bias:[208]***blocks.5.0.bn3.running_mean:[208]***blocks.5.0.bn3.running_var:[208]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1148, 208, 1, 1]***blocks.5.1.bn1.weight:[1148]***blocks.5.1.bn1.bias:[1148]***blocks.5.1.bn1.running_mean:[1148]***blocks.5.1.bn1.running_var:[1148]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1148, 1, 5, 5]***blocks.5.1.bn2.weight:[1148]***blocks.5.1.bn2.bias:[1148]***blocks.5.1.bn2.running_mean:[1148]***blocks.5.1.bn2.running_var:[1148]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[52, 1148, 1, 1]***blocks.5.1.se.conv_reduce.bias:[52]***blocks.5.1.se.conv_expand.weight:[1148, 52, 1, 1]***blocks.5.1.se.conv_expand.bias:[1148]***blocks.5.1.conv_pwl.weight:[208, 1148, 1, 1]***blocks.5.1.bn3.weight:[208]***blocks.5.1.bn3.bias:[208]***blocks.5.1.bn3.running_mean:[208]***blocks.5.1.bn3.running_var:[208]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[1160, 208, 1, 1]***blocks.5.2.bn1.weight:[1160]***blocks.5.2.bn1.bias:[1160]***blocks.5.2.bn1.running_mean:[1160]***blocks.5.2.bn1.running_var:[1160]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[1160, 1, 5, 5]***blocks.5.2.bn2.weight:[1160]***blocks.5.2.bn2.bias:[1160]***blocks.5.2.bn2.running_mean:[1160]***blocks.5.2.bn2.running_var:[1160]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[52, 1160, 1, 1]***blocks.5.2.se.conv_reduce.bias:[52]***blocks.5.2.se.conv_expand.weight:[1160, 52, 1, 1]***blocks.5.2.se.conv_expand.bias:[1160]***blocks.5.2.conv_pwl.weight:[208, 1160, 1, 1]***blocks.5.2.bn3.weight:[208]***blocks.5.2.bn3.bias:[208]***blocks.5.2.bn3.running_mean:[208]***blocks.5.2.bn3.running_var:[208]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1182, 208, 1, 1]***blocks.5.3.bn1.weight:[1182]***blocks.5.3.bn1.bias:[1182]***blocks.5.3.bn1.running_mean:[1182]***blocks.5.3.bn1.running_var:[1182]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1182, 1, 5, 5]***blocks.5.3.bn2.weight:[1182]***blocks.5.3.bn2.bias:[1182]***blocks.5.3.bn2.running_mean:[1182]***blocks.5.3.bn2.running_var:[1182]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[52, 1182, 1, 1]***blocks.5.3.se.conv_reduce.bias:[52]***blocks.5.3.se.conv_expand.weight:[1182, 52, 1, 1]***blocks.5.3.se.conv_expand.bias:[1182]***blocks.5.3.conv_pwl.weight:[208, 1182, 1, 1]***blocks.5.3.bn3.weight:[208]***blocks.5.3.bn3.bias:[208]***blocks.5.3.bn3.running_mean:[208]***blocks.5.3.bn3.running_var:[208]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1228, 208, 1, 1]***blocks.5.4.bn1.weight:[1228]***blocks.5.4.bn1.bias:[1228]***blocks.5.4.bn1.running_mean:[1228]***blocks.5.4.bn1.running_var:[1228]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1228, 1, 5, 5]***blocks.5.4.bn2.weight:[1228]***blocks.5.4.bn2.bias:[1228]***blocks.5.4.bn2.running_mean:[1228]***blocks.5.4.bn2.running_var:[1228]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[52, 1228, 1, 1]***blocks.5.4.se.conv_reduce.bias:[52]***blocks.5.4.se.conv_expand.weight:[1228, 52, 1, 1]***blocks.5.4.se.conv_expand.bias:[1228]***blocks.5.4.conv_pwl.weight:[208, 1228, 1, 1]***blocks.5.4.bn3.weight:[208]***blocks.5.4.bn3.bias:[208]***blocks.5.4.bn3.running_mean:[208]***blocks.5.4.bn3.running_var:[208]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1248, 208, 1, 1]***blocks.6.0.bn1.weight:[1248]***blocks.6.0.bn1.bias:[1248]***blocks.6.0.bn1.running_mean:[1248]***blocks.6.0.bn1.running_var:[1248]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1248, 1, 3, 3]***blocks.6.0.bn2.weight:[1248]***blocks.6.0.bn2.bias:[1248]***blocks.6.0.bn2.running_mean:[1248]***blocks.6.0.bn2.running_var:[1248]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[52, 1248, 1, 1]***blocks.6.0.se.conv_reduce.bias:[52]***blocks.6.0.se.conv_expand.weight:[1248, 52, 1, 1]***blocks.6.0.se.conv_expand.bias:[1248]***blocks.6.0.conv_pwl.weight:[352, 1248, 1, 1]***blocks.6.0.bn3.weight:[352]***blocks.6.0.bn3.bias:[352]***blocks.6.0.bn3.running_mean:[352]***blocks.6.0.bn3.running_var:[352]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2112, 352, 1, 1]***blocks.6.1.bn1.weight:[2112]***blocks.6.1.bn1.bias:[2112]***blocks.6.1.bn1.running_mean:[2112]***blocks.6.1.bn1.running_var:[2112]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2112, 1, 3, 3]***blocks.6.1.bn2.weight:[2112]***blocks.6.1.bn2.bias:[2112]***blocks.6.1.bn2.running_mean:[2112]***blocks.6.1.bn2.running_var:[2112]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[88, 2112, 1, 1]***blocks.6.1.se.conv_reduce.bias:[88]***blocks.6.1.se.conv_expand.weight:[2112, 88, 1, 1]***blocks.6.1.se.conv_expand.bias:[2112]***blocks.6.1.conv_pwl.weight:[352, 2112, 1, 1]***blocks.6.1.bn3.weight:[352]***blocks.6.1.bn3.bias:[352]***blocks.6.1.bn3.running_mean:[352]***blocks.6.1.bn3.running_var:[352]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1408, 352, 1, 1]***bn2.weight:[1408]***bn2.bias:[1408]***bn2.running_mean:[1408]***bn2.running_var:[1408]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1408]***classifier.bias:[1000] \ No newline at end of file diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/pruned/efficientnet_b3_pruned.txt new file mode 100644 index 0000000000000000000000000000000000000000..489781736de08e5cf40bf76528a735fff4a3f61c --- /dev/null +++ b/timm/models/pruned/efficientnet_b3_pruned.txt @@ -0,0 +1 @@ +conv_stem.weight:[40, 3, 3, 3]***bn1.weight:[40]***bn1.bias:[40]***bn1.running_mean:[40]***bn1.running_var:[40]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[40, 1, 3, 3]***blocks.0.0.bn1.weight:[40]***blocks.0.0.bn1.bias:[40]***blocks.0.0.bn1.running_mean:[40]***blocks.0.0.bn1.running_var:[40]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[10, 40, 1, 1]***blocks.0.0.se.conv_reduce.bias:[10]***blocks.0.0.se.conv_expand.weight:[40, 10, 1, 1]***blocks.0.0.se.conv_expand.bias:[40]***blocks.0.0.conv_pw.weight:[24, 40, 1, 1]***blocks.0.0.bn2.weight:[24]***blocks.0.0.bn2.bias:[24]***blocks.0.0.bn2.running_mean:[24]***blocks.0.0.bn2.running_var:[24]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[24, 1, 3, 3]***blocks.0.1.bn1.weight:[24]***blocks.0.1.bn1.bias:[24]***blocks.0.1.bn1.running_mean:[24]***blocks.0.1.bn1.running_var:[24]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[6, 24, 1, 1]***blocks.0.1.se.conv_reduce.bias:[6]***blocks.0.1.se.conv_expand.weight:[24, 6, 1, 1]***blocks.0.1.se.conv_expand.bias:[24]***blocks.0.1.conv_pw.weight:[24, 24, 1, 1]***blocks.0.1.bn2.weight:[24]***blocks.0.1.bn2.bias:[24]***blocks.0.1.bn2.running_mean:[24]***blocks.0.1.bn2.running_var:[24]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[27, 24, 1, 1]***blocks.1.0.bn1.weight:[27]***blocks.1.0.bn1.bias:[27]***blocks.1.0.bn1.running_mean:[27]***blocks.1.0.bn1.running_var:[27]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[27, 1, 3, 3]***blocks.1.0.bn2.weight:[27]***blocks.1.0.bn2.bias:[27]***blocks.1.0.bn2.running_mean:[27]***blocks.1.0.bn2.running_var:[27]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[6, 27, 1, 1]***blocks.1.0.se.conv_reduce.bias:[6]***blocks.1.0.se.conv_expand.weight:[27, 6, 1, 1]***blocks.1.0.se.conv_expand.bias:[27]***blocks.1.0.conv_pwl.weight:[12, 27, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[49, 12, 1, 1]***blocks.1.1.bn1.weight:[49]***blocks.1.1.bn1.bias:[49]***blocks.1.1.bn1.running_mean:[49]***blocks.1.1.bn1.running_var:[49]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[49, 1, 3, 3]***blocks.1.1.bn2.weight:[49]***blocks.1.1.bn2.bias:[49]***blocks.1.1.bn2.running_mean:[49]***blocks.1.1.bn2.running_var:[49]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[8, 49, 1, 1]***blocks.1.1.se.conv_reduce.bias:[8]***blocks.1.1.se.conv_expand.weight:[49, 8, 1, 1]***blocks.1.1.se.conv_expand.bias:[49]***blocks.1.1.conv_pwl.weight:[12, 49, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[8, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[8]***blocks.1.2.se.conv_expand.weight:[48, 8, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[83, 12, 1, 1]***blocks.2.0.bn1.weight:[83]***blocks.2.0.bn1.bias:[83]***blocks.2.0.bn1.running_mean:[83]***blocks.2.0.bn1.running_var:[83]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[83, 1, 5, 5]***blocks.2.0.bn2.weight:[83]***blocks.2.0.bn2.bias:[83]***blocks.2.0.bn2.running_mean:[83]***blocks.2.0.bn2.running_var:[83]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[8, 83, 1, 1]***blocks.2.0.se.conv_reduce.bias:[8]***blocks.2.0.se.conv_expand.weight:[83, 8, 1, 1]***blocks.2.0.se.conv_expand.bias:[83]***blocks.2.0.conv_pwl.weight:[40, 83, 1, 1]***blocks.2.0.bn3.weight:[40]***blocks.2.0.bn3.bias:[40]***blocks.2.0.bn3.running_mean:[40]***blocks.2.0.bn3.running_var:[40]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[90, 40, 1, 1]***blocks.2.1.bn1.weight:[90]***blocks.2.1.bn1.bias:[90]***blocks.2.1.bn1.running_mean:[90]***blocks.2.1.bn1.running_var:[90]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[90, 1, 5, 5]***blocks.2.1.bn2.weight:[90]***blocks.2.1.bn2.bias:[90]***blocks.2.1.bn2.running_mean:[90]***blocks.2.1.bn2.running_var:[90]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 90, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[90, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[90]***blocks.2.1.conv_pwl.weight:[40, 90, 1, 1]***blocks.2.1.bn3.weight:[40]***blocks.2.1.bn3.bias:[40]***blocks.2.1.bn3.running_mean:[40]***blocks.2.1.bn3.running_var:[40]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[85, 40, 1, 1]***blocks.2.2.bn1.weight:[85]***blocks.2.2.bn1.bias:[85]***blocks.2.2.bn1.running_mean:[85]***blocks.2.2.bn1.running_var:[85]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[85, 1, 5, 5]***blocks.2.2.bn2.weight:[85]***blocks.2.2.bn2.bias:[85]***blocks.2.2.bn2.running_mean:[85]***blocks.2.2.bn2.running_var:[85]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 85, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[85, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[85]***blocks.2.2.conv_pwl.weight:[40, 85, 1, 1]***blocks.2.2.bn3.weight:[40]***blocks.2.2.bn3.bias:[40]***blocks.2.2.bn3.running_mean:[40]***blocks.2.2.bn3.running_var:[40]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[215, 40, 1, 1]***blocks.3.0.bn1.weight:[215]***blocks.3.0.bn1.bias:[215]***blocks.3.0.bn1.running_mean:[215]***blocks.3.0.bn1.running_var:[215]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[215, 1, 3, 3]***blocks.3.0.bn2.weight:[215]***blocks.3.0.bn2.bias:[215]***blocks.3.0.bn2.running_mean:[215]***blocks.3.0.bn2.running_var:[215]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 215, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[215, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[215]***blocks.3.0.conv_pwl.weight:[93, 215, 1, 1]***blocks.3.0.bn3.weight:[93]***blocks.3.0.bn3.bias:[93]***blocks.3.0.bn3.running_mean:[93]***blocks.3.0.bn3.running_var:[93]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[261, 93, 1, 1]***blocks.3.1.bn1.weight:[261]***blocks.3.1.bn1.bias:[261]***blocks.3.1.bn1.running_mean:[261]***blocks.3.1.bn1.running_var:[261]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[261, 1, 3, 3]***blocks.3.1.bn2.weight:[261]***blocks.3.1.bn2.bias:[261]***blocks.3.1.bn2.running_mean:[261]***blocks.3.1.bn2.running_var:[261]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[24, 261, 1, 1]***blocks.3.1.se.conv_reduce.bias:[24]***blocks.3.1.se.conv_expand.weight:[261, 24, 1, 1]***blocks.3.1.se.conv_expand.bias:[261]***blocks.3.1.conv_pwl.weight:[93, 261, 1, 1]***blocks.3.1.bn3.weight:[93]***blocks.3.1.bn3.bias:[93]***blocks.3.1.bn3.running_mean:[93]***blocks.3.1.bn3.running_var:[93]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[219, 93, 1, 1]***blocks.3.2.bn1.weight:[219]***blocks.3.2.bn1.bias:[219]***blocks.3.2.bn1.running_mean:[219]***blocks.3.2.bn1.running_var:[219]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[219, 1, 3, 3]***blocks.3.2.bn2.weight:[219]***blocks.3.2.bn2.bias:[219]***blocks.3.2.bn2.running_mean:[219]***blocks.3.2.bn2.running_var:[219]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[24, 219, 1, 1]***blocks.3.2.se.conv_reduce.bias:[24]***blocks.3.2.se.conv_expand.weight:[219, 24, 1, 1]***blocks.3.2.se.conv_expand.bias:[219]***blocks.3.2.conv_pwl.weight:[93, 219, 1, 1]***blocks.3.2.bn3.weight:[93]***blocks.3.2.bn3.bias:[93]***blocks.3.2.bn3.running_mean:[93]***blocks.3.2.bn3.running_var:[93]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[254, 93, 1, 1]***blocks.3.3.bn1.weight:[254]***blocks.3.3.bn1.bias:[254]***blocks.3.3.bn1.running_mean:[254]***blocks.3.3.bn1.running_var:[254]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[254, 1, 3, 3]***blocks.3.3.bn2.weight:[254]***blocks.3.3.bn2.bias:[254]***blocks.3.3.bn2.running_mean:[254]***blocks.3.3.bn2.running_var:[254]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[24, 254, 1, 1]***blocks.3.3.se.conv_reduce.bias:[24]***blocks.3.3.se.conv_expand.weight:[254, 24, 1, 1]***blocks.3.3.se.conv_expand.bias:[254]***blocks.3.3.conv_pwl.weight:[93, 254, 1, 1]***blocks.3.3.bn3.weight:[93]***blocks.3.3.bn3.bias:[93]***blocks.3.3.bn3.running_mean:[93]***blocks.3.3.bn3.running_var:[93]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.3.4.conv_pw.weight:[236, 93, 1, 1]***blocks.3.4.bn1.weight:[236]***blocks.3.4.bn1.bias:[236]***blocks.3.4.bn1.running_mean:[236]***blocks.3.4.bn1.running_var:[236]***blocks.3.4.bn1.num_batches_tracked:[]***blocks.3.4.conv_dw.weight:[236, 1, 3, 3]***blocks.3.4.bn2.weight:[236]***blocks.3.4.bn2.bias:[236]***blocks.3.4.bn2.running_mean:[236]***blocks.3.4.bn2.running_var:[236]***blocks.3.4.bn2.num_batches_tracked:[]***blocks.3.4.se.conv_reduce.weight:[24, 236, 1, 1]***blocks.3.4.se.conv_reduce.bias:[24]***blocks.3.4.se.conv_expand.weight:[236, 24, 1, 1]***blocks.3.4.se.conv_expand.bias:[236]***blocks.3.4.conv_pwl.weight:[93, 236, 1, 1]***blocks.3.4.bn3.weight:[93]***blocks.3.4.bn3.bias:[93]***blocks.3.4.bn3.running_mean:[93]***blocks.3.4.bn3.running_var:[93]***blocks.3.4.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[480, 93, 1, 1]***blocks.4.0.bn1.weight:[480]***blocks.4.0.bn1.bias:[480]***blocks.4.0.bn1.running_mean:[480]***blocks.4.0.bn1.running_var:[480]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[480, 1, 5, 5]***blocks.4.0.bn2.weight:[480]***blocks.4.0.bn2.bias:[480]***blocks.4.0.bn2.running_mean:[480]***blocks.4.0.bn2.running_var:[480]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[24, 480, 1, 1]***blocks.4.0.se.conv_reduce.bias:[24]***blocks.4.0.se.conv_expand.weight:[480, 24, 1, 1]***blocks.4.0.se.conv_expand.bias:[480]***blocks.4.0.conv_pwl.weight:[120, 480, 1, 1]***blocks.4.0.bn3.weight:[120]***blocks.4.0.bn3.bias:[120]***blocks.4.0.bn3.running_mean:[120]***blocks.4.0.bn3.running_var:[120]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[235, 120, 1, 1]***blocks.4.1.bn1.weight:[235]***blocks.4.1.bn1.bias:[235]***blocks.4.1.bn1.running_mean:[235]***blocks.4.1.bn1.running_var:[235]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[235, 1, 5, 5]***blocks.4.1.bn2.weight:[235]***blocks.4.1.bn2.bias:[235]***blocks.4.1.bn2.running_mean:[235]***blocks.4.1.bn2.running_var:[235]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[34, 235, 1, 1]***blocks.4.1.se.conv_reduce.bias:[34]***blocks.4.1.se.conv_expand.weight:[235, 34, 1, 1]***blocks.4.1.se.conv_expand.bias:[235]***blocks.4.1.conv_pwl.weight:[120, 235, 1, 1]***blocks.4.1.bn3.weight:[120]***blocks.4.1.bn3.bias:[120]***blocks.4.1.bn3.running_mean:[120]***blocks.4.1.bn3.running_var:[120]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[217, 120, 1, 1]***blocks.4.2.bn1.weight:[217]***blocks.4.2.bn1.bias:[217]***blocks.4.2.bn1.running_mean:[217]***blocks.4.2.bn1.running_var:[217]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[217, 1, 5, 5]***blocks.4.2.bn2.weight:[217]***blocks.4.2.bn2.bias:[217]***blocks.4.2.bn2.running_mean:[217]***blocks.4.2.bn2.running_var:[217]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[34, 217, 1, 1]***blocks.4.2.se.conv_reduce.bias:[34]***blocks.4.2.se.conv_expand.weight:[217, 34, 1, 1]***blocks.4.2.se.conv_expand.bias:[217]***blocks.4.2.conv_pwl.weight:[120, 217, 1, 1]***blocks.4.2.bn3.weight:[120]***blocks.4.2.bn3.bias:[120]***blocks.4.2.bn3.running_mean:[120]***blocks.4.2.bn3.running_var:[120]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[226, 120, 1, 1]***blocks.4.3.bn1.weight:[226]***blocks.4.3.bn1.bias:[226]***blocks.4.3.bn1.running_mean:[226]***blocks.4.3.bn1.running_var:[226]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[226, 1, 5, 5]***blocks.4.3.bn2.weight:[226]***blocks.4.3.bn2.bias:[226]***blocks.4.3.bn2.running_mean:[226]***blocks.4.3.bn2.running_var:[226]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[33, 226, 1, 1]***blocks.4.3.se.conv_reduce.bias:[33]***blocks.4.3.se.conv_expand.weight:[226, 33, 1, 1]***blocks.4.3.se.conv_expand.bias:[226]***blocks.4.3.conv_pwl.weight:[120, 226, 1, 1]***blocks.4.3.bn3.weight:[120]***blocks.4.3.bn3.bias:[120]***blocks.4.3.bn3.running_mean:[120]***blocks.4.3.bn3.running_var:[120]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.4.4.conv_pw.weight:[340, 120, 1, 1]***blocks.4.4.bn1.weight:[340]***blocks.4.4.bn1.bias:[340]***blocks.4.4.bn1.running_mean:[340]***blocks.4.4.bn1.running_var:[340]***blocks.4.4.bn1.num_batches_tracked:[]***blocks.4.4.conv_dw.weight:[340, 1, 5, 5]***blocks.4.4.bn2.weight:[340]***blocks.4.4.bn2.bias:[340]***blocks.4.4.bn2.running_mean:[340]***blocks.4.4.bn2.running_var:[340]***blocks.4.4.bn2.num_batches_tracked:[]***blocks.4.4.se.conv_reduce.weight:[34, 340, 1, 1]***blocks.4.4.se.conv_reduce.bias:[34]***blocks.4.4.se.conv_expand.weight:[340, 34, 1, 1]***blocks.4.4.se.conv_expand.bias:[340]***blocks.4.4.conv_pwl.weight:[120, 340, 1, 1]***blocks.4.4.bn3.weight:[120]***blocks.4.4.bn3.bias:[120]***blocks.4.4.bn3.running_mean:[120]***blocks.4.4.bn3.running_var:[120]***blocks.4.4.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[802, 120, 1, 1]***blocks.5.0.bn1.weight:[802]***blocks.5.0.bn1.bias:[802]***blocks.5.0.bn1.running_mean:[802]***blocks.5.0.bn1.running_var:[802]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[802, 1, 5, 5]***blocks.5.0.bn2.weight:[802]***blocks.5.0.bn2.bias:[802]***blocks.5.0.bn2.running_mean:[802]***blocks.5.0.bn2.running_var:[802]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[34, 802, 1, 1]***blocks.5.0.se.conv_reduce.bias:[34]***blocks.5.0.se.conv_expand.weight:[802, 34, 1, 1]***blocks.5.0.se.conv_expand.bias:[802]***blocks.5.0.conv_pwl.weight:[232, 802, 1, 1]***blocks.5.0.bn3.weight:[232]***blocks.5.0.bn3.bias:[232]***blocks.5.0.bn3.running_mean:[232]***blocks.5.0.bn3.running_var:[232]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1030, 232, 1, 1]***blocks.5.1.bn1.weight:[1030]***blocks.5.1.bn1.bias:[1030]***blocks.5.1.bn1.running_mean:[1030]***blocks.5.1.bn1.running_var:[1030]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1030, 1, 5, 5]***blocks.5.1.bn2.weight:[1030]***blocks.5.1.bn2.bias:[1030]***blocks.5.1.bn2.running_mean:[1030]***blocks.5.1.bn2.running_var:[1030]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[58, 1030, 1, 1]***blocks.5.1.se.conv_reduce.bias:[58]***blocks.5.1.se.conv_expand.weight:[1030, 58, 1, 1]***blocks.5.1.se.conv_expand.bias:[1030]***blocks.5.1.conv_pwl.weight:[232, 1030, 1, 1]***blocks.5.1.bn3.weight:[232]***blocks.5.1.bn3.bias:[232]***blocks.5.1.bn3.running_mean:[232]***blocks.5.1.bn3.running_var:[232]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[924, 232, 1, 1]***blocks.5.2.bn1.weight:[924]***blocks.5.2.bn1.bias:[924]***blocks.5.2.bn1.running_mean:[924]***blocks.5.2.bn1.running_var:[924]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[924, 1, 5, 5]***blocks.5.2.bn2.weight:[924]***blocks.5.2.bn2.bias:[924]***blocks.5.2.bn2.running_mean:[924]***blocks.5.2.bn2.running_var:[924]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[58, 924, 1, 1]***blocks.5.2.se.conv_reduce.bias:[58]***blocks.5.2.se.conv_expand.weight:[924, 58, 1, 1]***blocks.5.2.se.conv_expand.bias:[924]***blocks.5.2.conv_pwl.weight:[232, 924, 1, 1]***blocks.5.2.bn3.weight:[232]***blocks.5.2.bn3.bias:[232]***blocks.5.2.bn3.running_mean:[232]***blocks.5.2.bn3.running_var:[232]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1016, 232, 1, 1]***blocks.5.3.bn1.weight:[1016]***blocks.5.3.bn1.bias:[1016]***blocks.5.3.bn1.running_mean:[1016]***blocks.5.3.bn1.running_var:[1016]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1016, 1, 5, 5]***blocks.5.3.bn2.weight:[1016]***blocks.5.3.bn2.bias:[1016]***blocks.5.3.bn2.running_mean:[1016]***blocks.5.3.bn2.running_var:[1016]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[58, 1016, 1, 1]***blocks.5.3.se.conv_reduce.bias:[58]***blocks.5.3.se.conv_expand.weight:[1016, 58, 1, 1]***blocks.5.3.se.conv_expand.bias:[1016]***blocks.5.3.conv_pwl.weight:[232, 1016, 1, 1]***blocks.5.3.bn3.weight:[232]***blocks.5.3.bn3.bias:[232]***blocks.5.3.bn3.running_mean:[232]***blocks.5.3.bn3.running_var:[232]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1130, 232, 1, 1]***blocks.5.4.bn1.weight:[1130]***blocks.5.4.bn1.bias:[1130]***blocks.5.4.bn1.running_mean:[1130]***blocks.5.4.bn1.running_var:[1130]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1130, 1, 5, 5]***blocks.5.4.bn2.weight:[1130]***blocks.5.4.bn2.bias:[1130]***blocks.5.4.bn2.running_mean:[1130]***blocks.5.4.bn2.running_var:[1130]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[58, 1130, 1, 1]***blocks.5.4.se.conv_reduce.bias:[58]***blocks.5.4.se.conv_expand.weight:[1130, 58, 1, 1]***blocks.5.4.se.conv_expand.bias:[1130]***blocks.5.4.conv_pwl.weight:[232, 1130, 1, 1]***blocks.5.4.bn3.weight:[232]***blocks.5.4.bn3.bias:[232]***blocks.5.4.bn3.running_mean:[232]***blocks.5.4.bn3.running_var:[232]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.5.5.conv_pw.weight:[1266, 232, 1, 1]***blocks.5.5.bn1.weight:[1266]***blocks.5.5.bn1.bias:[1266]***blocks.5.5.bn1.running_mean:[1266]***blocks.5.5.bn1.running_var:[1266]***blocks.5.5.bn1.num_batches_tracked:[]***blocks.5.5.conv_dw.weight:[1266, 1, 5, 5]***blocks.5.5.bn2.weight:[1266]***blocks.5.5.bn2.bias:[1266]***blocks.5.5.bn2.running_mean:[1266]***blocks.5.5.bn2.running_var:[1266]***blocks.5.5.bn2.num_batches_tracked:[]***blocks.5.5.se.conv_reduce.weight:[58, 1266, 1, 1]***blocks.5.5.se.conv_reduce.bias:[58]***blocks.5.5.se.conv_expand.weight:[1266, 58, 1, 1]***blocks.5.5.se.conv_expand.bias:[1266]***blocks.5.5.conv_pwl.weight:[232, 1266, 1, 1]***blocks.5.5.bn3.weight:[232]***blocks.5.5.bn3.bias:[232]***blocks.5.5.bn3.running_mean:[232]***blocks.5.5.bn3.running_var:[232]***blocks.5.5.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1392, 232, 1, 1]***blocks.6.0.bn1.weight:[1392]***blocks.6.0.bn1.bias:[1392]***blocks.6.0.bn1.running_mean:[1392]***blocks.6.0.bn1.running_var:[1392]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1392, 1, 3, 3]***blocks.6.0.bn2.weight:[1392]***blocks.6.0.bn2.bias:[1392]***blocks.6.0.bn2.running_mean:[1392]***blocks.6.0.bn2.running_var:[1392]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[58, 1392, 1, 1]***blocks.6.0.se.conv_reduce.bias:[58]***blocks.6.0.se.conv_expand.weight:[1392, 58, 1, 1]***blocks.6.0.se.conv_expand.bias:[1392]***blocks.6.0.conv_pwl.weight:[384, 1392, 1, 1]***blocks.6.0.bn3.weight:[384]***blocks.6.0.bn3.bias:[384]***blocks.6.0.bn3.running_mean:[384]***blocks.6.0.bn3.running_var:[384]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2301, 384, 1, 1]***blocks.6.1.bn1.weight:[2301]***blocks.6.1.bn1.bias:[2301]***blocks.6.1.bn1.running_mean:[2301]***blocks.6.1.bn1.running_var:[2301]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2301, 1, 3, 3]***blocks.6.1.bn2.weight:[2301]***blocks.6.1.bn2.bias:[2301]***blocks.6.1.bn2.running_mean:[2301]***blocks.6.1.bn2.running_var:[2301]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[96, 2301, 1, 1]***blocks.6.1.se.conv_reduce.bias:[96]***blocks.6.1.se.conv_expand.weight:[2301, 96, 1, 1]***blocks.6.1.se.conv_expand.bias:[2301]***blocks.6.1.conv_pwl.weight:[384, 2301, 1, 1]***blocks.6.1.bn3.weight:[384]***blocks.6.1.bn3.bias:[384]***blocks.6.1.bn3.running_mean:[384]***blocks.6.1.bn3.running_var:[384]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1536, 384, 1, 1]***bn2.weight:[1536]***bn2.bias:[1536]***bn2.running_mean:[1536]***bn2.running_var:[1536]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1536]***classifier.bias:[1000] \ No newline at end of file diff --git a/timm/models/registry.py b/timm/models/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..f92219b218228baf09ef7ee596c0b1f360347d47 --- /dev/null +++ b/timm/models/registry.py @@ -0,0 +1,149 @@ +""" Model Registry +Hacked together by / Copyright 2020 Ross Wightman +""" + +import sys +import re +import fnmatch +from collections import defaultdict +from copy import deepcopy + +__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', + 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained'] + +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present +_model_default_cfgs = dict() # central repo for model default_cfgs + + +def register_model(fn): + # lookup containing module + mod = sys.modules[fn.__module__] + module_name_split = fn.__module__.split('.') + module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ + if hasattr(mod, '__all__'): + mod.__all__.append(model_name) + else: + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name]) + if has_pretrained: + _model_has_pretrained.add(model_name) + return fn + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): + """ Return list of available model names, sorted alphabetically + + Args: + filter (str) - Wildcard filter string that works with fnmatch + module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') + pretrained (bool) - Include only models with pretrained weights if True + exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) + + Example: + model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' + model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module + """ + if module: + all_models = list(_module_to_models[module]) + else: + all_models = _model_entrypoints.keys() + if filter: + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models + if exclude_filters: + if not isinstance(exclude_filters, (tuple, list)): + exclude_filters = [exclude_filters] + for xf in exclude_filters: + exclude_models = fnmatch.filter(models, xf) # exclude these models + if len(exclude_models): + models = set(models).difference(exclude_models) + if pretrained: + models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) + return list(sorted(models, key=_natural_key)) + + +def is_model(model_name): + """ Check if a model name exists + """ + return model_name in _model_entrypoints + + +def model_entrypoint(model_name): + """Fetch a model entrypoint for specified model name + """ + return _model_entrypoints[model_name] + + +def list_modules(): + """ Return list of module names that contain models / model entrypoints + """ + modules = _module_to_models.keys() + return list(sorted(modules)) + + +def is_model_in_modules(model_name, module_names): + """Check if a model exists within a subset of modules + Args: + model_name (str) - name of model to check + module_names (tuple, list, set) - names of modules to search in + """ + assert isinstance(module_names, (tuple, list, set)) + return any(model_name in _module_to_models[n] for n in module_names) + + +def has_model_default_key(model_name, cfg_key): + """ Query model default_cfgs for existence of a specific key. + """ + if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]: + return True + return False + + +def is_model_default_key(model_name, cfg_key): + """ Return truthy value for specified model default_cfg key, False if does not exist. + """ + if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False): + return True + return False + + +def get_model_default_value(model_name, cfg_key): + """ Get a specific model default_cfg value by key. None if it doesn't exist. + """ + if model_name in _model_default_cfgs: + return _model_default_cfgs[model_name].get(cfg_key, None) + else: + return None + + +def is_model_pretrained(model_name): + return model_name in _model_has_pretrained diff --git a/timm/models/regnet.py b/timm/models/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6a38107467d22e195230663f5eeb03b38c82c125 --- /dev/null +++ b/timm/models/regnet.py @@ -0,0 +1,494 @@ +"""RegNet + +Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 +Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + +Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here) +and cleaned up with more descriptive variable names. + +Weights from original impl have been modified +* first layer from BGR -> RGB as most PyTorch models are +* removed training specific dict entries from checkpoints and keep model state_dict only +* remap names to match the ones here + +Hacked together by / Copyright 2020 Ross Wightman +""" +import numpy as np +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath +from .registry import register_model + + +def _mcfg(**kwargs): + cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) + cfg.update(**kwargs) + return cfg + + +# Model FLOPS = three trailing digits * 10^8 +model_cfgs = dict( + regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), + regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), + regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), + regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), + regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), + regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), + regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), + regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), + regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), + regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), + regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), + regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), + regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), + regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25), + regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), + regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), + regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), + regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), + regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), + regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25), + regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25), + regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25), + regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25), + regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25), +) + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'), + regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'), + regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'), + regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'), + regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'), + regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'), + regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'), + regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'), + regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'), + regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), + regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), + regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), + regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), + regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), + regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), + regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'), + regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'), + regnety_032=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth', + crop_pct=1.0, test_input_size=(3, 288, 288)), + regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'), + regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'), + regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'), + regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'), + regnety_160=_cfg( + url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository + crop_pct=1.0, test_input_size=(3, 288, 288)), + regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), +) + + +def quantize_float(f, q): + """Converts a float to closest non-zero int divisible by q.""" + return int(round(f / q) * q) + + +def adjust_widths_groups_comp(widths, bottle_ratios, groups): + """Adjusts the compatibility of widths and groups.""" + bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)] + groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)] + bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)] + widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)] + return widths, groups + + +def generate_regnet(width_slope, width_initial, width_mult, depth, q=8): + """Generates per block widths from RegNet parameters.""" + assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0 + widths_cont = np.arange(depth) * width_slope + width_initial + width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult)) + widths = width_initial * np.power(width_mult, width_exps) + widths = np.round(np.divide(widths, q)) * q + num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1 + widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist() + return widths, num_stages, max_stage, widths_cont + + +class Bottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, + downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, + drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + bottleneck_chs = int(round(out_chs * bottleneck_ratio)) + groups = bottleneck_chs // group_width + + cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv2 = ConvBnAct( + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation, + groups=groups, **cargs) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels) + else: + self.se = None + cargs['act_layer'] = None + self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + x = self.conv3(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + return x + + +def downsample_conv( + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + dilation = dilation if kernel_size > 1 else 1 + return ConvBnAct( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None) + + +def downsample_avg( + in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + pool = nn.Identity() + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + return nn.Sequential(*[ + pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)]) + + +class RegStage(nn.Module): + """Stage (sequence of blocks w/ the same output shape).""" + + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, + block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None): + super(RegStage, self).__init__() + block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args + first_dilation = 1 if dilation in (1, 2) else 2 + for i in range(depth): + block_stride = stride if i == 0 else 1 + block_in_chs = in_chs if i == 0 else out_chs + block_dilation = first_dilation if i == 0 else dilation + if drop_path_rates is not None and drop_path_rates[i] > 0.: + drop_path = DropPath(drop_path_rates[i]) + else: + drop_path = None + if (block_in_chs != out_chs) or (block_stride != 1): + proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) + else: + proj_block = None + + name = "b{}".format(i + 1) + self.add_module( + name, block_fn( + block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, + downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs) + ) + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +class RegNet(nn.Module): + """RegNet model. + + Paper: https://arxiv.org/abs/2003.13678 + Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py + """ + + def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., + drop_path_rate=0., zero_init_last_bn=True): + super().__init__() + # TODO add drop block, drop path, anti-aliasing, custom bn/act args + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + + # Construct the stem + stem_width = cfg['stem_width'] + self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) + self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] + + # Construct the stages + prev_width = stem_width + curr_stride = 2 + stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) + se_ratio = cfg['se_ratio'] + for i, stage_args in enumerate(stage_params): + stage_name = "s{}".format(i + 1) + self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio)) + prev_width = stage_args['out_chs'] + curr_stride *= stage_args['stride'] + self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] + + # Construct the head + self.num_features = prev_width + self.head = ClassifierHead( + in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.): + # Generate RegNet ws per block + w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] + widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) + + # Convert to per stage format + stage_widths, stage_depths = np.unique(widths, return_counts=True) + + # Use the same group width, bottleneck mult and stride for each stage + stage_groups = [cfg['group_w'] for _ in range(num_stages)] + stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] + stage_strides = [] + stage_dilations = [] + net_stride = 2 + dilation = 1 + for _ in range(num_stages): + if net_stride >= output_stride: + dilation *= default_stride + stride = 1 + else: + stride = default_stride + net_stride *= stride + stage_strides.append(stride) + stage_dilations.append(dilation) + stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1])) + + # Adjust the compatibility of ws and gws + stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) + param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates'] + stage_params = [ + dict(zip(param_names, params)) for params in + zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, + stage_dpr)] + return stage_params + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + for block in list(self.children())[:-1]: + x = block(x) + return x + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +def _filter_fn(state_dict): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if 'model' in state_dict: + # For DeiT trained regnety_160 pretraiend model + state_dict = state_dict['model'] + return state_dict + + +def _create_regnet(variant, pretrained, **kwargs): + return build_model_with_cfg( + RegNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + pretrained_filter_fn=_filter_fn, + **kwargs) + + +@register_model +def regnetx_002(pretrained=False, **kwargs): + """RegNetX-200MF""" + return _create_regnet('regnetx_002', pretrained, **kwargs) + + +@register_model +def regnetx_004(pretrained=False, **kwargs): + """RegNetX-400MF""" + return _create_regnet('regnetx_004', pretrained, **kwargs) + + +@register_model +def regnetx_006(pretrained=False, **kwargs): + """RegNetX-600MF""" + return _create_regnet('regnetx_006', pretrained, **kwargs) + + +@register_model +def regnetx_008(pretrained=False, **kwargs): + """RegNetX-800MF""" + return _create_regnet('regnetx_008', pretrained, **kwargs) + + +@register_model +def regnetx_016(pretrained=False, **kwargs): + """RegNetX-1.6GF""" + return _create_regnet('regnetx_016', pretrained, **kwargs) + + +@register_model +def regnetx_032(pretrained=False, **kwargs): + """RegNetX-3.2GF""" + return _create_regnet('regnetx_032', pretrained, **kwargs) + + +@register_model +def regnetx_040(pretrained=False, **kwargs): + """RegNetX-4.0GF""" + return _create_regnet('regnetx_040', pretrained, **kwargs) + + +@register_model +def regnetx_064(pretrained=False, **kwargs): + """RegNetX-6.4GF""" + return _create_regnet('regnetx_064', pretrained, **kwargs) + + +@register_model +def regnetx_080(pretrained=False, **kwargs): + """RegNetX-8.0GF""" + return _create_regnet('regnetx_080', pretrained, **kwargs) + + +@register_model +def regnetx_120(pretrained=False, **kwargs): + """RegNetX-12GF""" + return _create_regnet('regnetx_120', pretrained, **kwargs) + + +@register_model +def regnetx_160(pretrained=False, **kwargs): + """RegNetX-16GF""" + return _create_regnet('regnetx_160', pretrained, **kwargs) + + +@register_model +def regnetx_320(pretrained=False, **kwargs): + """RegNetX-32GF""" + return _create_regnet('regnetx_320', pretrained, **kwargs) + + +@register_model +def regnety_002(pretrained=False, **kwargs): + """RegNetY-200MF""" + return _create_regnet('regnety_002', pretrained, **kwargs) + + +@register_model +def regnety_004(pretrained=False, **kwargs): + """RegNetY-400MF""" + return _create_regnet('regnety_004', pretrained, **kwargs) + + +@register_model +def regnety_006(pretrained=False, **kwargs): + """RegNetY-600MF""" + return _create_regnet('regnety_006', pretrained, **kwargs) + + +@register_model +def regnety_008(pretrained=False, **kwargs): + """RegNetY-800MF""" + return _create_regnet('regnety_008', pretrained, **kwargs) + + +@register_model +def regnety_016(pretrained=False, **kwargs): + """RegNetY-1.6GF""" + return _create_regnet('regnety_016', pretrained, **kwargs) + + +@register_model +def regnety_032(pretrained=False, **kwargs): + """RegNetY-3.2GF""" + return _create_regnet('regnety_032', pretrained, **kwargs) + + +@register_model +def regnety_040(pretrained=False, **kwargs): + """RegNetY-4.0GF""" + return _create_regnet('regnety_040', pretrained, **kwargs) + + +@register_model +def regnety_064(pretrained=False, **kwargs): + """RegNetY-6.4GF""" + return _create_regnet('regnety_064', pretrained, **kwargs) + + +@register_model +def regnety_080(pretrained=False, **kwargs): + """RegNetY-8.0GF""" + return _create_regnet('regnety_080', pretrained, **kwargs) + + +@register_model +def regnety_120(pretrained=False, **kwargs): + """RegNetY-12GF""" + return _create_regnet('regnety_120', pretrained, **kwargs) + + +@register_model +def regnety_160(pretrained=False, **kwargs): + """RegNetY-16GF""" + return _create_regnet('regnety_160', pretrained, **kwargs) + + +@register_model +def regnety_320(pretrained=False, **kwargs): + """RegNetY-32GF""" + return _create_regnet('regnety_320', pretrained, **kwargs) diff --git a/timm/models/res2net.py b/timm/models/res2net.py new file mode 100644 index 0000000000000000000000000000000000000000..282baba3b04f7805b16ffeaef55dd2b19b434f0c --- /dev/null +++ b/timm/models/res2net.py @@ -0,0 +1,216 @@ +""" Res2Net and Res2NeXt +Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/ +Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169 +""" +import math + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .resnet import ResNet + +__all__ = [] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'res2net50_26w_4s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'), + 'res2net50_48w_2s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'), + 'res2net50_14w_8s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'), + 'res2net50_26w_6s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'), + 'res2net50_26w_8s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'), + 'res2net101_26w_4s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'), + 'res2next50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'), +} + + +class Bottle2neck(nn.Module): + """ Res2Net/Res2NeXT Bottleneck + Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py + """ + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): + super(Bottle2neck, self).__init__() + self.scale = scale + self.is_first = stride > 1 or downsample is not None + self.num_scales = max(1, scale - 1) + width = int(math.floor(planes * (base_width / 64.0))) * cardinality + self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) + self.bn1 = norm_layer(width * scale) + + convs = [] + bns = [] + for i in range(self.num_scales): + convs.append(nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) + bns.append(norm_layer(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + if self.is_first: + # FIXME this should probably have count_include_pad=False, but hurts original weights + self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) + else: + self.pool = None + + self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = attn_layer(outplanes) if attn_layer is not None else None + + self.relu = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = torch.split(out, self.width, 1) + spo = [] + sp = spx[0] # redundant, for torchscript + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + if i == 0 or self.is_first: + sp = spx[i] + else: + sp = sp + spx[i] + sp = conv(sp) + sp = bn(sp) + sp = self.relu(sp) + spo.append(sp) + if self.scale > 1: + if self.pool is not None: + # self.is_first == True, None check for torchscript + spo.append(self.pool(spx[-1])) + else: + spo.append(spx[-1]) + out = torch.cat(spo, 1) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out += shortcut + out = self.relu(out) + + return out + + +def _create_res2net(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def res2net50_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net50_26w_4s', pretrained, **model_args) + + +@register_model +def res2net101_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-101 26w4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2net101_26w_4s', pretrained, **model_args) + + +@register_model +def res2net50_26w_6s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w6s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs) + return _create_res2net('res2net50_26w_6s', pretrained, **model_args) + + +@register_model +def res2net50_26w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 26w8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_26w_8s', pretrained, **model_args) + + +@register_model +def res2net50_48w_2s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 48w2s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs) + return _create_res2net('res2net50_48w_2s', pretrained, **model_args) + + +@register_model +def res2net50_14w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50 14w8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs) + return _create_res2net('res2net50_14w_8s', pretrained, **model_args) + + +@register_model +def res2next50(pretrained=False, **kwargs): + """Construct Res2NeXt-50 4s + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model_args = dict( + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs) + return _create_res2net('res2next50', pretrained, **model_args) diff --git a/timm/models/resnest.py b/timm/models/resnest.py new file mode 100644 index 0000000000000000000000000000000000000000..31eebd8092a75e949a7592833f00f05c0a5a9be7 --- /dev/null +++ b/timm/models/resnest.py @@ -0,0 +1,237 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +import torch +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SplitAttn +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1.0', 'classifier': 'fc', + **kwargs + } + +default_cfgs = { + 'resnest14d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'), + 'resnest26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'), + 'resnest50d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth'), + 'resnest101e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'resnest200e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'), + 'resnest269e': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth', + input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'), + 'resnest50d_4s2x40d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth', + interpolation='bicubic'), + 'resnest50d_1s4x24d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth', + interpolation='bicubic') +} + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + self.drop_block = drop_block + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttn( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) + self.bn2 = nn.Identity() + self.act2 = nn.Identity() + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.drop_block is not None: + out = self.drop_block(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out += shortcut + out = self.act3(out) + return out + + +def _create_resnest(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def resnest14d(pretrained=False, **kwargs): + """ ResNeSt-14d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[1, 1, 1, 1], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest26d(pretrained=False, **kwargs): + """ ResNeSt-26d model. Weights ported from GluonCV. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[2, 2, 2, 2], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d(pretrained=False, **kwargs): + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest101e(pretrained=False, **kwargs): + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 23, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest200e(pretrained=False, **kwargs): + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 24, 36, 3], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest269e(pretrained=False, **kwargs): + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 30, 48, 8], + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d_4s2x40d(pretrained=False, **kwargs): + """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, + block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) + return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs) + + +@register_model +def resnest50d_1s4x24d(pretrained=False, **kwargs): + """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + model_kwargs = dict( + block=ResNestBottleneck, layers=[3, 4, 6, 3], + stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, + block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) + return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/resnet.py b/timm/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..66baa37a90dbb2f2cdb510bc0b988cd25bd5887a --- /dev/null +++ b/timm/models/resnet.py @@ -0,0 +1,1455 @@ +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +additional dropout and dynamic global avg/max pool. + +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman +Copyright 2020 Ross Wightman +""" +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, get_attn, create_classifier +from .registry import register_model + +__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + # ResNet and Wide ResNet + 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), + 'resnet18d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), + 'resnet34d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet26': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth', + interpolation='bicubic'), + 'resnet26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet26t': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), + 'resnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', + interpolation='bicubic'), + 'resnet50d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet50t': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'resnet101': _cfg(url='', interpolation='bicubic'), + 'resnet101d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320)), + 'resnet152': _cfg(url='', interpolation='bicubic'), + 'resnet152d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320)), + 'resnet200': _cfg(url='', interpolation='bicubic'), + 'resnet200d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320)), + 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'tv_resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), + 'tv_resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), + 'wide_resnet50_2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth', + interpolation='bicubic'), + 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + + # ResNeXt + 'resnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth', + interpolation='bicubic'), + 'resnext50d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'resnext101_32x4d': _cfg(url=''), + 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), + 'resnext101_64x4d': _cfg(url=''), + 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), + + # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags + # from https://github.com/facebookresearch/WSL-Images + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), + 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), + 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), + 'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'), + + # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'ssl_resnet18': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'), + 'ssl_resnet50': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth'), + 'ssl_resnext50_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth'), + 'ssl_resnext101_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth'), + 'ssl_resnext101_32x8d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'), + 'ssl_resnext101_32x16d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'), + + # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models + # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only. + 'swsl_resnet18': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'), + 'swsl_resnet50': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth'), + 'swsl_resnext50_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth'), + 'swsl_resnext101_32x4d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth'), + 'swsl_resnext101_32x8d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'), + 'swsl_resnext101_32x16d': _cfg( + url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), + + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py + 'seresnet18': _cfg( + url='', + interpolation='bicubic'), + 'seresnet34': _cfg( + url='', + interpolation='bicubic'), + 'seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth', + interpolation='bicubic'), + 'seresnet50t': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnet101': _cfg( + url='', + interpolation='bicubic'), + 'seresnet152': _cfg( + url='', + interpolation='bicubic'), + 'seresnet152d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=1.0, test_input_size=(3, 320, 320) + ), + 'seresnet200d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + 'seresnet269d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + + + # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py + 'seresnext26d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext26t_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'seresnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth', + interpolation='bicubic'), + 'seresnext101_32x4d': _cfg( + url='', + interpolation='bicubic'), + 'seresnext101_32x8d': _cfg( + url='', + interpolation='bicubic'), + 'senet154': _cfg( + url='', + interpolation='bicubic', + first_conv='conv1.0'), + + # Efficient Channel Attention ResNets + 'ecaresnet26t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnetlight': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', + interpolation='bicubic'), + 'ecaresnet50d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnet101d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnet101d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet200d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + 'ecaresnet269d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), + crop_pct=1.0, test_input_size=(3, 352, 352)), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnext50t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + + # ResNets with anti-aliasing blur pool + 'resnetblur18': _cfg( + interpolation='bicubic'), + 'resnetblur50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', + interpolation='bicubic'), + + # ResNet-RS models + 'resnetrs50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth', + input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth', + input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs200': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs270': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth', + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs350': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth', + input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384), + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs420': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth', + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416), + interpolation='bicubic', first_conv='conv1.0'), +} + + +def get_padding(kernel_size, stride, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(BasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock does not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, + dilation=first_dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None + + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) + self.bn2 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + if self.aa is not None: + x = self.aa(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act2(x) + + return x + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(Bottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=1 if use_aa else stride, + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.aa = aa_layer(channels=width, stride=stride) if use_aa else None + + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + + self.se = create_attn(attn_layer, outplanes) + + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) + if self.aa is not None: + x = self.aa(x) + + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) + + if self.se is not None: + x = self.se(x) + + if self.drop_path is not None: + x = self.drop_path(x) + + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act3(x) + + return x + + +def downsample_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + + +def downsample_avg( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) + + +def drop_blocks(drop_block_rate=0.): + return [ + None, None, + DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, + DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] + + +def make_blocks( + block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, + down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + stages = [] + feature_info = [] + net_num_blocks = sum(block_repeats) + net_block_idx = 0 + net_stride = 4 + dilation = prev_dilation = 1 + for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): + stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it + stride = 1 if stage_idx == 0 else 2 + if net_stride >= output_stride: + dilation *= stride + stride = 1 + else: + net_stride *= stride + + downsample = None + if stride != 1 or inplanes != planes * block_fn.expansion: + down_kwargs = dict( + in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) + + block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) + blocks = [] + for block_idx in range(num_blocks): + downsample = downsample if block_idx == 0 else None + stride = stride if block_idx == 0 else 1 + block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule + blocks.append(block_fn( + inplanes, planes, stride, downsample, first_dilation=prev_dilation, + drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs)) + prev_dilation = dilation + inplanes = planes * block_fn.expansion + net_block_idx += 1 + + stages.append((stage_name, nn.Sequential(*blocks))) + feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) + + return stages, feature_info + + +class ResNet(nn.Module): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the + 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. + + ResNet variants (the same modifications can be used in SE/ResNeXt models as well): + * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) + * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample + * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) + * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample + * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int + Numbers of layers in each block + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + cardinality : int, default 1 + Number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64 + Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + stem_width : int, default 64 + Number of channels in stem convolutions + stem_type : str, default '' + The type of stem: + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + block_reduce_first: int, default 1 + Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 + down_kernel_size: int, default 1 + Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + output_stride : int, default 32 + Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. + act_layer : nn.Module, activation layer + norm_layer : nn.Module, normalization layer + aa_layer : nn.Module, anti-aliasing layer + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + + def __init__(self, block, layers, num_classes=1000, in_chans=3, + cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, + output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + block_args = block_args or dict() + assert output_stride in (8, 16, 32) + self.num_classes = num_classes + self.drop_rate = drop_rate + super(ResNet, self).__init__() + + # Stem + deep_stem = 'deep' in stem_type + inplanes = stem_width * 2 if deep_stem else 64 + if deep_stem: + stem_chs = (stem_width, stem_width) + if 'tiered' in stem_type: + stem_chs = (3 * (stem_width // 4), stem_width) + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), + norm_layer(stem_chs[0]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), + norm_layer(stem_chs[1]), + act_layer(inplace=True), + nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(inplanes) + self.act1 = act_layer(inplace=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] + + # Stem Pooling + if replace_stem_pool: + self.maxpool = nn.Sequential(*filter(None, [ + nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), + aa_layer(channels=inplanes, stride=2) if aa_layer else None, + norm_layer(inplanes), + act_layer(inplace=True) + ])) + else: + if aa_layer is not None: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # Feature Blocks + channels = [64, 128, 256, 512] + stage_modules, stage_feature_info = make_blocks( + block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, + output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, + down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, + drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + for stage in stage_modules: + self.add_module(*stage) # layer1, layer2, etc + self.feature_info.extend(stage_feature_info) + + # Head (Pooling and Classifier) + self.num_features = 512 * block.expansion + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + self.init_weights(zero_init_last_bn=zero_init_last_bn) + + def init_weights(self, zero_init_last_bn=True): + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) + x = self.fc(x) + return x + + +def _create_resnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet18', pretrained, **model_args) + + +@register_model +def resnet18d(pretrained=False, **kwargs): + """Constructs a ResNet-18-D model. + """ + model_args = dict( + block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet18d', pretrained, **model_args) + + +@register_model +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet34', pretrained, **model_args) + + +@register_model +def resnet34d(pretrained=False, **kwargs): + """Constructs a ResNet-34-D model. + """ + model_args = dict( + block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet34d', pretrained, **model_args) + + +@register_model +def resnet26(pretrained=False, **kwargs): + """Constructs a ResNet-26 model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('resnet26', pretrained, **model_args) + + +@register_model +def resnet26t(pretrained=False, **kwargs): + """Constructs a ResNet-26-T model. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet26t', pretrained, **model_args) + + +@register_model +def resnet26d(pretrained=False, **kwargs): + """Constructs a ResNet-26-D model. + """ + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet26d', pretrained, **model_args) + + +@register_model +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('resnet50', pretrained, **model_args) + + +@register_model +def resnet50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet50d', pretrained, **model_args) + + +@register_model +def resnet50t(pretrained=False, **kwargs): + """Constructs a ResNet-50-T model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet50t', pretrained, **model_args) + + +@register_model +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('resnet101', pretrained, **model_args) + + +@register_model +def resnet101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet101d', pretrained, **model_args) + + +@register_model +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('resnet152', pretrained, **model_args) + + +@register_model +def resnet152d(pretrained=False, **kwargs): + """Constructs a ResNet-152-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet152d', pretrained, **model_args) + + +@register_model +def resnet200(pretrained=False, **kwargs): + """Constructs a ResNet-200 model. + """ + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs) + return _create_resnet('resnet200', pretrained, **model_args) + + +@register_model +def resnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnet200d', pretrained, **model_args) + + +@register_model +def tv_resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model with original Torchvision weights. + """ + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet34', pretrained, **model_args) + + +@register_model +def tv_resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with original Torchvision weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('tv_resnet50', pretrained, **model_args) + + +@register_model +def tv_resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model w/ Torchvision pretrained weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) + return _create_resnet('tv_resnet101', pretrained, **model_args) + + +@register_model +def tv_resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model w/ Torchvision pretrained weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) + return _create_resnet('tv_resnet152', pretrained, **model_args) + + +@register_model +def wide_resnet50_2(pretrained=False, **kwargs): + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet50_2', pretrained, **model_args) + + +@register_model +def wide_resnet101_2(pretrained=False, **kwargs): + """Constructs a Wide ResNet-101-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs) + return _create_resnet('wide_resnet101_2', pretrained, **model_args) + + +@register_model +def resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext50_32x4d', pretrained, **model_args) + + +@register_model +def resnext50d_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnext50d_32x4d', pretrained, **model_args) + + +@register_model +def resnext101_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 32x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('resnext101_32x4d', pretrained, **model_args) + + +@register_model +def resnext101_32x8d(pretrained=False, **kwargs): + """Constructs a ResNeXt-101 32x8d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('resnext101_32x8d', pretrained, **model_args) + + +@register_model +def resnext101_64x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt101-64x4d model. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) + return _create_resnet('resnext101_64x4d', pretrained, **model_args) + + +@register_model +def tv_resnext50_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt50-32x4d model with original Torchvision weights. + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x32d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) + return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args) + + +@register_model +def ig_resnext101_32x48d(pretrained=True, **kwargs): + """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data + and finetuned on ImageNet from Figure 5 in + `"Exploring the Limits of Weakly Supervised Pretraining" `_ + Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) + return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args) + + +@register_model +def ssl_resnet18(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNet-18 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('ssl_resnet18', pretrained, **model_args) + + +@register_model +def ssl_resnet50(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNet-50 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('ssl_resnet50', pretrained, **model_args) + + +@register_model +def ssl_resnext50_32x4d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-50 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x4d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x8 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def ssl_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a semi-supervised ResNeXt-101 32x16 model pre-trained on YFCC100M dataset and finetuned on ImageNet + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def swsl_resnet18(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised Resnet-18 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) + return _create_resnet('swsl_resnet18', pretrained, **model_args) + + +@register_model +def swsl_resnet50(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNet-50 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) + return _create_resnet('swsl_resnet50', pretrained, **model_args) + + +@register_model +def swsl_resnext50_32x4d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-50 32x4 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x4d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x4 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) + return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x8 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) + return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args) + + +@register_model +def swsl_resnext101_32x16d(pretrained=True, **kwargs): + """Constructs a semi-weakly supervised ResNeXt-101 32x16 model pre-trained on 1B weakly supervised + image dataset and finetuned on ImageNet. + `"Billion-scale Semi-Supervised Learning for Image Classification" `_ + Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) + return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) + + +@register_model +def ecaresnet26t(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem and ECA attn. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet26t', pretrained, **model_args) + + +@register_model +def ecaresnet50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d', pretrained, **model_args) + + +@register_model +def resnetrs50(pretrained=False, **kwargs): + """Constructs a ResNet-RS-50 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs50', pretrained, **model_args) + + +@register_model +def resnetrs101(pretrained=False, **kwargs): + """Constructs a ResNet-RS-101 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs101', pretrained, **model_args) + + +@register_model +def resnetrs152(pretrained=False, **kwargs): + """Constructs a ResNet-RS-152 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs152', pretrained, **model_args) + + +@register_model +def resnetrs200(pretrained=False, **kwargs): + """Constructs a ResNet-RS-200 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs200', pretrained, **model_args) + + +@register_model +def resnetrs270(pretrained=False, **kwargs): + """Constructs a ResNet-RS-270 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs270', pretrained, **model_args) + + + +@register_model +def resnetrs350(pretrained=False, **kwargs): + """Constructs a ResNet-RS-350 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs350', pretrained, **model_args) + + +@register_model +def resnetrs420(pretrained=False, **kwargs): + """Constructs a ResNet-RS-420 model + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs420', pretrained, **model_args) + + +@register_model +def ecaresnet50d_pruned(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args) + + +@register_model +def ecaresnet50t(pretrained=False, **kwargs): + """Constructs an ECA-ResNet-50-T model. + Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet50t', pretrained, **model_args) + + +@register_model +def ecaresnetlight(pretrained=False, **kwargs): + """Constructs a ResNet-50-D light model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnetlight', pretrained, **model_args) + + +@register_model +def ecaresnet101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with eca. + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d', pretrained, **model_args) + + +@register_model +def ecaresnet101d_pruned(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model pruned with eca. + The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) + + +@register_model +def ecaresnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model with ECA. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet200d', pretrained, **model_args) + + +@register_model +def ecaresnet269d(pretrained=False, **kwargs): + """Constructs a ResNet-269-D model with ECA. + """ + model_args = dict( + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnet269d', pretrained, **model_args) + + +@register_model +def ecaresnext26t_32x4d(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. This model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args) + + +@register_model +def ecaresnext50t_32x4d(pretrained=False, **kwargs): + """Constructs an ECA-ResNeXt-50-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. This model replaces SE module with the ECA module + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) + return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) + + +@register_model +def resnetblur18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model with blur anti-aliasing + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur18', pretrained, **model_args) + + +@register_model +def resnetblur50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with blur anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur50', pretrained, **model_args) + + +@register_model +def seresnet18(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet18', pretrained, **model_args) + + +@register_model +def seresnet34(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet34', pretrained, **model_args) + + +@register_model +def seresnet50(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50', pretrained, **model_args) + + +@register_model +def seresnet50t(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet50t', pretrained, **model_args) + + +@register_model +def seresnet101(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet101', pretrained, **model_args) + + +@register_model +def seresnet152(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet152', pretrained, **model_args) + + +@register_model +def seresnet152d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet152d', pretrained, **model_args) + + +@register_model +def seresnet200d(pretrained=False, **kwargs): + """Constructs a ResNet-200-D model with SE attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet200d', pretrained, **model_args) + + +@register_model +def seresnet269d(pretrained=False, **kwargs): + """Constructs a ResNet-269-D model with SE attn. + """ + model_args = dict( + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnet269d', pretrained, **model_args) + + +@register_model +def seresnext26d_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-D model.` + This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for + combination of deep stem and avg_pool in downsample. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26t_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNet-26-T model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) + + +@register_model +def seresnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs a SE-ResNeXt-26-T model. + NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note + so keeping this def for backwards compat with any uses out there. Old 't' model is lost. + """ + return seresnext26t_32x4d(pretrained=pretrained, **kwargs) + + +@register_model +def seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x4d', pretrained, **model_args) + + +@register_model +def seresnext101_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101_32x8d', pretrained, **model_args) + + +@register_model +def senet154(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('senet154', pretrained, **model_args) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff4da8c41f0de88e64a6404fd5fada8a5545957 --- /dev/null +++ b/timm/models/resnetv2.py @@ -0,0 +1,655 @@ +"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization. + +A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code +at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have +been included here as pretrained models from their original .NPZ checkpoints. + +Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and +extra padding support to allow porting of official Hybrid ResNet pretrained weights from +https://github.com/google-research/vision_transformer + +Thanks to the Google team for the above two repositories and associated papers: +* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370 +* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929 +* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + +Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020. +""" +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict # pylint: disable=g-importing-member + +import torch +import torch.nn as nn +from functools import partial + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .registry import register_model +from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\ + ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + # pretrained on imagenet21k, finetuned on imagenet1k + 'resnetv2_50x1_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_50x3_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_101x1_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_101x3_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_152x2_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz', + input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0), + 'resnetv2_152x4_bitm': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz', + input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480? + + # trained on imagenet-21k + 'resnetv2_50x1_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz', + num_classes=21843), + 'resnetv2_50x3_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz', + num_classes=21843), + 'resnetv2_101x1_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz', + num_classes=21843), + 'resnetv2_101x3_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz', + num_classes=21843), + 'resnetv2_152x2_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz', + num_classes=21843), + 'resnetv2_152x4_bitm_in21k': _cfg( + url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', + num_classes=21843), + + 'resnetv2_50x1_bit_distilled': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz', + interpolation='bicubic'), + 'resnetv2_152x2_bit_teacher': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz', + interpolation='bicubic'), + 'resnetv2_152x2_bit_teacher_384': _cfg( + url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'), + + 'resnetv2_50': _cfg( + interpolation='bicubic'), + 'resnetv2_50d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50t': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_101': _cfg( + interpolation='bicubic'), + 'resnetv2_101d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_152': _cfg( + interpolation='bicubic'), + 'resnetv2_152d': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), +} + + +def make_div(v, divisor=8): + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__( + self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, + act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.norm1 = norm_layer(in_chs) + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm2 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm3 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + nn.init.zeros_(self.conv3.weight) + + def forward(self, x): + x_preact = self.norm1(x) + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x_preact) + + # residual branch + x = self.conv1(x_preact) + x = self.conv2(self.norm2(x)) + x = self.conv3(self.norm3(x)) + x = self.drop_path(x) + return x + shortcut + + +class Bottleneck(nn.Module): + """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT. + """ + def __init__( + self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1, + act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.): + super().__init__() + first_dilation = first_dilation or dilation + act_layer = act_layer or nn.ReLU + conv_layer = conv_layer or StdConv2d + norm_layer = norm_layer or partial(GroupNormAct, num_groups=32) + out_chs = out_chs or in_chs + mid_chs = make_div(out_chs * bottle_ratio) + + if proj_layer is not None: + self.downsample = proj_layer( + in_chs, out_chs, stride=stride, dilation=dilation, preact=False, + conv_layer=conv_layer, norm_layer=norm_layer) + else: + self.downsample = None + + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.norm1 = norm_layer(mid_chs) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + self.norm2 = norm_layer(mid_chs) + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.norm3 = norm_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.act3 = act_layer(inplace=True) + + def zero_init_last(self): + nn.init.zeros_(self.norm3.weight) + + def forward(self, x): + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + # residual + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.conv3(x) + x = self.norm3(x) + x = self.drop_path(x) + x = self.act3(x + shortcut) + return x + + +class DownsampleConv(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True, + conv_layer=None, norm_layer=None): + super(DownsampleConv, self).__init__() + self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(x)) + + +class DownsampleAvg(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, + preact=True, conv_layer=None, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) + + def forward(self, x): + return self.norm(self.conv(self.pool(x))) + + +class ResNetStage(nn.Module): + """ResNet Stage.""" + def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1, + avg_down=False, block_dpr=None, block_fn=PreActBottleneck, + act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs): + super(ResNetStage, self).__init__() + first_dilation = 1 if dilation in (1, 2) else 2 + layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer) + proj_layer = DownsampleAvg if avg_down else DownsampleConv + prev_chs = in_chs + self.blocks = nn.Sequential() + for block_idx in range(depth): + drop_path_rate = block_dpr[block_idx] if block_dpr else 0. + stride = stride if block_idx == 0 else 1 + self.blocks.add_module(str(block_idx), block_fn( + prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups, + first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate, + **layer_kwargs, **block_kwargs)) + prev_chs = out_chs + first_dilation = dilation + proj_layer = None + + def forward(self, x): + x = self.blocks(x) + return x + + +def is_stem_deep(stem_type): + return any([s in stem_type for s in ('deep', 'tiered')]) + + +def create_resnetv2_stem( + in_chs, out_chs=64, stem_type='', preact=True, + conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): + stem = OrderedDict() + assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') + + # NOTE conv padding mode can be changed by overriding the conv_layer def + if is_stem_deep(stem_type): + # A 3 deep 3x3 conv stack as in ResNet V1D models + if 'tiered' in stem_type: + stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py + else: + stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets + stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2) + stem['norm1'] = norm_layer(stem_chs[0]) + stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1) + stem['norm2'] = norm_layer(stem_chs[1]) + stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1) + if not preact: + stem['norm3'] = norm_layer(out_chs) + else: + # The usual 7x7 stem conv + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + if not preact: + stem['norm'] = norm_layer(out_chs) + + if 'fixed' in stem_type: + # 'fixed' SAME padding approximation that is used in BiT models + stem['pad'] = nn.ConstantPad2d(1, 0.) + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) + elif 'same' in stem_type: + # full, input size based 'SAME' padding, used in ViT Hybrid model + stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same') + else: + # the usual PyTorch symmetric padding + stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + return nn.Sequential(stem) + + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode. + """ + + def __init__( + self, layers, channels=(256, 512, 1024, 2048), + num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, + act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0., drop_path_rate=0., zero_init_last=True): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + wf = width_factor + + self.feature_info = [] + stem_chs = make_div(stem_chs * wf) + self.stem = create_resnetv2_stem( + in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) + + prev_chs = stem_chs + curr_stride = 4 + dilation = 1 + block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)] + block_fn = PreActBottleneck if preact else Bottleneck + self.stages = nn.Sequential() + for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)): + out_chs = make_div(c * wf) + stride = 1 if stage_idx == 0 else 2 + if curr_stride >= output_stride: + dilation *= stride + stride = 1 + stage = ResNetStage( + prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down, + act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn) + prev_chs = out_chs + curr_stride *= stride + self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] + self.stages.add_module(str(stage_idx), stage) + + self.num_features = prev_chs + self.norm = norm_layer(self.num_features) if preact else nn.Identity() + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) + + self.init_weights(zero_init_last=zero_init_last) + + def init_weights(self, zero_init_last=True): + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix='resnet/'): + _load_weights(self, checkpoint_path, prefix) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _init_weights(module: nn.Module, name: str = '', zero_init_last=True): + if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif zero_init_last and hasattr(module, 'zero_init_last'): + module.zero_init_last() + + +@torch.no_grad() +def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'): + import numpy as np + + def t2p(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + weights = np.load(checkpoint_path) + stem_conv_w = adapt_input_conv( + model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) + model.stem.conv.weight.copy_(stem_conv_w) + model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) + model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) + if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \ + model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) + model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) + for i, (sname, stage) in enumerate(model.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + cname = 'standardized_conv2d' + block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' + block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel'])) + block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel'])) + block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel'])) + block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma'])) + block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma'])) + block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma'])) + block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta'])) + block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta'])) + block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta'])) + if block.downsample is not None: + w = weights[f'{block_prefix}a/proj/{cname}/kernel'] + block.downsample.conv.weight.copy_(t2p(w)) + + +def _create_resnetv2(variant, pretrained=False, **kwargs): + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + ResNetV2, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, + pretrained_custom_load=True, + **kwargs) + + +def _create_resnetv2_bit(variant, pretrained=False, **kwargs): + return _create_resnetv2( + variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs) + + +@register_model +def resnetv2_50x1_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_50x3_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_101x1_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_101x3_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_152x2_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x4_bitm(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) + + +@register_model +def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 6, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 23, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 4, 23, 3], width_factor=3, **kwargs) + + +@register_model +def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): + return _create_resnetv2_bit( + 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), + layers=[3, 8, 36, 3], width_factor=4, **kwargs) + + +@register_model +def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs): + """ ResNetV2-50x1-BiT Distilled + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) + + +@register_model +def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs): + """ ResNetV2-152x2-BiT Teacher + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs): + """ ResNetV2-152xx-BiT Teacher @ 384x384 + Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237 + """ + return _create_resnetv2_bit( + 'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) + + +@register_model +def resnetv2_50(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_50d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50t(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50t', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='tiered', avg_down=True, **kwargs) + + +@register_model +def resnetv2_101(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101', pretrained=pretrained, + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_101d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_101d', pretrained=pretrained, + layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_152(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152', pretrained=pretrained, + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + + +@register_model +def resnetv2_152d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_152d', pretrained=pretrained, + layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, + stem_type='deep', avg_down=True, **kwargs) + + +# @register_model +# def resnetv2_50ebd(pretrained=False, **kwargs): +# # FIXME for testing w/ TPU + PyTorch XLA +# return _create_resnetv2( +# 'resnetv2_50d', pretrained=pretrained, +# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, +# stem_type='deep', avg_down=True, **kwargs) +# +# +# @register_model +# def resnetv2_50esd(pretrained=False, **kwargs): +# # FIXME for testing w/ TPU + PyTorch XLA +# return _create_resnetv2( +# 'resnetv2_50d', pretrained=pretrained, +# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, +# stem_type='deep', avg_down=True, **kwargs) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..279780beb6c5cf05d6d89073fbc5d99f1676eebf --- /dev/null +++ b/timm/models/rexnet.py @@ -0,0 +1,238 @@ +""" ReXNet + +A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` - +https://arxiv.org/abs/2007.00992 + +Adapted from original impl at https://github.com/clovaai/rexnet +Copyright (c) 2020-present NAVER Corp. MIT license + +Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman +Copyright 2020 Ross Wightman +""" + +import torch.nn as nn +from functools import partial +from math import ceil + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule +from .registry import register_model +from .efficientnet_builder import efficientnet_init_weights + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + rexnet_100=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'), + rexnet_130=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'), + rexnet_150=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'), + rexnet_200=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'), + rexnetr_100=_cfg( + url=''), + rexnetr_130=_cfg( + url=''), + rexnetr_150=_cfg( + url=''), + rexnetr_200=_cfg( + url=''), +) + +SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d) + + +class LinearBottleneck(nn.Module): + def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1, + act_layer='swish', dw_act_layer='relu6', drop_path=None): + super(LinearBottleneck, self).__init__() + self.use_shortcut = stride == 1 and in_chs <= out_chs + self.in_channels = in_chs + self.out_channels = out_chs + + if exp_ratio != 1.: + dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) + self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer) + else: + dw_chs = in_chs + self.conv_exp = None + + self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) + if se_ratio > 0: + self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) + else: + self.se = None + self.act_dw = create_act_layer(dw_act_layer) + + self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) + self.drop_path = drop_path + + def feat_channels(self, exp=False): + return self.conv_dw.out_channels if exp else self.out_channels + + def forward(self, x): + shortcut = x + if self.conv_exp is not None: + x = self.conv_exp(x) + x = self.conv_dw(x) + if self.se is not None: + x = self.se(x) + x = self.act_dw(x) + x = self.conv_pwl(x) + if self.use_shortcut: + if self.drop_path is not None: + x = self.drop_path(x) + x[:, 0:self.in_channels] += shortcut + return x + + +def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1): + layers = [1, 2, 2, 3, 3, 5] + strides = [1, 2, 2, 2, 1, 2] + layers = [ceil(element * depth_mult) for element in layers] + strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], []) + exp_ratios = [1] * layers[0] + [6] * sum(layers[1:]) + depth = sum(layers[:]) * 3 + base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs + + # The following channel configuration is a simple instance to make each layer become an expand layer. + out_chs_list = [] + for i in range(depth // 3): + out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div)) + base_chs += final_chs / (depth // 3 * 1.0) + + se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:]) + + return list(zip(out_chs_list, exp_ratios, strides, se_ratios)) + + +def _build_blocks( + block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.): + feat_chs = [prev_chs] + feature_info = [] + curr_stride = 2 + features = [] + num_blocks = len(block_cfg) + for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg): + if stride > 1: + fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}' + feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)] + curr_stride *= stride + block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule + drop_path = DropPath(block_dpr) if block_dpr > 0. else None + features.append(LinearBottleneck( + in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio, + ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path)) + prev_chs = chs + feat_chs += [features[-1].feat_channels()] + pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) + feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')] + features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer)) + return features, feature_info + + +class ReXNetV1(nn.Module): + def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, + initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12., + ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.): + super(ReXNetV1, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + + assert output_stride == 32 # FIXME support dilation + stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 + stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) + self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) + + block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) + features, self.feature_info = _build_blocks( + block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate) + self.num_features = features[-1].out_channels + self.features = nn.Sequential(*features) + + self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) + + efficientnet_init_weights(self) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_rexnet(variant, pretrained, **kwargs): + feature_cfg = dict(flatten_sequential=True) + return build_model_with_cfg( + ReXNetV1, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, + **kwargs) + + +@register_model +def rexnet_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x""" + return _create_rexnet('rexnet_100', pretrained, **kwargs) + + +@register_model +def rexnet_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x""" + return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs) + + +@register_model +def rexnet_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x""" + return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs) + + +@register_model +def rexnet_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x""" + return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs) + + +@register_model +def rexnetr_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs) + + +@register_model +def rexnetr_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs) + + +@register_model +def rexnetr_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs) + + +@register_model +def rexnetr_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3379db3da5e5303d9af2c78098b9a5424a5fde --- /dev/null +++ b/timm/models/selecsls.py @@ -0,0 +1,362 @@ +"""PyTorch SelecSLS Net example for ImageNet Classification +License: CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode) +Author: Dushyant Mehta (@mehtadushy) + +SelecSLS (core) Network Architecture as proposed in "XNect: Real-time Multi-person 3D +Human Pose Estimation with a Single RGB Camera, Mehta et al." +https://arxiv.org/abs/1907.00837 + +Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models +and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch +""" +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'selecsls42': _cfg( + url='', + interpolation='bicubic'), + 'selecsls42b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth', + interpolation='bicubic'), + 'selecsls60': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60-bbf87526.pth', + interpolation='bicubic'), + 'selecsls60b': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth', + interpolation='bicubic'), + 'selecsls84': _cfg( + url='', + interpolation='bicubic'), +} + + +class SequentialList(nn.Sequential): + + def __init__(self, *args): + super(SequentialList, self).__init__(*args) + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (List[torch.Tensor]) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (torch.Tensor) -> (List[torch.Tensor]) + pass + + def forward(self, x) -> List[torch.Tensor]: + for module in self: + x = module(x) + return x + + +class SelectSeq(nn.Module): + def __init__(self, mode='index', index=0): + super(SelectSeq, self).__init__() + self.mode = mode + self.index = index + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (List[torch.Tensor]) -> (torch.Tensor) + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, x): + # type: (Tuple[torch.Tensor]) -> (torch.Tensor) + pass + + def forward(self, x) -> torch.Tensor: + if self.mode == 'index': + return x[self.index] + else: + return torch.cat(x, dim=1) + + +def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1): + if padding is None: + padding = ((stride - 1) + dilation * (k - 1)) // 2 + return nn.Sequential( + nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(out_chs), + nn.ReLU(inplace=True) + ) + + +class SelecSLSBlock(nn.Module): + def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1): + super(SelecSLSBlock, self).__init__() + self.stride = stride + self.is_first = is_first + assert stride in [1, 2] + + # Process input with 4 conv blocks with the same number of input and output channels + self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation) + self.conv2 = conv_bn(mid_chs, mid_chs, 1) + self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1) + self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3) + self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1) + + def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: + if not isinstance(x, list): + x = [x] + assert len(x) in [1, 2] + + d1 = self.conv1(x[0]) + d2 = self.conv3(self.conv2(d1)) + d3 = self.conv5(self.conv4(d2)) + if self.is_first: + out = self.conv6(torch.cat([d1, d2, d3], 1)) + return [out, out] + else: + return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]] + + +class SelecSLS(nn.Module): + """SelecSLS42 / SelecSLS60 / SelecSLS84 + + Parameters + ---------- + cfg : network config dictionary specifying block type, feature, and head args + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + + def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(SelecSLS, self).__init__() + + self.stem = conv_bn(in_chans, 32, stride=2) + self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) + self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way + self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) + self.num_features = cfg['num_features'] + self.feature_info = cfg['feature_info'] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + x = self.head(self.from_seq(x)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _create_selecsls(variant, pretrained, **kwargs): + cfg = {} + feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] + if variant.startswith('selecsls42'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 144, 144, True, 2), + (144, 144, 144, 288, False, 1), + (288, 0, 304, 304, True, 2), + (304, 304, 304, 480, False, 1), + ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.3'), + dict(num_chs=480, reduction=16, module='features.5'), + ]) + # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) + if variant == 'selecsls42b': + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (480, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) + cfg['num_features'] = 1280 + + elif variant.startswith('selecsls60'): + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 128, False, 1), + (128, 0, 128, 128, True, 2), + (128, 128, 128, 128, False, 1), + (128, 128, 128, 288, False, 1), + (288, 0, 288, 288, True, 2), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 288, False, 1), + (288, 288, 288, 416, False, 1), + ] + feature_info.extend([ + dict(num_chs=128, reduction=4, module='features.1'), + dict(num_chs=288, reduction=8, module='features.4'), + dict(num_chs=416, reduction=16, module='features.8'), + ]) + # Head can be replaced with alternative configurations depending on the problem + feature_info.append(dict(num_chs=1024, reduction=32, module='head.1')) + if variant == 'selecsls60b': + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1280, 3, 2), + (1280, 1024, 1, 1), + ] + feature_info.append(dict(num_chs=1024, reduction=64, module='head.3')) + cfg['num_features'] = 1024 + else: + cfg['head'] = [ + (416, 756, 3, 2), + (756, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 1, 1), + ] + feature_info.append(dict(num_chs=1280, reduction=64, module='head.3')) + cfg['num_features'] = 1280 + + elif variant == 'selecsls84': + cfg['block'] = SelecSLSBlock + # Define configuration of the network after the initial neck + cfg['features'] = [ + # in_chs, skip_chs, mid_chs, out_chs, is_first, stride + (32, 0, 64, 64, True, 2), + (64, 64, 64, 144, False, 1), + (144, 0, 144, 144, True, 2), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 144, False, 1), + (144, 144, 144, 304, False, 1), + (304, 0, 304, 304, True, 2), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 304, False, 1), + (304, 304, 304, 512, False, 1), + ] + feature_info.extend([ + dict(num_chs=144, reduction=4, module='features.1'), + dict(num_chs=304, reduction=8, module='features.6'), + dict(num_chs=512, reduction=16, module='features.12'), + ]) + # Head can be replaced with alternative configurations depending on the problem + cfg['head'] = [ + (512, 960, 3, 2), + (960, 1024, 3, 1), + (1024, 1024, 3, 2), + (1024, 1280, 3, 1), + ] + cfg['num_features'] = 1280 + feature_info.extend([ + dict(num_chs=1024, reduction=32, module='head.1'), + dict(num_chs=1280, reduction=64, module='head.3') + ]) + else: + raise ValueError('Invalid net configuration ' + variant + ' !!!') + cfg['feature_info'] = feature_info + + # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? + return build_model_with_cfg( + SelecSLS, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfg, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), + **kwargs) + + +@register_model +def selecsls42(pretrained=False, **kwargs): + """Constructs a SelecSLS42 model. + """ + return _create_selecsls('selecsls42', pretrained, **kwargs) + + +@register_model +def selecsls42b(pretrained=False, **kwargs): + """Constructs a SelecSLS42_B model. + """ + return _create_selecsls('selecsls42b', pretrained, **kwargs) + + +@register_model +def selecsls60(pretrained=False, **kwargs): + """Constructs a SelecSLS60 model. + """ + return _create_selecsls('selecsls60', pretrained, **kwargs) + + +@register_model +def selecsls60b(pretrained=False, **kwargs): + """Constructs a SelecSLS60_B model. + """ + return _create_selecsls('selecsls60b', pretrained, **kwargs) + + +@register_model +def selecsls84(pretrained=False, **kwargs): + """Constructs a SelecSLS84 model. + """ + return _create_selecsls('selecsls84', pretrained, **kwargs) diff --git a/timm/models/senet.py b/timm/models/senet.py new file mode 100644 index 0000000000000000000000000000000000000000..3d0ba7b3ee573523523c3af574c835ccdf502a32 --- /dev/null +++ b/timm/models/senet.py @@ -0,0 +1,467 @@ +""" +SEResNet implementation from Cadene's pretrained models +https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py +Additional credit to https://github.com/creafz + +Original model: https://github.com/hujie-frank/SENet + +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate +support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here. +""" +import math +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['SENet'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', + **kwargs + } + + +default_cfgs = { + 'legacy_senet154': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'), + 'legacy_seresnet18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth', + interpolation='bicubic'), + 'legacy_seresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'), + 'legacy_seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), + 'legacy_seresnet101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), + 'legacy_seresnet152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), + 'legacy_seresnext26_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', + interpolation='bicubic'), + 'legacy_seresnext50_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), + 'legacy_seresnext101_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'), +} + + +def _weight_init(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = x.mean((2, 3), keepdim=True) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out = self.se_module(out) + shortcut + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, planes * 4, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d( + planes * 4, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, bias=False, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d( + inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes, reduction=reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + shortcut = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.downsample is not None: + shortcut = self.downsample(x) + + out = self.se_module(out) + shortcut + out = self.relu(out) + + return out + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, drop_rate=0.2, + in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1, + downsample_padding=0, num_classes=1000, global_pool='avg'): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + self.num_classes = num_classes + self.drop_rate = drop_rate + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d( + in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`. + self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True) + self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')] + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')] + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')] + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')] + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] + self.num_features = 512 * block.expansion + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + for m in self.modules(): + _weight_init(m) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size, + stride=stride, padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def get_classifier(self): + return self.last_linear + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.last_linear = create_classifier( + self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.layer0(x) + x = self.pool0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.global_pool(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.logits(x) + return x + + +def _create_senet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + SENet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def legacy_seresnet18(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet18', pretrained, **model_args) + + +@register_model +def legacy_seresnet34(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet34', pretrained, **model_args) + + +@register_model +def legacy_seresnet50(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet50', pretrained, **model_args) + + +@register_model +def legacy_seresnet101(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet101', pretrained, **model_args) + + +@register_model +def legacy_seresnet152(pretrained=False, **kwargs): + model_args = dict( + block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs) + return _create_senet('legacy_seresnet152', pretrained, **model_args) + + +@register_model +def legacy_senet154(pretrained=False, **kwargs): + model_args = dict( + block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16, + downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs) + return _create_senet('legacy_senet154', pretrained, **model_args) + + +@register_model +def legacy_seresnext26_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args) + + +@register_model +def legacy_seresnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs) + return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args) diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc2aa534c1c9d27c7a988b72f9c4f5a1f172e95 --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,215 @@ +""" Selective Kernel Networks (ResNet base) + +Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) + +This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268) +and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer +to the original paper with some modifications of my own to better balance param count vs accuracy. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math + +from torch import nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import SelectiveKernel, ConvBnAct, create_attn +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'), + 'skresnet34': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg( + first_conv='conv1.0'), + 'skresnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'), +} + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = SelectiveKernel( + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act(x) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, + drop_block=None, drop_path=None): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) + self.conv2 = SelectiveKernel( + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + shortcut = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + shortcut = self.downsample(shortcut) + x += shortcut + x = self.act(x) + return x + + +def _create_skresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + + +@register_model +def skresnet18(pretrained=False, **kwargs): + """Constructs a Selective Kernel ResNet-18 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet18', pretrained, **model_args) + + +@register_model +def skresnet34(pretrained=False, **kwargs): + """Constructs a Selective Kernel ResNet-34 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True) + model_args = dict( + block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet34', pretrained, **model_args) + + +@register_model +def skresnet50(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50', pretrained, **model_args) + + +@register_model +def skresnet50d(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + + Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this + variation splits the input channels to the selective convolutions to keep param count down. + """ + sk_kwargs = dict(split_input=True) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnet50d', pretrained, **model_args) + + +@register_model +def skresnext50_32x4d(pretrained=False, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet-50 model in the Select Kernel Paper + """ + sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False) + model_args = dict( + block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + return _create_skresnet('skresnext50_32x4d', pretrained, **model_args) + diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..87b5c4bc80be377ad832ce704897020ff5560a9c --- /dev/null +++ b/timm/models/swin_transformer.py @@ -0,0 +1,660 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model +from .vision_transformer import checkpoint_filter_fn, _init_vit_weights + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'swin_base_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_base_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth', + ), + + 'swin_large_patch4_window12_384': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'swin_large_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth', + ), + + 'swin_small_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth', + ), + + 'swin_tiny_patch4_window7_224': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth', + ), + + 'swin_base_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_base_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + num_classes=21841), + + 'swin_large_patch4_window12_384_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841), + + 'swin_large_patch4_window7_224_in22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + num_classes=21841), + +} + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + # H *= 2 + # W *= 2 + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + # H *= 2 + # W *= 2 + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if not torch.jit.is_scripting() and self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): + super().__init__() + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + self.patch_grid = self.patch_embed.grid_size + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + layers = [] + for i_layer in range(self.num_layers): + layers += [BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + ] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if weight_init.startswith('jax'): + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + self.apply(_init_vit_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + l1 = self.layers[0](x) + l2 = self.layers[1](l1) + l3 = self.layers[2](l2) + l4 = self.layers[3](l3) + # x = self.norm(l4) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return {"layer1":l1, "layer2":l2, "layer3":l3, "layer4":l4} + + def forward(self, x): + x = self.forward_features(x) + # x = self.head(x) + return x + + +def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + + +@register_model +def swin_base_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_base_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window12_384(pretrained=False, **kwargs): + """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_small_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-S @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_tiny_patch4_window7_224(pretrained=False, **kwargs): + """ Swin-T @ 224x224, trained ImageNet-1k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs): + """ Swin-B @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs): + """ Swin-B @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs): + """ Swin-L @ 384x384, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): + """ Swin-L @ 224x224, trained ImageNet-22k + """ + model_kwargs = dict( + patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) \ No newline at end of file diff --git a/timm/models/tnt.py b/timm/models/tnt.py new file mode 100644 index 0000000000000000000000000000000000000000..8186cc4aea0c53c5a6217e3cdd0b9193bb6d1359 --- /dev/null +++ b/timm/models/tnt.py @@ -0,0 +1,268 @@ +""" Transformer in Transformer (TNT) in PyTorch + +A PyTorch implement of TNT as described in +'Transformer in Transformer' - https://arxiv.org/abs/2103.00112 + +The official mindspore code is released and available at +https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT +""" +import math +import torch +import torch.nn as nn +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import build_model_with_cfg +from timm.models.layers import Mlp, DropPath, trunc_normal_ +from timm.models.layers.helpers import to_2tuple +from timm.models.registry import register_model +from timm.models.vision_transformer import resize_pos_embed + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'pixel_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'tnt_s_patch16_224': _cfg( + url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'tnt_b_patch16_224': _cfg( + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), +} + + +class Attention(nn.Module): + """ Multi-Head Attention + """ + def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.hidden_dim = hidden_dim + self.num_heads = num_heads + head_dim = hidden_dim // num_heads + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias) + self.v = nn.Linear(dim, dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop, inplace=True) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop, inplace=True) + + def forward(self, x): + B, N, C = x.shape + qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple) + v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + """ TNT Block + """ + def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4., + qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + # Inner transformer + self.norm_in = norm_layer(in_dim) + self.attn_in = Attention( + in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + + self.norm_mlp_in = norm_layer(in_dim) + self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4), + out_features=in_dim, act_layer=act_layer, drop=drop) + + self.norm1_proj = norm_layer(in_dim) + self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True) + # Outer transformer + self.norm_out = norm_layer(dim) + self.attn_out = Attention( + dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm_mlp = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), + out_features=dim, act_layer=act_layer, drop=drop) + + def forward(self, pixel_embed, patch_embed): + # inner + pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed))) + pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed))) + # outer + B, N, C = patch_embed.size() + patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1)) + patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed))) + patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed))) + return pixel_embed, patch_embed + + +class PixelEmbed(nn.Module): + """ Image to Pixel Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # grid_size property necessary for resizing positional embedding + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + num_patches = (self.grid_size[0]) * (self.grid_size[1]) + self.img_size = img_size + self.num_patches = num_patches + self.in_dim = in_dim + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] + self.new_patch_size = new_patch_size + + self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) + self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size) + + def forward(self, x, pixel_pos): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + x = self.unfold(x) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) + x = x + pixel_pos + x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) + return x + + +class TNT(nn.Module): + """ Transformer in Transformer - https://arxiv.org/abs/2103.00112 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12, + num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.pixel_embed = PixelEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride) + num_patches = self.pixel_embed.num_patches + self.num_patches = num_patches + new_patch_size = self.pixel_embed.new_patch_size + num_pixel = new_patch_size[0] * new_patch_size[1] + + self.norm1_proj = norm_layer(num_pixel * in_dim) + self.proj = nn.Linear(num_pixel * in_dim, embed_dim) + self.norm2_proj = norm_layer(embed_dim) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + blocks = [] + for i in range(depth): + blocks.append(Block( + dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head, + mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[i], norm_layer=norm_layer)) + self.blocks = nn.ModuleList(blocks) + self.norm = norm_layer(embed_dim) + + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.cls_token, std=.02) + trunc_normal_(self.patch_pos, std=.02) + trunc_normal_(self.pixel_pos, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'patch_pos', 'pixel_pos', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + pixel_embed = self.pixel_embed(x, self.pixel_pos) + + patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1)))) + patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1) + patch_embed = patch_embed + self.patch_pos + patch_embed = self.pos_drop(patch_embed) + + for blk in self.blocks: + pixel_embed, patch_embed = blk(pixel_embed, patch_embed) + + patch_embed = self.norm(patch_embed) + return patch_embed[:, 0] + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if state_dict['patch_pos'].shape != model.patch_pos.shape: + state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], + model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) + return state_dict + + +def _create_tnt(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + TNT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def tnt_s_patch16_224(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + qkv_bias=False, **kwargs) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def tnt_b_patch16_224(pretrained=False, **kwargs): + model_cfg = dict( + patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + qkv_bias=False, **kwargs) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) + return model diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py new file mode 100644 index 0000000000000000000000000000000000000000..372bfb7bc0ce89241121f8b85ea928f376af8bd5 --- /dev/null +++ b/timm/models/tresnet.py @@ -0,0 +1,297 @@ +""" +TResNet: High Performance GPU-Dedicated Architecture +https://arxiv.org/pdf/2003.13630.pdf + +Original model: https://github.com/mrT23/TResNet + +""" +from collections import OrderedDict + +import torch +import torch.nn as nn + +from .helpers import build_model_with_cfg +from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule +from .registry import register_model + +__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': (0, 0, 0), 'std': (1, 1, 1), + 'first_conv': 'body.conv1.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'tresnet_m': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_1k_miil_83_1.pth'), + 'tresnet_m_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_miil_in21k.pth', num_classes=11221), + 'tresnet_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'), + 'tresnet_xl': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'), + 'tresnet_m_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'), + 'tresnet_l_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'), + 'tresnet_xl_448': _cfg( + input_size=(3, 448, 448), pool_size=(14, 14), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth') +} + + +def IABN2Float(module: nn.Module) -> nn.Module: + """If `module` is IABN don't use half precision.""" + if isinstance(module, InplaceAbn): + module.float() + for child in module.children(): + IABN2Float(child) + return module + + +def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2): + return nn.Sequential( + nn.Conv2d( + ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False), + InplaceAbn(nf, act_layer=act_layer, act_param=act_param) + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None): + super(BasicBlock, self).__init__() + if stride == 1: + self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3) + else: + if aa_layer is None: + self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3) + else: + self.conv1 = nn.Sequential( + conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) + + self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity") + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + rd_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None + + def forward(self, x): + if self.downsample is not None: + shortcut = self.downsample(x) + else: + shortcut = x + + out = self.conv1(x) + out = self.conv2(out) + + if self.se is not None: + out = self.se(out) + + out += shortcut + out = self.relu(out) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, + act_layer="leaky_relu", aa_layer=None): + super(Bottleneck, self).__init__() + self.conv1 = conv2d_iabn( + inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3) + if stride == 1: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3) + else: + if aa_layer is None: + self.conv2 = conv2d_iabn( + planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3) + else: + self.conv2 = nn.Sequential( + conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3), + aa_layer(channels=planes, filt_size=3, stride=2)) + + reduction_chs = max(planes * self.expansion // 8, 64) + self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None + + self.conv3 = conv2d_iabn( + planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + if self.downsample is not None: + shortcut = self.downsample(x) + else: + shortcut = x + + out = self.conv1(x) + out = self.conv2(out) + if self.se is not None: + out = self.se(out) + + out = self.conv3(out) + out = out + shortcut # no inplace + out = self.relu(out) + + return out + + +class TResNet(nn.Module): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.): + self.num_classes = num_classes + self.drop_rate = drop_rate + super(TResNet, self).__init__() + + aa_layer = BlurPool2d + + # TResnet stages + self.inplanes = int(64 * width_factor) + self.planes = int(64 * width_factor) + conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3) + layer1 = self._make_layer( + BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) # 56x56 + layer2 = self._make_layer( + BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) # 28x28 + layer3 = self._make_layer( + Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) # 14x14 + layer4 = self._make_layer( + Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) # 7x7 + + # body + self.body = nn.Sequential(OrderedDict([ + ('SpaceToDepth', SpaceToDepthModule()), + ('conv1', conv1), + ('layer1', layer1), + ('layer2', layer2), + ('layer3', layer3), + ('layer4', layer4)])) + + self.feature_info = [ + dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D? + dict(num_chs=self.planes, reduction=4, module='body.layer1'), + dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'), + dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'), + dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'), + ] + + # head + self.num_features = (self.planes * 8) * Bottleneck.expansion + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # model initilization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # residual connections special initialization + for m in self.modules(): + if isinstance(m, BasicBlock): + m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero + if isinstance(m, Bottleneck): + m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero + if isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + + def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + layers = [] + if stride == 2: + # avg pooling before 1x1 conv + layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) + layers += [conv2d_iabn( + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")] + downsample = nn.Sequential(*layers) + + layers = [] + layers.append(block( + self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append( + block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer)) + return nn.Sequential(*layers) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='fast'): + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + return self.body(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_tresnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + TResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), + **kwargs) + + +@register_model +def tresnet_m(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_m_miil_in21k(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m_miil_in21k', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_l(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_m_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs) + return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_l_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs) + return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs) + + +@register_model +def tresnet_xl_448(pretrained=False, **kwargs): + model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs) + return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/twins.py b/timm/models/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..4aed09d90f4d832e798bc7dd39d3712cb20b966d --- /dev/null +++ b/timm/models/twins.py @@ -0,0 +1,422 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf + +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below + +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model +from .vision_transformer import Attention +from .helpers import build_model_with_cfg, overlay_external_default_cfg + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'twins_pcpvt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', + ), + 'twins_pcpvt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', + ), + 'twins_pcpvt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', + ), + 'twins_svt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', + ), + 'twins_svt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', + ), + 'twins_svt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', + ), +} + +Size_ = Tuple[int, int] + + +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + # def forward_mask(self, x, size: Size_): + # B, N, C = x.shape + # H, W = size + # x = x.view(B, H, W, C) + # pad_l = pad_t = 0 + # pad_r = (self.ws - W % self.ws) % self.ws + # pad_b = (self.ws - H % self.ws) % self.ws + # x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + # _, Hp, Wp, _ = x.shape + # _h, _w = Hp // self.ws, Wp // self.ws + # mask = torch.zeros((1, Hp, Wp), device=x.device) + # mask[:, -pad_b:, :].fill_(1) + # mask[:, :, -pad_r:].fill_(1) + # + # x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C + # mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h * _w, self.ws * self.ws) + # attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws + # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0)) + # qkv = self.qkv(x).reshape( + # B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + # # n_h, B, _w*_h, nhead, ws*ws, dim + # q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head + # attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws + # attn = attn + attn_mask.unsqueeze(2) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head + # attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + # x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + # if pad_r > 0 or pad_b > 0: + # x = x[:, :H, :W, :].contiguous() + # x = x.reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None): + super().__init__() + self.norm1 = norm_layer(dim) + if ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size: Size_): + x = x + self.drop_path(self.attn(self.norm1(x), size)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """ Twins Vision Transfomer (Revisiting Spatial Attention) + + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, + block_cls=Block): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # init weights + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_twins(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + Twins, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + +@register_model +def twins_pcpvt_small(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_pcpvt_base(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_pcpvt_large(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], + depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_small(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_base(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def twins_svt_large(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], + depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) + return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/vgg.py b/timm/models/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..8bea03e7ce31bada1790090561c99db9faa5ca76 --- /dev/null +++ b/timm/models/vgg.py @@ -0,0 +1,261 @@ +"""VGG + +Adapted from https://github.com/pytorch/vision 'vgg.py' (BSD-3-Clause) with a few changes for +timm functionality. + +Copyright 2021 Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, List, Dict, Any, cast + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct +from .registry import register_model + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = { + 'vgg11': _cfg(url='https://download.pytorch.org/models/vgg11-bbd30ac9.pth'), + 'vgg13': _cfg(url='https://download.pytorch.org/models/vgg13-c768596a.pth'), + 'vgg16': _cfg(url='https://download.pytorch.org/models/vgg16-397923af.pth'), + 'vgg19': _cfg(url='https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'), + 'vgg11_bn': _cfg(url='https://download.pytorch.org/models/vgg11_bn-6002323d.pth'), + 'vgg13_bn': _cfg(url='https://download.pytorch.org/models/vgg13_bn-abd245e5.pth'), + 'vgg16_bn': _cfg(url='https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'), + 'vgg19_bn': _cfg(url='https://download.pytorch.org/models/vgg19_bn-c79401a0.pth'), +} + + +cfgs: Dict[str, List[Union[str, int]]] = { + 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +class ConvMlp(nn.Module): + + def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0, + drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None): + super(ConvMlp, self).__init__() + self.input_kernel_size = kernel_size + mid_features = int(out_features * mlp_ratio) + self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True) + self.act1 = act_layer(True) + self.drop = nn.Dropout(drop_rate) + self.fc2 = conv_layer(mid_features, out_features, 1, bias=True) + self.act2 = act_layer(True) + + def forward(self, x): + if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size: + # keep the input size >= 7x7 + output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1])) + x = F.adaptive_avg_pool2d(x, output_size) + x = self.fc1(x) + x = self.act1(x) + x = self.drop(x) + x = self.fc2(x) + x = self.act2(x) + return x + + +class VGG(nn.Module): + + def __init__( + self, + cfg: List[Any], + num_classes: int = 1000, + in_chans: int = 3, + output_stride: int = 32, + mlp_ratio: float = 1.0, + act_layer: nn.Module = nn.ReLU, + conv_layer: nn.Module = nn.Conv2d, + norm_layer: nn.Module = None, + global_pool: str = 'avg', + drop_rate: float = 0., + ) -> None: + super(VGG, self).__init__() + assert output_stride == 32 + self.num_classes = num_classes + self.num_features = 4096 + self.drop_rate = drop_rate + self.feature_info = [] + prev_chs = in_chans + net_stride = 1 + pool_layer = nn.MaxPool2d + layers: List[nn.Module] = [] + for v in cfg: + last_idx = len(layers) - 1 + if v == 'M': + self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{last_idx}')) + layers += [pool_layer(kernel_size=2, stride=2)] + net_stride *= 2 + else: + v = cast(int, v) + conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1) + if norm_layer is not None: + layers += [conv2d, norm_layer(v), act_layer(inplace=True)] + else: + layers += [conv2d, act_layer(inplace=True)] + prev_chs = v + self.features = nn.Sequential(*layers) + self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}')) + self.pre_logits = ConvMlp( + prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio, + drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer) + self.head = ClassifierHead( + self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + self._initialize_weights() + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.head = ClassifierHead( + self.num_features, self.num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = self.pre_logits(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.head(x) + return x + + def _initialize_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def _filter_fn(state_dict): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + k_r = k + k_r = k_r.replace('classifier.0', 'pre_logits.fc1') + k_r = k_r.replace('classifier.3', 'pre_logits.fc2') + k_r = k_r.replace('classifier.6', 'head.fc') + if 'classifier.0.weight' in k: + v = v.reshape(-1, 512, 7, 7) + if 'classifier.3.weight' in k: + v = v.reshape(-1, 4096, 1, 1) + out_dict[k_r] = v + return out_dict + + +def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: + cfg = variant.split('_')[0] + # NOTE: VGG is one of the only models with stride==1 features, so indices are offset from other models + out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) + model = build_model_with_cfg( + VGG, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfgs[cfg], + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + pretrained_filter_fn=_filter_fn, + **kwargs) + return model + + +@register_model +def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 11-layer model (configuration "A") from + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg11', pretrained=pretrained, **model_args) + + +@register_model +def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 11-layer model (configuration "A") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg11_bn', pretrained=pretrained, **model_args) + + +@register_model +def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 13-layer model (configuration "B") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg13', pretrained=pretrained, **model_args) + + +@register_model +def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 13-layer model (configuration "B") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg13_bn', pretrained=pretrained, **model_args) + + +@register_model +def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 16-layer model (configuration "D") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg16', pretrained=pretrained, **model_args) + + +@register_model +def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 16-layer model (configuration "D") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg16_bn', pretrained=pretrained, **model_args) + + +@register_model +def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 19-layer model (configuration "E") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(**kwargs) + return _create_vgg('vgg19', pretrained=pretrained, **model_args) + + +@register_model +def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG: + r"""VGG 19-layer model (configuration 'E') with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ + """ + model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs) + return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args) \ No newline at end of file diff --git a/timm/models/visformer.py b/timm/models/visformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7740f38132aef6fb254aca6260881754a0212191 --- /dev/null +++ b/timm/models/visformer.py @@ -0,0 +1,409 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +""" +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier +from .registry import register_model + + +__all__ = ['Visformer'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + visformer_tiny=_cfg(), + visformer_small=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth' + ), +) + + +class SpatialMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.drop = nn.Dropout(drop) + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x[0], x[1], x[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, + group=8, attn_disabled=False, spatial_conv=False): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = SpatialMlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + group=group, spatial_conv=spatial_conv) # new setting + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): + super().__init__() + img_size = to_2tuple(img_size) + self.num_classes = num_classes + self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.pos_embed = pos_embed + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size = [x // 16 for x in img_size] + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size = [x // 8 for x in img_size] + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size = [x // 2 for x in img_size] + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size = [x // 4 for x in img_size] + + if self.pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.stage1 = nn.ModuleList([ + Block( + dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + ) + for i in range(self.stage_num1) + ]) + + # stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size = [x // 2 for x in img_size] + if self.pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) + self.stage2 = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, + embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size = [x // 2 for x in img_size] + if self.pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) + self.stage3 = nn.ModuleList([ + Block( + dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + # head + self.num_features = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(self.num_features) + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + # weights init + if self.pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed: + x = x + self.pos_embed1 + x = self.pos_drop(x) + for b in self.stage1: + x = b(x) + + # stage 2 + if not self.vit_stem: + x = self.patch_embed2(x) + if self.pos_embed: + x = x + self.pos_embed2 + x = self.pos_drop(x) + for b in self.stage2: + x = b(x) + + # stage3 + if not self.vit_stem: + x = self.patch_embed3(x) + if self.pos_embed: + x = x + self.pos_embed3 + x = self.pos_drop(x) + for b in self.stage3: + x = b(x) + + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + x = self.head(x) + return x + + +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Visformer, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + +@register_model +def visformer_tiny(pretrained=False, **kwargs): + model_cfg = dict( + init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs): + model_cfg = dict( + init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) + return model + + +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model + + + + diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4a85d80898c4b41d25f89a165d0ecc40fc85bb --- /dev/null +++ b/timm/models/vision_transformer.py @@ -0,0 +1,871 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (weights from official Google JAX impl) + 'vit_tiny_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_tiny_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_base_patch32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + ), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), + 'vit_huge_patch14_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', + hf_hub='timm/vit_huge_patch14_224_in21k', + num_classes=21843), + + # deit models (FB weights) + 'deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), + 'deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil_in21k': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, + ), + 'vit_base_patch16_224_miil': _cfg( + url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' + '/vit_base_patch16_224_1k_miil_84_4.pth', + mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', + ), +} + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, + act_layer=None, weight_init=''): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if mode.startswith('jax'): + # leave cls token as zeros to match jax impl + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) + else: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + + + x = self.norm(x) + if self.dist_token is None: + return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model = build_model_with_cfg( + VisionTransformer, variant, pretrained, + default_cfg=default_cfg, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + pretrained_custom_load='npz' in default_cfg['url'], + **kwargs) + return model + + +@register_model +def vit_tiny_patch16_224(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_384(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_384(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict( + patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer( + 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer( + 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + return model \ No newline at end of file diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f0a5377ec9492c5ed55ceb3ce5a4378cbb8e3c --- /dev/null +++ b/timm/models/vision_transformer_hybrid.py @@ -0,0 +1,363 @@ +""" Hybrid Vision Transformer (ViT) in PyTorch + +A PyTorch implement of the Hybrid Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.TODO + +NOTE These hybrid model definitions depend on code in vision_transformer.py. +They were moved here to keep file sizes sane. + +Hacked together by / Copyright 2021 Ross Wightman +""" +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import StdConv2dSame, StdConv2d, to_2tuple +from .resnet import resnet26d, resnet50d +from .resnetv2 import ResNetV2, create_resnetv2_stem +from .registry import register_model +from timm.models.vision_transformer import _create_vision_transformer + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # hybrid in-1k models (weights from official JAX impl where they exist) + 'vit_tiny_r_s16_p8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + first_conv='patch_embed.backbone.conv'), + 'vit_tiny_r_s16_p8_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_r26_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + ), + 'vit_small_r26_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_r26_s32_224': _cfg(), + 'vit_base_r50_s16_224': _cfg(), + 'vit_base_r50_s16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_r50_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' + ), + 'vit_large_r50_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0 + ), + + # hybrid in-21k models (weights from official Google JAX impl where they exist) + 'vit_tiny_r_s16_p8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), + 'vit_small_r26_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9), + 'vit_base_r50_s16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9), + + # hybrid models (using timm resnet backbones) + 'vit_small_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_small_resnet50d_s16_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet50d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), +} + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): + embed_layer = partial(HybridEmbed, backbone=backbone) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return _create_vision_transformer( + variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) + + +def _resnetv2(layers=(3, 4, 9), **kwargs): + """ ResNet-V2 backbone helper""" + padding_same = kwargs.get('padding_same', True) + stem_type = 'same' if padding_same else '' + conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8) + if len(layers): + backbone = ResNetV2( + layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type=stem_type, conv_layer=conv_layer) + else: + backbone = create_resnetv2_stem( + kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) + return backbone + + +@register_model +def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-B/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224(pretrained=False, **kwargs): + """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_384(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50_384(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) + + +@register_model +def vit_large_r50_s32_224(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_r50_s32_384(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. ImageNet-21k. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + backbone = _resnetv2(layers=(3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): + # DEPRECATED this is forwarding to model def above for backwards compatibility + return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) + + +@register_model +def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. ImageNet-21k. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_resnet26d_224(pretrained=False, **kwargs): + """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. + """ + backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): + """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. + """ + backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet26d_224(pretrained=False, **kwargs): + """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. + """ + backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_resnet50d_224(pretrained=False, **kwargs): + """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. + """ + backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model \ No newline at end of file diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5b3e81608b05c54b4e3725b1838d8395aa33ca --- /dev/null +++ b/timm/models/vovnet.py @@ -0,0 +1,406 @@ +""" VoVNet (V1 & V2) + +Papers: +* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730 +* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Looked at https://github.com/youngwanLEE/vovnet-detectron2 & +https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +for some reference, rewrote most of the code. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model +from .helpers import build_model_with_cfg +from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\ + create_attn, create_norm_act, get_norm_act_layer + + +# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & +# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py +model_cfgs = dict( + vovnet39a=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=False, + depthwise=False, + attn='', + ), + vovnet57a=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=False, + depthwise=False, + attn='', + + ), + ese_vovnet19b_slim_dw=dict( + stem_chs=[64, 64, 64], + stage_conv_chs=[64, 80, 96, 112], + stage_out_chs=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + + ), + ese_vovnet19b_dw=dict( + stem_chs=[64, 64, 64], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=True, + attn='ese', + ), + ese_vovnet19b_slim=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[64, 80, 96, 112], + stage_out_chs=[112, 256, 384, 512], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet19b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=3, + block_per_stage=[1, 1, 1, 1], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet39b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='ese', + ), + ese_vovnet57b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 4, 3], + residual=True, + depthwise=False, + attn='ese', + + ), + ese_vovnet99b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 3, 9, 3], + residual=True, + depthwise=False, + attn='ese', + ), + eca_vovnet39b=dict( + stem_chs=[64, 64, 128], + stage_conv_chs=[128, 160, 192, 224], + stage_out_chs=[256, 512, 768, 1024], + layer_per_block=5, + block_per_stage=[1, 1, 2, 2], + residual=True, + depthwise=False, + attn='eca', + ), +) +model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b'] +model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b'] + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + vovnet39a=_cfg(url=''), + vovnet57a=_cfg(url=''), + ese_vovnet19b_slim_dw=_cfg(url=''), + ese_vovnet19b_dw=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'), + ese_vovnet19b_slim=_cfg(url=''), + ese_vovnet39b=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'), + ese_vovnet57b=_cfg(url=''), + ese_vovnet99b=_cfg(url=''), + eca_vovnet39b=_cfg(url=''), + ese_vovnet39b_evos=_cfg(url=''), + ese_vovnet99b_iabn=_cfg(url=''), +) + + +class SequentialAppendList(nn.Sequential): + def __init__(self, *args): + super(SequentialAppendList, self).__init__(*args) + + def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: + for i, module in enumerate(self): + if i == 0: + concat_list.append(module(x)) + else: + concat_list.append(module(concat_list[-1])) + x = torch.cat(concat_list, dim=1) + return x + + +class OsaBlock(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, + depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): + super(OsaBlock, self).__init__() + + self.residual = residual + self.depthwise = depthwise + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + + next_in_chs = in_chs + if self.depthwise and next_in_chs != mid_chs: + assert not residual + self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs) + else: + self.conv_reduction = None + + mid_convs = [] + for i in range(layer_per_block): + if self.depthwise: + conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs) + else: + conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs) + next_in_chs = mid_chs + mid_convs.append(conv) + self.conv_mid = SequentialAppendList(*mid_convs) + + # feature aggregation + next_in_chs = in_chs + layer_per_block * mid_chs + self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs) + + if attn: + self.attn = create_attn(attn, out_chs) + else: + self.attn = None + + self.drop_path = drop_path + + def forward(self, x): + output = [x] + if self.conv_reduction is not None: + x = self.conv_reduction(x) + x = self.conv_mid(x, output) + x = self.conv_concat(x) + if self.attn is not None: + x = self.attn(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.residual: + x = x + output[0] + return x + + +class OsaStage(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, + residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, + drop_path_rates=None): + super(OsaStage, self).__init__() + + if downsample: + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True) + else: + self.pool = None + + blocks = [] + for i in range(block_per_stage): + last_block = i == block_per_stage - 1 + if drop_path_rates is not None and drop_path_rates[i] > 0.: + drop_path = DropPath(drop_path_rates[i]) + else: + drop_path = None + blocks += [OsaBlock( + in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise, + attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path) + ] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + if self.pool is not None: + x = self.pool(x) + x = self.blocks(x) + return x + + +class VovNet(nn.Module): + + def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, + output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): + """ VovNet (v2) + """ + super(VovNet, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert stem_stride in (4, 2) + assert output_stride == 32 # FIXME support dilation + + stem_chs = cfg["stem_chs"] + stage_conv_chs = cfg["stage_conv_chs"] + stage_out_chs = cfg["stage_out_chs"] + block_per_stage = cfg["block_per_stage"] + layer_per_block = cfg["layer_per_block"] + conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer) + + # Stem module + last_stem_stride = stem_stride // 2 + conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs), + conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs), + conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs), + ]) + self.feature_info = [dict( + num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')] + current_stride = stem_stride + + # OSA stages + stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage) + in_ch_list = stem_chs[-1:] + stage_out_chs[:-1] + stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs) + stages = [] + for i in range(4): # num_stages + downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 + stages += [OsaStage( + in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, + downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args) + ] + self.num_features = stage_out_chs[i] + current_stride *= 2 if downsample else 1 + self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] + + self.stages = nn.Sequential(*stages) + + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + for n, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + return self.stages(x) + + def forward(self, x): + x = self.forward_features(x) + return self.head(x) + + +def _create_vovnet(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + VovNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) + + +@register_model +def vovnet39a(pretrained=False, **kwargs): + return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs) + + +@register_model +def vovnet57a(pretrained=False, **kwargs): + return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim_dw(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_dw(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet19b_slim(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet39b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet57b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs) + + +@register_model +def ese_vovnet99b(pretrained=False, **kwargs): + return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs) + + +@register_model +def eca_vovnet39b(pretrained=False, **kwargs): + return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs) + + +# Experimental Models + +@register_model +def ese_vovnet39b_evos(pretrained=False, **kwargs): + def norm_act_fn(num_features, **nkwargs): + return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) + return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) + + +@register_model +def ese_vovnet99b_iabn(pretrained=False, **kwargs): + norm_layer = get_norm_act_layer('iabn') + return _create_vovnet( + 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) diff --git a/timm/models/xception.py b/timm/models/xception.py new file mode 100644 index 0000000000000000000000000000000000000000..86f558cb5b2b890ef74d11c99eea50c33eff653e --- /dev/null +++ b/timm/models/xception.py @@ -0,0 +1,232 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" + +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import build_model_with_cfg +from .layers import create_classifier +from .registry import register_model + +__all__ = ['Xception'] + +default_cfgs = { + 'xception': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', + 'input_size': (3, 299, 299), + 'pool_size': (10, 10), + 'crop_pct': 0.8975, + 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_channels != in_channels or strides != 1: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_channels) + else: + self.skip = None + + rep = [] + for i in range(reps): + if grow_first: + inc = in_channels if i == 0 else out_channels + outc = out_channels + else: + inc = in_channels + outc = in_channels if i < (reps - 1) else out_channels + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1)) + rep.append(nn.BatchNorm2d(outc)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = 2048 + + self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, 2, 2, start_with_relu=False) + self.block2 = Block(128, 256, 2, 2) + self.block3 = Block(256, 728, 2, 2) + + self.block4 = Block(728, 728, 3, 1) + self.block5 = Block(728, 728, 3, 1) + self.block6 = Block(728, 728, 3, 1) + self.block7 = Block(728, 728, 3, 1) + + self.block8 = Block(728, 728, 3, 1) + self.block9 = Block(728, 728, 3, 1) + self.block10 = Block(728, 728, 3, 1) + self.block11 = Block(728, 728, 3, 1) + + self.block12 = Block(728, 1024, 2, 2, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(self.num_features) + self.act4 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block2.rep.0'), + dict(num_chs=256, reduction=8, module='block3.rep.0'), + dict(num_chs=728, reduction=16, module='block12.rep.0'), + dict(num_chs=2048, reduction=32, module='act4'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + # #------- init weights -------- + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + Xception, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), + **kwargs) + + +@register_model +def xception(pretrained=False, **kwargs): + return _xception('xception', pretrained=pretrained, **kwargs) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7f5c05e06e0e1962074b0dffae2bdb731d2bb6 --- /dev/null +++ b/timm/models/xception_aligned.py @@ -0,0 +1,238 @@ +"""Pytorch impl of Aligned Xception 41, 65, 71 + +This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at +https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md + +Hacked together by / Copyright 2020 Ross Wightman +""" +from functools import partial + +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, ConvBnAct, create_conv2d +from .layers.helpers import to_3tuple +from .registry import register_model + +__all__ = ['XceptionAligned'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10), + 'crop_pct': 0.903, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + xception41=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'), + xception65=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), + xception71=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), +) + + +class SeparableConv2d(nn.Module): + def __init__( + self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SeparableConv2d, self).__init__() + self.kernel_size = kernel_size + self.dilation = dilation + + # depthwise convolution + self.conv_dw = create_conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + self.bn_dw = norm_layer(inplanes) + if act_layer is not None: + self.act_dw = act_layer(inplace=True) + else: + self.act_dw = None + + # pointwise convolution + self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) + self.bn_pw = norm_layer(planes) + if act_layer is not None: + self.act_pw = act_layer(inplace=True) + else: + self.act_pw = None + + def forward(self, x): + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.act_dw is not None: + x = self.act_dw(x) + x = self.conv_pw(x) + x = self.bn_pw(x) + if self.act_pw is not None: + x = self.act_pw(x) + return x + + +class XceptionModule(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, pad_type='', + start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None): + super(XceptionModule, self).__init__() + out_chs = to_3tuple(out_chs) + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = ConvBnAct( + in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None) + else: + self.shortcut = None + + separable_act_layer = None if start_with_relu else act_layer + self.stack = nn.Sequential() + for i in range(3): + if start_with_relu: + self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) + self.stack.add_module(f'conv{i + 1}', SeparableConv2d( + in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, + act_layer=separable_act_layer, norm_layer=norm_layer)) + in_chs = out_chs[i] + + def forward(self, x): + skip = x + x = self.stack(x) + if self.shortcut is not None: + skip = self.shortcut(skip) + if not self.no_skip: + x = x + skip + return x + + +class XceptionAligned(nn.Module): + """Modified Aligned Xception + """ + + def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): + super(XceptionAligned, self).__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + assert output_stride in (8, 16, 32) + + layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) + self.stem = nn.Sequential(*[ + ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), + ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) + ]) + + curr_dilation = 1 + curr_stride = 2 + self.feature_info = [] + self.blocks = nn.Sequential() + for i, b in enumerate(block_cfg): + b['dilation'] = curr_dilation + if b['stride'] > 1: + self.feature_info += [dict( + num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] + next_stride = curr_stride * b['stride'] + if next_stride > output_stride: + curr_dilation *= b['stride'] + b['stride'] = 1 + else: + curr_stride = next_stride + self.blocks.add_module(str(i), XceptionModule(**b, **layer_args)) + self.num_features = self.blocks[-1].out_channels + + self.feature_info += [dict( + num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] + + self.head = ClassifierHead( + in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _xception(variant, pretrained=False, **kwargs): + return build_model_with_cfg( + XceptionAligned, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), + **kwargs) + + +@register_model +def xception41(pretrained=False, **kwargs): + """ Modified Aligned Xception-41 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception41', pretrained=pretrained, **model_args) + + +@register_model +def xception65(pretrained=False, **kwargs): + """ Modified Aligned Xception-65 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception65', pretrained=pretrained, **model_args) + + +@register_model +def xception71(pretrained=False, **kwargs): + """ Modified Aligned Xception-71 + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=1), + dict(in_chs=256, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=1), + dict(in_chs=728, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 16), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False), + ] + model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) + return _xception('xception71', pretrained=pretrained, **model_args) diff --git a/timm/models/xcit.py b/timm/models/xcit.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad162ede825a947ec3e881b98b0515dcacad2a0 --- /dev/null +++ b/timm/models/xcit.py @@ -0,0 +1,814 @@ +""" Cross-Covariance Image Transformer (XCiT) in PyTorch + +Paper: + - https://arxiv.org/abs/2106.09681 + +Same as the official implementation, with some minor adaptations, original copyright below + - https://github.com/facebookresearch/xcit/blob/master/xcit.py + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +import math +from functools import partial + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .vision_transformer import _cfg, Mlp +from .registry import register_model +from .layers import DropPath, trunc_normal_, to_2tuple +from .cait import ClassAttn +from .fx_features import register_notrace_module + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # Patch size 16 + 'xcit_nano_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'), + 'xcit_nano_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'), + 'xcit_nano_12_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'), + 'xcit_tiny_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'), + 'xcit_tiny_12_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'), + 'xcit_tiny_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'), + 'xcit_tiny_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'), + 'xcit_small_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'), + 'xcit_small_12_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'), + 'xcit_small_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'), + 'xcit_small_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_medium_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'), + 'xcit_medium_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'), + 'xcit_medium_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_large_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'), + 'xcit_large_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'), + 'xcit_large_24_p16_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)), + + # Patch size 8 + 'xcit_nano_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'), + 'xcit_nano_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'), + 'xcit_nano_12_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'), + 'xcit_tiny_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'), + 'xcit_tiny_12_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_tiny_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'), + 'xcit_tiny_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'), + 'xcit_tiny_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'), + 'xcit_small_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'), + 'xcit_small_12_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_small_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'), + 'xcit_small_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'), + 'xcit_small_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_medium_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'), + 'xcit_medium_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'), + 'xcit_medium_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)), + 'xcit_large_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'), + 'xcit_large_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'), + 'xcit_large_24_p8_384_dist': _cfg( + url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)), +} + + +@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method +class PositionalEncodingFourier(nn.Module): + """ + Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper. + Based on the official XCiT code + - https://github.com/facebookresearch/xcit/blob/master/xcit.py + """ + + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + self.eps = 1e-6 + + def forward(self, B: int, H: int, W: int): + device = self.token_projection.weight.device + y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W) + x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1) + y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3) + pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + return pos.repeat(B, 1, 1, 1) # (B, C, H, W) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution + batch norm""" + return torch.nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_planes) + ) + + +class ConvPatchEmbed(nn.Module): + """Image to Patch Embedding using multiple convolutional layers""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU): + super().__init__() + img_size = to_2tuple(img_size) + num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + if patch_size == 16: + self.proj = torch.nn.Sequential( + conv3x3(in_chans, embed_dim // 8, 2), + act_layer(), + conv3x3(embed_dim // 8, embed_dim // 4, 2), + act_layer(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + act_layer(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + elif patch_size == 8: + self.proj = torch.nn.Sequential( + conv3x3(in_chans, embed_dim // 4, 2), + act_layer(), + conv3x3(embed_dim // 4, embed_dim // 2, 2), + act_layer(), + conv3x3(embed_dim // 2, embed_dim, 2), + ) + else: + raise('For convolutional projection, patch size has to be in [8, 16]') + + def forward(self, x): + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + x = x.flatten(2).transpose(1, 2) # (B, N, C) + return x, (Hp, Wp) + + +class LPI(nn.Module): + """ + Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the + implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable + 3x3 convolutions with GeLU and BatchNorm2d + """ + + def __init__(self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3): + super().__init__() + out_features = out_features or in_features + + padding = kernel_size // 2 + + self.conv1 = torch.nn.Conv2d( + in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features) + self.act = act_layer() + self.bn = nn.BatchNorm2d(in_features) + self.conv2 = torch.nn.Conv2d( + in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features) + + def forward(self, x, H: int, W: int): + B, N, C = x.shape + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.conv1(x) + x = self.act(x) + x = self.bn(x) + x = self.conv2(x) + x = x.reshape(B, C, N).permute(0, 2, 1) + return x + + +class ClassAttentionBlock(nn.Module): + """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239""" + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False): + super().__init__() + self.norm1 = norm_layer(dim) + + self.attn = ClassAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + if eta is not None: # LayerScale Initialization (no layerscale when None) + self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + else: + self.gamma1, self.gamma2 = 1.0, 1.0 + + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 + self.tokens_norm = tokens_norm + + def forward(self, x): + x_norm1 = self.norm1(x) + x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) + x = x + self.drop_path(self.gamma1 * x_attn) + if self.tokens_norm: + x = self.norm2(x) + else: + x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) + x_res = x + cls_token = x[:, 0:1] + cls_token = self.gamma2 * self.mlp(cls_token) + x = torch.cat([cls_token, x[:, 1:]], dim=1) + x = x_res + self.drop_path(x) + return x + + +class XCA(nn.Module): + """ Cross-Covariance Attention (XCA) + Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax + normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h) + """ + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # Paper section 3.2 l2-Normalization and temperature scaling + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + # (B, H, C', N), permute -> (B, N, H, C') + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class XCABlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm3 = norm_layer(dim) + self.local_mp = LPI(in_features=dim, act_layer=act_layer) + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + + self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) + + def forward(self, x, H: int, W: int): + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) + # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 + x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + return x + + +class XCiT(nn.Module): + """ + Based on timm and DeiT code bases + https://github.com/rwightman/pytorch-image-models/tree/master/timm + https://github.com/facebookresearch/deit/ + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False): + """ + Args: + img_size (int, tuple): input image size + patch_size (int): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate (constant across all layers) + norm_layer: (nn.Module): normalization layer + cls_attn_layers: (int) Depth of Class attention layers + use_pos_embed: (bool) whether to use positional encoding + eta: (float) layerscale initialization value + tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA + + Notes: + - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch + interaction (class LPI) and the patch embedding (class ConvPatchEmbed) + """ + super().__init__() + img_size = to_2tuple(img_size) + assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \ + '`patch_size` should divide image dimensions evenly' + + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = ConvPatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.use_pos_embed = use_pos_embed + if use_pos_embed: + self.pos_embed = PositionalEncodingFourier(dim=embed_dim) + self.pos_drop = nn.Dropout(p=drop_rate) + + self.blocks = nn.ModuleList([ + XCABlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta) + for _ in range(depth)]) + + self.cls_attn_blocks = nn.ModuleList([ + ClassAttentionBlock( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm) + for _ in range(cls_attn_layers)]) + + # Classifier head + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # Init weights + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + B = x.shape[0] + # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches) + x, (Hp, Wp) = self.patch_embed(x) + + if self.use_pos_embed: + # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C) + pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x, Hp, Wp) + + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + for blk in self.cls_attn_blocks: + x = blk(x) + + x = self.norm(x)[:, 0] + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + state_dict = state_dict['model'] + # For consistency with timm's transformer models while being compatible with official weights source we rename + # pos_embeder to pos_embed. Also account for use_pos_embed == False + use_pos_embed = getattr(model, 'pos_embed', None) is not None + pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')] + for k in pos_embed_keys: + if use_pos_embed: + state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k) + else: + del state_dict[k] + # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors + # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v + if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict(): + num_ca_blocks = len(model.cls_attn_blocks) + for i in range(num_ca_blocks): + qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight') + qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1]) + for j, subscript in enumerate('qkv'): + state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j] + qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None) + if qkv_bias is not None: + qkv_bias = qkv_bias.reshape(3, -1) + for j, subscript in enumerate('qkv'): + state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j] + return state_dict + + +def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs): + default_cfg = default_cfg or default_cfgs[variant] + model = build_model_with_cfg( + XCiT, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) + return model + + +@register_model +def xcit_nano_12_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs) + model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p16_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p16_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +# Patch size 8x8 models +@register_model +def xcit_nano_12_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs) + model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_12_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_small_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p8_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p8_224_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def xcit_large_24_p8_384_dist(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs) + model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e2bf46aa4110dd67ac5bc5b6bf2f452a3afb6eee --- /dev/null +++ b/timm/optim/__init__.py @@ -0,0 +1,13 @@ +from .adamp import AdamP +from .adamw import AdamW +from .adafactor import Adafactor +from .adahessian import Adahessian +from .lookahead import Lookahead +from .nadam import Nadam +from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .adabelief import AdaBelief +from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs diff --git a/timm/optim/__pycache__/__init__.cpython-38.pyc b/timm/optim/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aefb1ffa6826eb0bc294ff7ac1065e18837ca371 Binary files /dev/null and b/timm/optim/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/adabelief.cpython-38.pyc b/timm/optim/__pycache__/adabelief.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62abdc4467debda90767a8f52ef51a02e6618a44 Binary files /dev/null and b/timm/optim/__pycache__/adabelief.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/adafactor.cpython-38.pyc b/timm/optim/__pycache__/adafactor.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7bd107fe7053bb686562cf9a1b14b599207703f Binary files /dev/null and b/timm/optim/__pycache__/adafactor.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/adahessian.cpython-38.pyc b/timm/optim/__pycache__/adahessian.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4ac64c43049805b9d3f0e6e4b71d2b49e205d68 Binary files /dev/null and b/timm/optim/__pycache__/adahessian.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/adamp.cpython-38.pyc b/timm/optim/__pycache__/adamp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de20264ea836ce8813aabb928bc3e978c9a0e0fb Binary files /dev/null and b/timm/optim/__pycache__/adamp.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/adamw.cpython-38.pyc b/timm/optim/__pycache__/adamw.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc84d77fb03940970e463e865085faa1b82d8cd Binary files /dev/null and b/timm/optim/__pycache__/adamw.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/lamb.cpython-38.pyc b/timm/optim/__pycache__/lamb.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cd853ef5d8a5ff5100110fe869a6a78716cf784 Binary files /dev/null and b/timm/optim/__pycache__/lamb.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/lars.cpython-38.pyc b/timm/optim/__pycache__/lars.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ad6622fbcfba0910df3d8f8ebf5f3762435dcf Binary files /dev/null and b/timm/optim/__pycache__/lars.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/lookahead.cpython-38.pyc b/timm/optim/__pycache__/lookahead.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98c1e210aaab4df3c4f69472f70ec7eb2cb3aaf6 Binary files /dev/null and b/timm/optim/__pycache__/lookahead.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/madgrad.cpython-38.pyc b/timm/optim/__pycache__/madgrad.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5497d3110788a9c4364d8dc44c6e6af6fe0ab326 Binary files /dev/null and b/timm/optim/__pycache__/madgrad.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/nadam.cpython-38.pyc b/timm/optim/__pycache__/nadam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc50778ad52a2ead69d4f7a9d079ea5341bf33ea Binary files /dev/null and b/timm/optim/__pycache__/nadam.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/novograd.cpython-38.pyc b/timm/optim/__pycache__/novograd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41e0f344980caf9ea8b8d714954877eda1347573 Binary files /dev/null and b/timm/optim/__pycache__/novograd.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/nvnovograd.cpython-38.pyc b/timm/optim/__pycache__/nvnovograd.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..742f8fe5a7391374af14f8192232f71fb26f8827 Binary files /dev/null and b/timm/optim/__pycache__/nvnovograd.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/optim_factory.cpython-38.pyc b/timm/optim/__pycache__/optim_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51d44a2d79fbd20680efd4a8fa07856287361f64 Binary files /dev/null and b/timm/optim/__pycache__/optim_factory.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/radam.cpython-38.pyc b/timm/optim/__pycache__/radam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bffd4a62ec2eac80c9a006500f5a3606f46dfb03 Binary files /dev/null and b/timm/optim/__pycache__/radam.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/rmsprop_tf.cpython-38.pyc b/timm/optim/__pycache__/rmsprop_tf.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a87a15792f2885aaad55fc7e32805263190e7885 Binary files /dev/null and b/timm/optim/__pycache__/rmsprop_tf.cpython-38.pyc differ diff --git a/timm/optim/__pycache__/sgdp.cpython-38.pyc b/timm/optim/__pycache__/sgdp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60ebff45015de7d7be2dfa721dd1174a6f8787a Binary files /dev/null and b/timm/optim/__pycache__/sgdp.cpython-38.pyc differ diff --git a/timm/optim/adabelief.py b/timm/optim/adabelief.py new file mode 100644 index 0000000000000000000000000000000000000000..a26d7b27ac85ce65a02bc2e938058b685d914a65 --- /dev/null +++ b/timm/optim/adabelief.py @@ -0,0 +1,205 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdaBelief(Optimizer): + r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-16) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + weight_decouple (boolean, optional): ( default: True) If set as True, then + the optimizer uses decoupled weight decay as in AdamW + fixed_decay (boolean, optional): (default: False) This is used when weight_decouple + is set as True. + When fixed_decay == True, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay$. + When fixed_decay == False, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the + weight decay ratio decreases with learning rate (lr). + rectify (boolean, optional): (default: True) If set as True, then perform the rectified + update similar to RAdam + degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update + when variance of gradient is high + reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, + weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True, + degenerated_to_sgd=True): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, buffer=[[None, None, None] for _ in range(10)]) + super(AdaBelief, self).__init__(params, defaults) + + self.degenerated_to_sgd = degenerated_to_sgd + self.weight_decouple = weight_decouple + self.rectify = rectify + self.fixed_decay = fixed_decay + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + amsgrad = group['amsgrad'] + + # State initialization + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p.data) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # cast data type + half_precision = False + if p.data.dtype == torch.float16: + half_precision = True + p.data = p.data.float() + p.grad = p.grad.float() + + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + beta1, beta2 = group['betas'] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p.data) + + # perform weight decay, check if decoupled weight decay + if self.weight_decouple: + if not self.fixed_decay: + p.data.mul_(1.0 - group['lr'] * group['weight_decay']) + else: + p.data.mul_(1.0 - group['weight_decay']) + else: + if group['weight_decay'] != 0: + grad.add_(p.data, alpha=group['weight_decay']) + + # get current state variable + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Update first and second moment running average + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - exp_avg + exp_avg_var.mul_(beta2).addcmul_( grad_residual, grad_residual, value=1 - beta2) + + if amsgrad: + max_exp_avg_var = state['max_exp_avg_var'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + # update + if not self.rectify: + # Default update + step_size = group['lr'] / bias_correction1 + p.data.addcdiv_( exp_avg, denom, value=-step_size) + + else: # Rectified update, forked from RAdam + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + if N_sma >= 5: + denom = exp_avg_var.sqrt().add_(group['eps']) + p.data.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + elif step_size > 0: + p.data.add_( exp_avg, alpha=-step_size * group['lr']) + + if half_precision: + p.data = p.data.half() + p.grad = p.grad.half() + + return loss diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py new file mode 100644 index 0000000000000000000000000000000000000000..088ce3acd82e2be1b393afafa05f48435e538a1a --- /dev/null +++ b/timm/optim/adafactor.py @@ -0,0 +1,174 @@ +""" Adafactor Optimizer + +Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py + +Original header/copyright below. + +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +import math + + +class Adafactor(torch.optim.Optimizer): + """Implements Adafactor algorithm. + This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` + (see https://arxiv.org/abs/1804.04235) + + Note that this optimizer internally adjusts the learning rate depending on the + *scale_parameter*, *relative_step* and *warmup_init* options. + + To use a manual (external) learning rate schedule you should set `scale_parameter=False` and + `relative_step=False`. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): external learning rate (default: None) + eps (tuple[float, float]): regularization constants for square gradient + and parameter scale respectively (default: (1e-30, 1e-3)) + clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) + decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) + beta1 (float): coefficient used for computing running averages of gradient (default: None) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) + relative_step (bool): if True, time-dependent learning rate is computed + instead of external learning rate (default: True) + warmup_init (bool): time-dependent learning rate computation depends on + whether warm-up initialization is being used (default: False) + """ + + def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0, + decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False): + relative_step = lr is None + if warmup_init and not relative_step: + raise ValueError('warmup_init requires relative_step=True') + + beta1 = None if betas is None else betas[0] # make it compat with standard betas arg + defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate, + beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, + relative_step=relative_step, warmup_init=warmup_init) + super(Adafactor, self).__init__(params, defaults) + + @staticmethod + def _get_lr(param_group, param_state): + if param_group['relative_step']: + min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 + lr_t = min(min_step, 1.0 / math.sqrt(param_state['step'])) + param_scale = 1.0 + if param_group['scale_parameter']: + param_scale = max(param_group['eps_scale'], param_state['RMS']) + param_group['lr'] = lr_t * param_scale + return param_group['lr'] + + @staticmethod + def _get_options(param_group, param_shape): + factored = len(param_shape) >= 2 + use_first_moment = param_group['beta1'] is not None + return factored, use_first_moment + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel() ** 0.5) + + def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): + r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1) + c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt() + return torch.mul(r_factor, c_factor) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError('Adafactor does not support sparse gradients.') + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = self._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state['step'] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(grad) + if factored: + state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) + state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['RMS'] = 0 + else: + if use_first_moment: + state['exp_avg'] = state['exp_avg'].to(grad) + if factored: + state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) + state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) + else: + state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state['step'] += 1 + state['RMS'] = self._rms(p_data_fp32) + lr_t = self._get_lr(group, state) + + beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) + update = grad ** 2 + group['eps'] + if factored: + exp_avg_sq_row = state['exp_avg_sq_row'] + exp_avg_sq_col = state['exp_avg_sq_col'] + + exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) + exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) + #exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t) # pytorch 1.6+ + #exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t) + + # Approximation of exponential moving average of square of gradient + update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) + #exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) # pytorch 1.6+ + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0)) + update.mul_(lr_t) + + if use_first_moment: + exp_avg = state['exp_avg'] + exp_avg.mul_(group["beta1"]).add_(1 - group["beta1"], update) + #exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) # pytorch 1.6+ + update = exp_avg + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group["weight_decay"] * lr_t, p_data_fp32) + #p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr_t) # pytorch 1.6+ + + p_data_fp32.add_(-update) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss \ No newline at end of file diff --git a/timm/optim/adahessian.py b/timm/optim/adahessian.py new file mode 100644 index 0000000000000000000000000000000000000000..985c67ca686a65f61f5c5b1a7db3e5bba815a19b --- /dev/null +++ b/timm/optim/adahessian.py @@ -0,0 +1,156 @@ +""" AdaHessian Optimizer + +Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py +Originally licensed MIT, Copyright 2020, David Samuel +""" +import torch + + +class Adahessian(torch.optim.Optimizer): + """ + Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning" + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + lr (float, optional): learning rate (default: 0.1) + betas ((float, float), optional): coefficients used for computing running averages of gradient and the + squared hessian trace (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0) + hessian_power (float, optional): exponent of the hessian trace (default: 1.0) + update_each (int, optional): compute the hessian trace approximation only after *this* number of steps + (to save time) (default: 1) + n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1) + """ + + def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, + hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= hessian_power <= 1.0: + raise ValueError(f"Invalid Hessian power value: {hessian_power}") + + self.n_samples = n_samples + self.update_each = update_each + self.avg_conv_kernel = avg_conv_kernel + + # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training + self.seed = 2147483647 + self.generator = torch.Generator().manual_seed(self.seed) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power) + super(Adahessian, self).__init__(params, defaults) + + for p in self.get_params(): + p.hess = 0.0 + self.state[p]["hessian step"] = 0 + + @property + def is_second_order(self): + return True + + def get_params(self): + """ + Gets all parameters in all param_groups with gradients + """ + + return (p for group in self.param_groups for p in group['params'] if p.requires_grad) + + def zero_hessian(self): + """ + Zeros out the accumalated hessian traces. + """ + + for p in self.get_params(): + if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0: + p.hess.zero_() + + @torch.no_grad() + def set_hessian(self): + """ + Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter. + """ + + params = [] + for p in filter(lambda p: p.grad is not None, self.get_params()): + if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step + params.append(p) + self.state[p]["hessian step"] += 1 + + if len(params) == 0: + return + + if self.generator.device != params[0].device: # hackish way of casting the generator to the right device + self.generator = torch.Generator(params[0].device).manual_seed(self.seed) + + grads = [p.grad for p in params] + + for i in range(self.n_samples): + # Rademacher distribution {-1.0, 1.0} + zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params] + h_zs = torch.autograd.grad( + grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1) + for h_z, z, p in zip(h_zs, zs, params): + p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z) + + @torch.no_grad() + def step(self, closure=None): + """ + Performs a single optimization step. + Arguments: + closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None) + """ + + loss = None + if closure is not None: + loss = closure() + + self.zero_hessian() + self.set_hessian() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None or p.hess is None: + continue + + if self.avg_conv_kernel and p.dim() == 4: + p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone() + + # Perform correct stepweight decay as in AdamW + p.mul_(1 - group['lr'] * group['weight_decay']) + + state = self.state[p] + + # State initialization + if len(state) == 1: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of Hessian diagonal square values + state['exp_hessian_diag_sq'] = torch.zeros_like(p) + + exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq'] + beta1, beta2 = group['betas'] + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1) + exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + k = group['hessian_power'] + denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps']) + + # make update + step_size = group['lr'] / bias_correction1 + p.addcdiv_(exp_avg, denom, value=-step_size) + + return loss diff --git a/timm/optim/adamp.py b/timm/optim/adamp.py new file mode 100644 index 0000000000000000000000000000000000000000..468c3e865e0ceb6fb2bf22f9388237a783314f07 --- /dev/null +++ b/timm/optim/adamp.py @@ -0,0 +1,107 @@ +""" +AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn as nn +from torch.optim.optimizer import Optimizer, required +import math + +class AdamP(Optimizer): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, + delta=delta, wd_ratio=wd_ratio, nesterov=nesterov) + super(AdamP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + grad = p.grad.data + beta1, beta2 = group['betas'] + nesterov = group['nesterov'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p.data) + state['exp_avg_sq'] = torch.zeros_like(p.data) + + # Adam + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + step_size = group['lr'] / bias_correction1 + + if nesterov: + perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom + else: + perturb = exp_avg / denom + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if group['weight_decay'] > 0: + p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio) + + # Step + p.data.add_(-step_size, perturb) + + return loss diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..66f9a959de586356a29ace2f9c57d3fee8d1057a --- /dev/null +++ b/timm/optim/adamw.py @@ -0,0 +1,117 @@ +""" AdamW Optimizer +Impl copied from PyTorch master +""" +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdamW(Optimizer): + r"""Implements AdamW algorithm. + + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + step_size = group['lr'] / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py new file mode 100644 index 0000000000000000000000000000000000000000..12c7c49b8a01ef793c97654ac938259ca6508449 --- /dev/null +++ b/timm/optim/lamb.py @@ -0,0 +1,192 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. + +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import math + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False): + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/timm/optim/lars.py b/timm/optim/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..98198e675c91e867d36865c67d360f65698218e7 --- /dev/null +++ b/timm/optim/lars.py @@ -0,0 +1,135 @@ +""" PyTorch LARS / LARC Optimizer + +An implementation of LARS (SGD) + LARC in PyTorch + +Based on: + * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + +Additional cleanup and modifications to properly support PyTorch XLA. + +Copyright 2021 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer + + +class Lars(Optimizer): + """ LARS for PyTorch + + Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf + + Args: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate (default: 1.0). + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001) + eps (float): eps for division denominator (default: 1e-8) + trust_clip (bool): enable LARC trust ratio clipping (default: False) + always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False) + """ + + def __init__( + self, + params, + lr=1.0, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + trust_coeff=0.001, + eps=1e-8, + trust_clip=False, + always_adapt=False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + trust_coeff=trust_coeff, + eps=eps, + trust_clip=trust_clip, + always_adapt=always_adapt, + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + trust_coeff = group['trust_coeff'] + eps = group['eps'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + + # apply LARS LR adaptation, LARC clipping, weight decay + # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py + if weight_decay != 0 or group['always_adapt']: + w_norm = p.norm(2.0) + g_norm = grad.norm(2.0) + trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps) + # FIXME nested where required since logical and/or not working in PT XLA + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, trust_ratio, one_tensor), + one_tensor, + ) + if group['trust_clip']: + trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor) + grad.add(p, alpha=weight_decay) + grad.mul_(trust_ratio) + + # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100 + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone(grad).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(grad, alpha=1. - dampening) + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf + + p.add_(grad, alpha=-group['lr']) + + return loss \ No newline at end of file diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py new file mode 100644 index 0000000000000000000000000000000000000000..6b5b7f38ec8cb6594e3986b66223fa2881daeca3 --- /dev/null +++ b/timm/optim/lookahead.py @@ -0,0 +1,92 @@ +""" Lookahead Optimizer Wrapper. +Implementation modified from: https://github.com/alphadl/lookahead.pytorch +Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch.optim.optimizer import Optimizer +from collections import defaultdict + + +class Lookahead(Optimizer): + def __init__(self, base_optimizer, alpha=0.5, k=6): + if not 0.0 <= alpha <= 1.0: + raise ValueError(f'Invalid slow update rate: {alpha}') + if not 1 <= k: + raise ValueError(f'Invalid lookahead steps: {k}') + defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) + self.base_optimizer = base_optimizer + self.param_groups = self.base_optimizer.param_groups + self.defaults = base_optimizer.defaults + self.defaults.update(defaults) + self.state = defaultdict(dict) + # manually add our defaults to the param groups + for name, default in defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) + + def update_slow(self, group): + for fast_p in group["params"]: + if fast_p.grad is None: + continue + param_state = self.state[fast_p] + if 'slow_buffer' not in param_state: + param_state['slow_buffer'] = torch.empty_like(fast_p.data) + param_state['slow_buffer'].copy_(fast_p.data) + slow = param_state['slow_buffer'] + slow.add_(group['lookahead_alpha'], fast_p.data - slow) + fast_p.data.copy_(slow) + + def sync_lookahead(self): + for group in self.param_groups: + self.update_slow(group) + + def step(self, closure=None): + #assert id(self.param_groups) == id(self.base_optimizer.param_groups) + loss = self.base_optimizer.step(closure) + for group in self.param_groups: + group['lookahead_step'] += 1 + if group['lookahead_step'] % group['lookahead_k'] == 0: + self.update_slow(group) + return loss + + def state_dict(self): + fast_state_dict = self.base_optimizer.state_dict() + slow_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + fast_state = fast_state_dict['state'] + param_groups = fast_state_dict['param_groups'] + return { + 'state': fast_state, + 'slow_state': slow_state, + 'param_groups': param_groups, + } + + def load_state_dict(self, state_dict): + fast_state_dict = { + 'state': state_dict['state'], + 'param_groups': state_dict['param_groups'], + } + self.base_optimizer.load_state_dict(fast_state_dict) + + # We want to restore the slow state, but share param_groups reference + # with base_optimizer. This is a bit redundant but least code + slow_state_new = False + if 'slow_state' not in state_dict: + print('Loading state_dict from optimizer without Lookahead applied.') + state_dict['slow_state'] = defaultdict(dict) + slow_state_new = True + slow_state_dict = { + 'state': state_dict['slow_state'], + 'param_groups': state_dict['param_groups'], # this is pointless but saves code + } + super(Lookahead, self).load_state_dict(slow_state_dict) + self.param_groups = self.base_optimizer.param_groups # make both ref same container + if slow_state_new: + # reapply defaults to catch missing lookahead specific ones + for name, default in self.defaults.items(): + for group in self.param_groups: + group.setdefault(name, default) diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py new file mode 100644 index 0000000000000000000000000000000000000000..a76713bf27ed1daf0ce598ac5f25c6238c7fdb57 --- /dev/null +++ b/timm/optim/madgrad.py @@ -0,0 +1,184 @@ +""" PyTorch MADGRAD optimizer + +MADGRAD: https://arxiv.org/abs/2101.11075 + +Code from: https://github.com/facebookresearch/madgrad +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch +import torch.optim + +if TYPE_CHECKING: + from torch.optim.optimizer import _params_t +else: + _params_t = Any + + +class MADGRAD(torch.optim.Optimizer): + """ + MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic + Optimization. + + .. _MADGRAD: https://arxiv.org/abs/2101.11075 + + MADGRAD is a general purpose optimizer that can be used in place of SGD or + Adam may converge faster and generalize better. Currently GPU-only. + Typically, the same learning rate schedule that is used for SGD or Adam may + be used. The overall learning rate is not comparable to either method and + should be determined by a hyper-parameter sweep. + + MADGRAD requires less weight decay than other methods, often as little as + zero. Momentum values used for SGD or Adam's beta1 should work here also. + + On sparse problems both weight_decay and momentum should be set to 0. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate (default: 1e-2). + momentum (float): + Momentum value in the range [0,1) (default: 0.9). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). + """ + + def __init__( + self, + params: _params_t, + lr: float = 1e-2, + momentum: float = 0.9, + weight_decay: float = 0, + eps: float = 1e-6, + decoupled_decay: bool = False, + ): + if momentum < 0 or momentum >= 1: + raise ValueError(f"Momentum {momentum} must be in the range [0,1]") + if lr <= 0: + raise ValueError(f"Learning rate {lr} must be positive") + if weight_decay < 0: + raise ValueError(f"Weight decay {weight_decay} must be non-negative") + if eps < 0: + raise ValueError(f"Eps must be non-negative") + + defaults = dict( + lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay) + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self) -> bool: + return False + + @property + def supports_flat_params(self) -> bool: + return True + + @torch.no_grad() + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + eps = group['eps'] + lr = group['lr'] + eps + weight_decay = group['weight_decay'] + momentum = group['momentum'] + ck = 1 - momentum + + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if momentum != 0.0 and grad.is_sparse: + raise RuntimeError("momentum != 0 is not compatible with sparse gradients") + + state = self.state[p] + if len(state) == 0: + state['step'] = 0 + state['grad_sum_sq'] = torch.zeros_like(p) + state['s'] = torch.zeros_like(p) + if momentum != 0: + state['x0'] = torch.clone(p).detach() + + state['step'] += 1 + grad_sum_sq = state['grad_sum_sq'] + s = state['s'] + lamb = lr * math.sqrt(state['step']) + + # Apply weight decay + if weight_decay != 0: + if group['decoupled_decay']: + p.mul_(1.0 - group['lr'] * weight_decay) + else: + if grad.is_sparse: + raise RuntimeError("weight_decay option is not compatible with sparse gradients") + grad.add_(p, alpha=weight_decay) + + if grad.is_sparse: + grad = grad.coalesce() + grad_val = grad._values() + + p_masked = p.sparse_mask(grad) + grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) + s_masked = s.sparse_mask(grad) + + # Compute x_0 from other known quantities + rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) + x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1) + + # Dense + sparse op + grad_sq = grad * grad + grad_sum_sq.add_(grad_sq, alpha=lamb) + grad_sum_sq_masked.add_(grad_sq, alpha=lamb) + + rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) + + s.add_(grad, alpha=lamb) + s_masked._values().add_(grad_val, alpha=lamb) + + # update masked copy of p + p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1) + # Copy updated masked p to dense p using an add operation + p_masked._values().add_(p_kp1_masked_vals, alpha=-1) + p.add_(p_masked, alpha=-1) + else: + if momentum == 0: + # Compute x_0 from other known quantities + rms = grad_sum_sq.pow(1 / 3).add_(eps) + x0 = p.addcdiv(s, rms, value=1) + else: + x0 = state['x0'] + + # Accumulate second moments + grad_sum_sq.addcmul_(grad, grad, value=lamb) + rms = grad_sum_sq.pow(1 / 3).add_(eps) + + # Update s + s.add_(grad, alpha=lamb) + + # Step + if momentum == 0: + p.copy_(x0.addcdiv(s, rms, value=-1)) + else: + z = x0.addcdiv(s, rms, value=-1) + + # p is a moving average of z + p.mul_(1 - ck).add_(z, alpha=ck) + + return loss diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py new file mode 100644 index 0000000000000000000000000000000000000000..d994d1b83485c9b068de73f5f3cf2efb1e5bec39 --- /dev/null +++ b/timm/optim/nadam.py @@ -0,0 +1,88 @@ +import torch +from torch.optim import Optimizer + + +class Nadam(Optimizer): + """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + schedule_decay (float, optional): momentum schedule decay (default: 4e-3) + + __ http://cs229.stanford.edu/proj2015/054_report.pdf + __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + + Originally taken from: https://github.com/pytorch/pytorch/pull/1408 + NOTE: Has potential issues but does work well on some problems. + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, schedule_decay=4e-3): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, schedule_decay=schedule_decay) + super(Nadam, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['m_schedule'] = 1. + state['exp_avg'] = grad.new().resize_as_(grad).zero_() + state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + + # Warming momentum schedule + m_schedule = state['m_schedule'] + schedule_decay = group['schedule_decay'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + eps = group['eps'] + state['step'] += 1 + t = state['step'] + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + momentum_cache_t = beta1 * \ + (1. - 0.5 * (0.96 ** (t * schedule_decay))) + momentum_cache_t_1 = beta1 * \ + (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) + m_schedule_new = m_schedule * momentum_cache_t + m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 + state['m_schedule'] = m_schedule_new + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1. - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) + exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) + denom = exp_avg_sq_prime.sqrt_().add_(eps) + + p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) + p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) + + return loss diff --git a/timm/optim/novograd.py b/timm/optim/novograd.py new file mode 100644 index 0000000000000000000000000000000000000000..4137c6aa9406360d29f5f7234ebbdef294404d0e --- /dev/null +++ b/timm/optim/novograd.py @@ -0,0 +1,77 @@ +"""NovoGrad Optimizer. +Original impl by Masashi Kimura (Convergence Lab): https://github.com/convergence-lab/novograd +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NovoGrad(Optimizer): + def __init__(self, params, grad_averaging=False, lr=0.1, betas=(0.95, 0.98), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super(NovoGrad, self).__init__(params, defaults) + self._lr = lr + self._beta1 = betas[0] + self._beta2 = betas[1] + self._eps = eps + self._wd = weight_decay + self._grad_averaging = grad_averaging + + self._momentum_initialized = False + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + if not self._momentum_initialized: + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('NovoGrad does not support sparse gradients') + + v = torch.norm(grad)**2 + m = grad/(torch.sqrt(v) + self._eps) + self._wd * p.data + state['step'] = 0 + state['v'] = v + state['m'] = m + state['grad_ema'] = None + self._momentum_initialized = True + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + state = self.state[p] + state['step'] += 1 + + step, v, m = state['step'], state['v'], state['m'] + grad_ema = state['grad_ema'] + + grad = p.grad.data + g2 = torch.norm(grad)**2 + grad_ema = g2 if grad_ema is None else grad_ema * \ + self._beta2 + g2 * (1. - self._beta2) + grad *= 1.0 / (torch.sqrt(grad_ema) + self._eps) + + if self._grad_averaging: + grad *= (1. - self._beta1) + + g2 = torch.norm(grad)**2 + v = self._beta2*v + (1. - self._beta2)*g2 + m = self._beta1*m + (grad / (torch.sqrt(v) + self._eps) + self._wd * p.data) + bias_correction1 = 1 - self._beta1 ** step + bias_correction2 = 1 - self._beta2 ** step + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + state['v'], state['m'] = v, m + state['grad_ema'] = grad_ema + p.data.add_(-step_size, m) + return loss diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py new file mode 100644 index 0000000000000000000000000000000000000000..323312d2fc36d028124f7a7ec604d248e71503cd --- /dev/null +++ b/timm/optim/nvnovograd.py @@ -0,0 +1,118 @@ +""" Nvidia NovoGrad Optimizer. +Original impl by Nvidia from Jasper example: + - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper +Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks` + - https://arxiv.org/abs/1905.11286 +""" + +import torch +from torch.optim.optimizer import Optimizer +import math + + +class NvNovoGrad(Optimizer): + """ + Implements Novograd algorithm. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.95, 0.98)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging: gradient averaging + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + """ + + def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8, + weight_decay=0, grad_averaging=False, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, + grad_averaging=grad_averaging, + amsgrad=amsgrad) + + super(NvNovoGrad, self).__init__(params, defaults) + + def __setstate__(self, state): + super(NvNovoGrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Sparse gradients are not supported.') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + norm = torch.sum(torch.pow(grad, 2)) + + if exp_avg_sq == 0: + exp_avg_sq.copy_(norm) + else: + exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + grad.div_(denom) + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + if group['grad_averaging']: + grad.mul_(1 - beta1) + exp_avg.mul_(beta1).add_(grad) + + p.data.add_(-group['lr'], exp_avg) + + return loss diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..2017d21f31e7223406b59e4193ce8696514a5383 --- /dev/null +++ b/timm/optim/optim_factory.py @@ -0,0 +1,174 @@ +""" Optimizer Factory w/ Custom Weight Decay +Hacked together by / Copyright 2020 Ross Wightman +""" +from typing import Optional + +import torch +import torch.nn as nn +import torch.optim as optim + +from .adafactor import Adafactor +from .adahessian import Adahessian +from .adamp import AdamP +from .lookahead import Lookahead +from .nadam import Nadam +from .novograd import NovoGrad +from .nvnovograd import NvNovoGrad +from .radam import RAdam +from .rmsprop_tf import RMSpropTF +from .sgdp import SGDP +from .adabelief import AdaBelief + +try: + from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD + has_apex = True +except ImportError: + has_apex = False + + +def add_weight_decay(model, weight_decay=1e-5, skip_list=()): + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.}, + {'params': decay, 'weight_decay': weight_decay}] + + +def optimizer_kwargs(cfg): + """ cfg/argparse to kwargs helper + Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. + """ + kwargs = dict( + optimizer_name=cfg.opt, + learning_rate=cfg.lr, + weight_decay=cfg.weight_decay, + momentum=cfg.momentum) + if getattr(cfg, 'opt_eps', None) is not None: + kwargs['eps'] = cfg.opt_eps + if getattr(cfg, 'opt_betas', None) is not None: + kwargs['betas'] = cfg.opt_betas + if getattr(cfg, 'opt_args', None) is not None: + kwargs.update(cfg.opt_args) + return kwargs + + +def create_optimizer(args, model, filter_bias_and_bn=True): + """ Legacy optimizer factory for backwards compatibility. + NOTE: Use create_optimizer_v2 for new code. + """ + return create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + filter_bias_and_bn=filter_bias_and_bn, + ) + + +def create_optimizer_v2( + model: nn.Module, + optimizer_name: str = 'sgd', + learning_rate: Optional[float] = None, + weight_decay: float = 0., + momentum: float = 0.9, + filter_bias_and_bn: bool = True, + **kwargs): + """ Create an optimizer. + + TODO currently the model is passed in and all parameters are selected for optimization. + For more general use an interface that allows selection of parameters to optimize and lr groups, one of: + * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion + * expose the parameters interface and leave it up to caller + + Args: + model (nn.Module): model containing parameters to optimize + optimizer_name: name of optimizer to create + learning_rate: initial learning rate + weight_decay: weight decay to apply in optimizer + momentum: momentum for momentum based optimizers (others may use betas via kwargs) + filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay + **kwargs: extra optimizer specific kwargs to pass through + + Returns: + Optimizer + """ + opt_lower = optimizer_name.lower() + if weight_decay and filter_bias_and_bn: + skip = {} + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + parameters = add_weight_decay(model, weight_decay, skip) + weight_decay = 0. + else: + parameters = model.parameters() + if 'fused' in opt_lower: + assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' + + opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs) + opt_split = opt_lower.split('_') + opt_lower = opt_split[-1] + if opt_lower == 'sgd' or opt_lower == 'nesterov': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'momentum': + opt_args.pop('eps', None) + optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == 'adam': + optimizer = optim.Adam(parameters, **opt_args) + elif opt_lower == 'adabelief': + optimizer = AdaBelief(parameters, rectify=False, **opt_args) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, **opt_args) + elif opt_lower == 'nadam': + optimizer = Nadam(parameters, **opt_args) + elif opt_lower == 'radam': + optimizer = RAdam(parameters, **opt_args) + elif opt_lower == 'adamp': + optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) + elif opt_lower == 'sgdp': + optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'adadelta': + optimizer = optim.Adadelta(parameters, **opt_args) + elif opt_lower == 'adafactor': + if not learning_rate: + opt_args['lr'] = None + optimizer = Adafactor(parameters, **opt_args) + elif opt_lower == 'adahessian': + optimizer = Adahessian(parameters, **opt_args) + elif opt_lower == 'rmsprop': + optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == 'rmsproptf': + optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) + elif opt_lower == 'novograd': + optimizer = NovoGrad(parameters, **opt_args) + elif opt_lower == 'nvnovograd': + optimizer = NvNovoGrad(parameters, **opt_args) + elif opt_lower == 'fusedsgd': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) + elif opt_lower == 'fusedmomentum': + opt_args.pop('eps', None) + optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) + elif opt_lower == 'fusedadam': + optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) + elif opt_lower == 'fusedadamw': + optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) + elif opt_lower == 'fusedlamb': + optimizer = FusedLAMB(parameters, **opt_args) + elif opt_lower == 'fusednovograd': + opt_args.setdefault('betas', (0.95, 0.98)) + optimizer = FusedNovoGrad(parameters, **opt_args) + else: + assert False and "Invalid optimizer" + raise ValueError + + if len(opt_split) > 1: + if opt_split[0] == 'lookahead': + optimizer = Lookahead(optimizer) + + return optimizer diff --git a/timm/optim/radam.py b/timm/optim/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..9987a334460286b1a6c8ec6d57ee023596a74219 --- /dev/null +++ b/timm/optim/radam.py @@ -0,0 +1,152 @@ +"""RAdam Optimizer. +Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam +Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265 +""" +import math +import torch +from torch.optim.optimizer import Optimizer, required + + +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss + + +class PlainRAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + + super(PlainRAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PlainRAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = group['lr'] * math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( + N_sma_max - 2)) / (1 - beta1 ** state['step']) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size, exp_avg, denom) + else: + step_size = group['lr'] / (1 - beta1 ** state['step']) + p_data_fp32.add_(-step_size, exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..5115555cd26040e3af297a6e79e7bd5e4d202623 --- /dev/null +++ b/timm/optim/rmsprop_tf.py @@ -0,0 +1,136 @@ +""" RMSProp modified to behave like Tensorflow impl + +Originally cut & paste from PyTorch RMSProp +https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py +Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE + +Modifications Copyright 2020 Ross Wightman +""" + +import torch +from torch.optim import Optimizer + + +class RMSpropTF(Optimizer): + """Implements RMSprop algorithm (TensorFlow style epsilon) + + NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt + and a few other modifications to closer match Tensorflow for matching hyper-params. + + Noteworthy changes include: + 1. Epsilon applied inside square-root + 2. square_avg initialized to ones + 3. LR scaling of update accumulated in momentum buffer + + Proposed by G. Hinton in his + `course `_. + + The centered version first appears in `Generating Sequences + With Recurrent Neural Networks `_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-2) + momentum (float, optional): momentum factor (default: 0) + alpha (float, optional): smoothing (decay) constant (default: 0.9) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-10) + centered (bool, optional) : if ``True``, compute the centered RMSProp, + the gradient is normalized by an estimation of its variance + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer + update as per defaults in Tensorflow + + """ + + def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False, + decoupled_decay=False, lr_in_momentum=True): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= momentum: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= alpha: + raise ValueError("Invalid alpha value: {}".format(alpha)) + + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay, + decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum) + super(RMSpropTF, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RMSpropTF, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('momentum', 0) + group.setdefault('centered', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('RMSprop does not support sparse gradients') + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['square_avg'] = torch.ones_like(p.data) # PyTorch inits to zero + if group['momentum'] > 0: + state['momentum_buffer'] = torch.zeros_like(p.data) + if group['centered']: + state['grad_avg'] = torch.zeros_like(p.data) + + square_avg = state['square_avg'] + one_minus_alpha = 1. - group['alpha'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + if 'decoupled_decay' in group and group['decoupled_decay']: + p.data.add_(-group['weight_decay'], p.data) + else: + grad = grad.add(group['weight_decay'], p.data) + + # Tensorflow order of ops for updating squared avg + square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg) + # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original + + if group['centered']: + grad_avg = state['grad_avg'] + grad_avg.add_(one_minus_alpha, grad - grad_avg) + # grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original + avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group['eps']).sqrt_() # eps moved in sqrt + else: + avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt + + if group['momentum'] > 0: + buf = state['momentum_buffer'] + # Tensorflow accumulates the LR scaling in the momentum buffer + if 'lr_in_momentum' in group and group['lr_in_momentum']: + buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) + p.data.add_(-buf) + else: + # PyTorch scales the param update by LR + buf.mul_(group['momentum']).addcdiv_(grad, avg) + p.data.add_(-group['lr'], buf) + else: + p.data.addcdiv_(-group['lr'], grad, avg) + + return loss diff --git a/timm/optim/sgdp.py b/timm/optim/sgdp.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a94aa332d7030a70e888342eb6cc4623d69836 --- /dev/null +++ b/timm/optim/sgdp.py @@ -0,0 +1,96 @@ +""" +SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py + +Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217 +Code: https://github.com/clovaai/AdamP + +Copyright (c) 2020-present NAVER Corp. +MIT license +""" + +import torch +import torch.nn as nn +from torch.optim.optimizer import Optimizer, required +import math + +class SGDP(Optimizer): + def __init__(self, params, lr=required, momentum=0, dampening=0, + weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1): + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, + nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio) + super(SGDP, self).__init__(params, defaults) + + def _channel_view(self, x): + return x.view(x.size(0), -1) + + def _layer_view(self, x): + return x.view(1, -1) + + def _cosine_similarity(self, x, y, eps, view_func): + x = view_func(x) + y = view_func(y) + + x_norm = x.norm(dim=1).add_(eps) + y_norm = y.norm(dim=1).add_(eps) + dot = (x * y).sum(dim=1) + + return dot.abs() / x_norm / y_norm + + def _projection(self, p, grad, perturb, delta, wd_ratio, eps): + wd = 1 + expand_size = [-1] + [1] * (len(p.shape) - 1) + for view_func in [self._channel_view, self._layer_view]: + + cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) + + if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): + p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps) + perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size) + wd = wd_ratio + + return perturb, wd + + return perturb, wd + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['momentum'] = torch.zeros_like(p.data) + + # SGD + buf = state['momentum'] + buf.mul_(momentum).add_(1 - dampening, grad) + if nesterov: + d_p = grad + momentum * buf + else: + d_p = buf + + # Projection + wd_ratio = 1 + if len(p.shape) > 1: + d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps']) + + # Weight decay + if weight_decay != 0: + p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum)) + + # Step + p.data.add_(-group['lr'], d_p) + + return loss diff --git a/timm/scheduler/__init__.py b/timm/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a7789826229f66e1220cb6149902ba9c411b537 --- /dev/null +++ b/timm/scheduler/__init__.py @@ -0,0 +1,5 @@ +from .cosine_lr import CosineLRScheduler +from .plateau_lr import PlateauLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler +from .scheduler_factory import create_scheduler diff --git a/timm/scheduler/__pycache__/__init__.cpython-38.pyc b/timm/scheduler/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..874bc854fa599b23861050305e46f24902b08bd7 Binary files /dev/null and b/timm/scheduler/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/cosine_lr.cpython-38.pyc b/timm/scheduler/__pycache__/cosine_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c1a672509934ca288eca01fe8d36810c6b76026 Binary files /dev/null and b/timm/scheduler/__pycache__/cosine_lr.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/multistep_lr.cpython-38.pyc b/timm/scheduler/__pycache__/multistep_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f168ebc7491cb213d22df482ec0ef40f8a9d4be0 Binary files /dev/null and b/timm/scheduler/__pycache__/multistep_lr.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/plateau_lr.cpython-38.pyc b/timm/scheduler/__pycache__/plateau_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d65749187b960416c3a4a83893bcf16730e49161 Binary files /dev/null and b/timm/scheduler/__pycache__/plateau_lr.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/poly_lr.cpython-38.pyc b/timm/scheduler/__pycache__/poly_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e18884742ee6eaef1d7bc897954a05eac804fd6 Binary files /dev/null and b/timm/scheduler/__pycache__/poly_lr.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/scheduler.cpython-38.pyc b/timm/scheduler/__pycache__/scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4129c2040bd15f58d0357fcf2d5e612723cb0a9d Binary files /dev/null and b/timm/scheduler/__pycache__/scheduler.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/scheduler_factory.cpython-38.pyc b/timm/scheduler/__pycache__/scheduler_factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76cb61d804ed68f9978abd96096cdf77dda1c916 Binary files /dev/null and b/timm/scheduler/__pycache__/scheduler_factory.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/step_lr.cpython-38.pyc b/timm/scheduler/__pycache__/step_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..812407c6c72d6b00a7a31420669c4d5865e436e1 Binary files /dev/null and b/timm/scheduler/__pycache__/step_lr.cpython-38.pyc differ diff --git a/timm/scheduler/__pycache__/tanh_lr.cpython-38.pyc b/timm/scheduler/__pycache__/tanh_lr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d0df8e0c03399d618b6852b11f758f77d0b66da Binary files /dev/null and b/timm/scheduler/__pycache__/tanh_lr.cpython-38.pyc differ diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..1532f092b5cc8c0af5125967cfb84b32ce03ca4a --- /dev/null +++ b/timm/scheduler/cosine_lr.py @@ -0,0 +1,116 @@ +""" Cosine Scheduler + +Cosine LR schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class CosineLRScheduler(Scheduler): + """ + Cosine decay with restarts. + This is described in the paper https://arxiv.org/abs/1608.03983. + + Inspiration from + https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + cycle_limit=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and t_mul == 1 and decay_rate == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.t_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + if not cycles: + cycles = self.cycle_limit + cycles = max(1, cycles) + if self.t_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) diff --git a/timm/scheduler/multistep_lr.py b/timm/scheduler/multistep_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d5fe1980f20d0a1251d7771f475110805f8862 --- /dev/null +++ b/timm/scheduler/multistep_lr.py @@ -0,0 +1,65 @@ +""" MultiStep LR Scheduler + +Basic multi step LR schedule with warmup, noise. +""" +import torch +import bisect +from timm.scheduler.scheduler import Scheduler +from typing import List + +class MultiStepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: List[int], + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def get_curr_decay_steps(self, t): + # find where in the array t goes, + # assumes self.decay_t is sorted + return bisect.bisect_right(self.decay_t, t+1) + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2cacb65a1bf23d10aa6fd296f74579571043cf --- /dev/null +++ b/timm/scheduler/plateau_lr.py @@ -0,0 +1,113 @@ +""" Plateau Scheduler + +Adapts PyTorch plateau scheduler and allows application of noise, warmup. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +from .scheduler import Scheduler + + +class PlateauLRScheduler(Scheduler): + """Decay the LR by a factor every time the validation loss plateaus.""" + + def __init__(self, + optimizer, + decay_rate=0.1, + patience_t=10, + verbose=True, + threshold=1e-4, + cooldown_t=0, + warmup_t=0, + warmup_lr_init=0, + lr_min=0, + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize=True, + ): + super().__init__(optimizer, 'lr', initialize=initialize) + + self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, + patience=patience_t, + factor=decay_rate, + verbose=verbose, + threshold=threshold, + cooldown=cooldown_t, + mode=mode, + min_lr=lr_min + ) + + self.noise_range = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None + + def state_dict(self): + return { + 'best': self.lr_scheduler.best, + 'last_epoch': self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + self.lr_scheduler.best = state_dict['best'] + if 'last_epoch' in state_dict: + self.lr_scheduler.last_epoch = state_dict['last_epoch'] + + # override the base class step fn completely + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) + else: + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self.noise_range is not None: + if isinstance(self.noise_range, (list, tuple)): + apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] + else: + apply_noise = epoch >= self.noise_range + if apply_noise: + self._apply_noise(epoch) + + def _apply_noise(self, epoch): + g = torch.Generator() + g.manual_seed(self.noise_seed + epoch) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..9c351be6ed56f8fe130cd391df0a7a7f89c7a96c --- /dev/null +++ b/timm/scheduler/poly_lr.py @@ -0,0 +1,116 @@ +""" Polynomial Scheduler + +Polynomial LR schedule with warmup, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging + +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class PolyLRScheduler(Scheduler): + """ Polynomial LR Scheduler w/ warmup, noise, and k-decay + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + power: float = 0.5, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=1.0, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.power = power + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..21d51509c87a0783c6b61986c574a3ed5366e165 --- /dev/null +++ b/timm/scheduler/scheduler.py @@ -0,0 +1,105 @@ +from typing import Dict, Any + +import torch + + +class Scheduler: + """ Parameter Scheduler Base Class + A scheduler base class that can be used to schedule any optimizer parameter groups. + + Unlike the builtin PyTorch schedulers, this is intended to be consistently called + * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value + * At the END of each optimizer update, after incrementing the update count, to calculate next update's value + + The schedulers built on this should try to remain as stateless as possible (for simplicity). + + This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' + and -1 values for special behaviour. All epoch and update counts must be tracked in the training + code and explicitly passed in to the schedulers on the corresponding step or step_update call. + + Based on ideas from: + * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler + * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + param_group_field: str, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, + initialize: bool = True) -> None: + self.optimizer = optimizer + self.param_group_field = param_group_field + self._initial_param_group_field = f"initial_{param_group_field}" + if initialize: + for i, group in enumerate(self.optimizer.param_groups): + if param_group_field not in group: + raise KeyError(f"{param_group_field} missing from param_groups[{i}]") + group.setdefault(self._initial_param_group_field, group[param_group_field]) + else: + for i, group in enumerate(self.optimizer.param_groups): + if self._initial_param_group_field not in group: + raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") + self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] + self.metric = None # any point to having this for all? + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 + self.update_groups(self.base_values) + + def state_dict(self) -> Dict[str, Any]: + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.__dict__.update(state_dict) + + def get_epoch_values(self, epoch: int): + return None + + def get_update_values(self, num_updates: int): + return None + + def step(self, epoch: int, metric: float = None) -> None: + self.metric = metric + values = self.get_epoch_values(epoch) + if values is not None: + values = self._add_noise(values, epoch) + self.update_groups(values) + + def step_update(self, num_updates: int, metric: float = None): + self.metric = metric + values = self.get_update_values(num_updates) + if values is not None: + values = self._add_noise(values, num_updates) + self.update_groups(values) + + def update_groups(self, values): + if not isinstance(values, (list, tuple)): + values = [values] * len(self.optimizer.param_groups) + for param_group, value in zip(self.optimizer.param_groups, values): + param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + lrs = [v + v * noise for v in lrs] + return lrs diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7748f42280b846ab159fb18d7cda09d1890123 --- /dev/null +++ b/timm/scheduler/scheduler_factory.py @@ -0,0 +1,87 @@ +""" Scheduler Factory +Hacked together by / Copyright 2020 Ross Wightman +""" +from .cosine_lr import CosineLRScheduler +from .tanh_lr import TanhLRScheduler +from .step_lr import StepLRScheduler +from .plateau_lr import PlateauLRScheduler + + +def create_scheduler(args, optimizer): + num_epochs = args.epochs + + if getattr(args, 'lr_noise', None) is not None: + lr_noise = getattr(args, 'lr_noise') + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] + else: + noise_range = lr_noise * num_epochs + else: + noise_range = None + + lr_scheduler = None + if args.sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(args, 'lr_cycle_mul', 1.), + lr_min=args.min_lr, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=getattr(args, 'lr_cycle_mul', 1.), + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs + elif args.sched == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + elif args.sched == 'plateau': + mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=args.decay_rate, + patience_t=args.patience_epochs, + lr_min=args.min_lr, + mode=mode, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cooldown_t=0, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) + + return lr_scheduler, num_epochs diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..f797e1a8cf35999531dd5f1ccbbe09a9d0cf30a9 --- /dev/null +++ b/timm/scheduler/step_lr.py @@ -0,0 +1,63 @@ +""" Step Scheduler + +Basic step LR schedule with warmup, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import math +import torch + +from .scheduler import Scheduler + + +class StepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: float, + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..8cc338bb1df7a564d9207b32ab0f59cdf1ef4c59 --- /dev/null +++ b/timm/scheduler/tanh_lr.py @@ -0,0 +1,120 @@ +""" TanH Scheduler + +TanH schedule with warmup, cycle/restarts, noise. + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import math +import numpy as np +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class TanhLRScheduler(Scheduler): + """ + Hyberbolic-Tangent decay with restarts. + This is described in the paper https://arxiv.org/abs/1806.01593 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lb: float = -6., + ub: float = 4., + t_mul: float = 1., + lr_min: float = 0., + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + cycle_limit=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 + self.lb = lb + self.ub = ub + self.t_initial = t_initial + self.t_mul = t_mul + self.lr_min = lr_min + self.decay_rate = decay_rate + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.t_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) + t_i = self.t_mul ** i * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + gamma = self.decay_rate ** i + lr_min = self.lr_min * gamma + lr_max_values = [v * gamma for v in self.base_values] + + tr = t_curr / t_i + lrs = [ + lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + if not cycles: + cycles = self.cycle_limit + cycles = max(1, cycles) + if self.t_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d02e62d2d0ce62e594393014208e28c3ace5318b --- /dev/null +++ b/timm/utils/__init__.py @@ -0,0 +1,13 @@ +from .agc import adaptive_clip_grad +from .checkpoint_saver import CheckpointSaver +from .clip_grad import dispatch_clip_grad +from .cuda import ApexScaler, NativeScaler +from .distributed import distribute_bn, reduce_tensor +from .jit import set_jit_legacy +from .log import setup_default_logging, FormatterNoInfo +from .metrics import AverageMeter, accuracy +from .misc import natural_key, add_bool_arg +from .model import unwrap_model, get_state_dict +from .model_ema import ModelEma, ModelEmaV2 +from .random import random_seed +from .summary import update_summary, get_outdir diff --git a/timm/utils/__pycache__/__init__.cpython-310.pyc b/timm/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e0145fd84b11a412bf1c238f978d91ce37e4545 Binary files /dev/null and b/timm/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/__init__.cpython-311.pyc b/timm/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a934f41c48315ee934e35cf692a105079eb7f50 Binary files /dev/null and b/timm/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/__init__.cpython-313.pyc b/timm/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..127eef1ba482b03f5776fb55035b600a0903caa8 Binary files /dev/null and b/timm/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/__init__.cpython-38.pyc b/timm/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd673cb91f058aec2ba0126fc49d7157622a0416 Binary files /dev/null and b/timm/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/agc.cpython-310.pyc b/timm/utils/__pycache__/agc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7fccb257c26bb7a0ec5b6d56be6fec3ccb5a7bc Binary files /dev/null and b/timm/utils/__pycache__/agc.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/agc.cpython-311.pyc b/timm/utils/__pycache__/agc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50675ed84b48ca3ef6785ce2b9b2d786cedd2282 Binary files /dev/null and b/timm/utils/__pycache__/agc.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/agc.cpython-313.pyc b/timm/utils/__pycache__/agc.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bb4b25bfce4251264d47e5f4691ec13b16b8c56 Binary files /dev/null and b/timm/utils/__pycache__/agc.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/agc.cpython-38.pyc b/timm/utils/__pycache__/agc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43cec7997e8c761ba296a648feaf60c64b1fb3b3 Binary files /dev/null and b/timm/utils/__pycache__/agc.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/checkpoint_saver.cpython-310.pyc b/timm/utils/__pycache__/checkpoint_saver.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5d54cd1bbb4413ee0acecd7cc77cdc152d8ab06 Binary files /dev/null and b/timm/utils/__pycache__/checkpoint_saver.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/checkpoint_saver.cpython-311.pyc b/timm/utils/__pycache__/checkpoint_saver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54cff3933e021b78e45b7874771d87715d042cdc Binary files /dev/null and b/timm/utils/__pycache__/checkpoint_saver.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/checkpoint_saver.cpython-313.pyc b/timm/utils/__pycache__/checkpoint_saver.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f2b1abec581bd6163bcedefc1b29ede34e59786 Binary files /dev/null and b/timm/utils/__pycache__/checkpoint_saver.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc b/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c90bfb8ef80c5e94b61f2058d8e439d5b55252ca Binary files /dev/null and b/timm/utils/__pycache__/checkpoint_saver.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/clip_grad.cpython-310.pyc b/timm/utils/__pycache__/clip_grad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1bfd80b67179363fbeaba71d12d3903e54411c6 Binary files /dev/null and b/timm/utils/__pycache__/clip_grad.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/clip_grad.cpython-311.pyc b/timm/utils/__pycache__/clip_grad.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46d90fbb6b797e610355ca1cf842031d7647b873 Binary files /dev/null and b/timm/utils/__pycache__/clip_grad.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/clip_grad.cpython-313.pyc b/timm/utils/__pycache__/clip_grad.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb1aa1287be64c3a9c4e38c59db8cd40c9490b0 Binary files /dev/null and b/timm/utils/__pycache__/clip_grad.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/clip_grad.cpython-38.pyc b/timm/utils/__pycache__/clip_grad.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e6dee41e5b5a06555d7544e27123a1a7e432d5 Binary files /dev/null and b/timm/utils/__pycache__/clip_grad.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/cuda.cpython-310.pyc b/timm/utils/__pycache__/cuda.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cb33758a914d1424c1350e9372cb15662344def Binary files /dev/null and b/timm/utils/__pycache__/cuda.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/cuda.cpython-311.pyc b/timm/utils/__pycache__/cuda.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e5fd7961f6805f443e879c06ba61fd2da5dd751 Binary files /dev/null and b/timm/utils/__pycache__/cuda.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/cuda.cpython-313.pyc b/timm/utils/__pycache__/cuda.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e0b89c78b41d29eb7d96f9ec94d7caff26fcdc Binary files /dev/null and b/timm/utils/__pycache__/cuda.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/cuda.cpython-38.pyc b/timm/utils/__pycache__/cuda.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1994cb2edf62bf24bfb430e4d432fc86ad80bc23 Binary files /dev/null and b/timm/utils/__pycache__/cuda.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/distributed.cpython-310.pyc b/timm/utils/__pycache__/distributed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98ef8a11b139cb9d414e6d8b31ffe566fb18de1a Binary files /dev/null and b/timm/utils/__pycache__/distributed.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/distributed.cpython-311.pyc b/timm/utils/__pycache__/distributed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1bf306e2b9aa1ef66a5767d543cd155d0f4755 Binary files /dev/null and b/timm/utils/__pycache__/distributed.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/distributed.cpython-313.pyc b/timm/utils/__pycache__/distributed.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f506d16fd257b44d4d12ef752f5b583eddf5ac6c Binary files /dev/null and b/timm/utils/__pycache__/distributed.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/distributed.cpython-38.pyc b/timm/utils/__pycache__/distributed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f88df389b5d09152db2c0334b6ab53eb469a776 Binary files /dev/null and b/timm/utils/__pycache__/distributed.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/jit.cpython-310.pyc b/timm/utils/__pycache__/jit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11402524ec101bb823ea0e86a32655ae8ea3bd1a Binary files /dev/null and b/timm/utils/__pycache__/jit.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/jit.cpython-311.pyc b/timm/utils/__pycache__/jit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf2eafea5b0069489e2b8db8d754f466452c5cb5 Binary files /dev/null and b/timm/utils/__pycache__/jit.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/jit.cpython-313.pyc b/timm/utils/__pycache__/jit.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8575278e2e03c568ddde1231af3aaa00ffb4d23f Binary files /dev/null and b/timm/utils/__pycache__/jit.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/jit.cpython-38.pyc b/timm/utils/__pycache__/jit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d533d41f8aa1b2bea9ccf90c2e3d8b460b9f77da Binary files /dev/null and b/timm/utils/__pycache__/jit.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/log.cpython-310.pyc b/timm/utils/__pycache__/log.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5bf82dd35e97a310324d89ba694e91636897855 Binary files /dev/null and b/timm/utils/__pycache__/log.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/log.cpython-311.pyc b/timm/utils/__pycache__/log.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..831b1a42d728aef7e3d29a4b9509ed869591d4b9 Binary files /dev/null and b/timm/utils/__pycache__/log.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/log.cpython-313.pyc b/timm/utils/__pycache__/log.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c5b1d603d4b8bfab771d47468ef200c0b7396c Binary files /dev/null and b/timm/utils/__pycache__/log.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/log.cpython-38.pyc b/timm/utils/__pycache__/log.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5792224ef1c9613cbfacf180103103ad5697e2ca Binary files /dev/null and b/timm/utils/__pycache__/log.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/metrics.cpython-310.pyc b/timm/utils/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e009a5d993a2ab7c9eabf4446dd6dc8b6d59518b Binary files /dev/null and b/timm/utils/__pycache__/metrics.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/metrics.cpython-311.pyc b/timm/utils/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..793b7612eae258dfc653e8e07816d8dafeb00854 Binary files /dev/null and b/timm/utils/__pycache__/metrics.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/metrics.cpython-313.pyc b/timm/utils/__pycache__/metrics.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6050d72d8ce4f4f7d8107c4d1796abf403baa805 Binary files /dev/null and b/timm/utils/__pycache__/metrics.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/metrics.cpython-38.pyc b/timm/utils/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1536484e5a097a85c81b34297951b12aca19d727 Binary files /dev/null and b/timm/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/misc.cpython-310.pyc b/timm/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c95fa04a7e006ae03671ffccc449a33ecbb4c5bb Binary files /dev/null and b/timm/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/misc.cpython-311.pyc b/timm/utils/__pycache__/misc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7754943857c5d2002addaca53abd6251bcaa7a70 Binary files /dev/null and b/timm/utils/__pycache__/misc.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/misc.cpython-313.pyc b/timm/utils/__pycache__/misc.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9205eec85176e3234a1e282b9946fead9fae3be Binary files /dev/null and b/timm/utils/__pycache__/misc.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/misc.cpython-38.pyc b/timm/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1dd376f603f19909b8325385d014ffc7a5ab51e Binary files /dev/null and b/timm/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/model.cpython-310.pyc b/timm/utils/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c161fa8df1f1da4856d0b7a2023cd3e75355b5 Binary files /dev/null and b/timm/utils/__pycache__/model.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/model.cpython-311.pyc b/timm/utils/__pycache__/model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..931a5b2d3d48aac5e20ea7a050a11848f8041ab8 Binary files /dev/null and b/timm/utils/__pycache__/model.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/model.cpython-313.pyc b/timm/utils/__pycache__/model.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba52cdeefb01b93c9bfabca9752f3dd44268de3a Binary files /dev/null and b/timm/utils/__pycache__/model.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/model.cpython-38.pyc b/timm/utils/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e3184b9d3f8a77a4d6812be33c34305c251d364 Binary files /dev/null and b/timm/utils/__pycache__/model.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/model_ema.cpython-310.pyc b/timm/utils/__pycache__/model_ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f19542e4dfe703e74bce870bbb08d6094a14376 Binary files /dev/null and b/timm/utils/__pycache__/model_ema.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/model_ema.cpython-311.pyc b/timm/utils/__pycache__/model_ema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ec7b8ab63ce5b38d0b26244675fa2777815c46c Binary files /dev/null and b/timm/utils/__pycache__/model_ema.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/model_ema.cpython-313.pyc b/timm/utils/__pycache__/model_ema.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07b3f8f88b31d91a096496c83547dd0cd8c8081e Binary files /dev/null and b/timm/utils/__pycache__/model_ema.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/model_ema.cpython-38.pyc b/timm/utils/__pycache__/model_ema.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efe8e97ea6b3f10a2623bf9279763fc1ae559f0c Binary files /dev/null and b/timm/utils/__pycache__/model_ema.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/random.cpython-310.pyc b/timm/utils/__pycache__/random.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4aa228bc23f22e22a712830020b95e92b712207d Binary files /dev/null and b/timm/utils/__pycache__/random.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/random.cpython-311.pyc b/timm/utils/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61f5aaadc1a72a1e64d4ef5633a2eda0698bc3ae Binary files /dev/null and b/timm/utils/__pycache__/random.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/random.cpython-313.pyc b/timm/utils/__pycache__/random.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd14e3fb933c490d3f10754e2320bfda64075410 Binary files /dev/null and b/timm/utils/__pycache__/random.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/random.cpython-38.pyc b/timm/utils/__pycache__/random.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ef02cfddc5121dae6f2f9858512f8ab629dd647 Binary files /dev/null and b/timm/utils/__pycache__/random.cpython-38.pyc differ diff --git a/timm/utils/__pycache__/summary.cpython-310.pyc b/timm/utils/__pycache__/summary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4381b7a49d3990fad92bd1e74df59fdb0f9cc095 Binary files /dev/null and b/timm/utils/__pycache__/summary.cpython-310.pyc differ diff --git a/timm/utils/__pycache__/summary.cpython-311.pyc b/timm/utils/__pycache__/summary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba320da4310b152e81c1b02fa92521ab739c2ffa Binary files /dev/null and b/timm/utils/__pycache__/summary.cpython-311.pyc differ diff --git a/timm/utils/__pycache__/summary.cpython-313.pyc b/timm/utils/__pycache__/summary.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3bce8276f80464e2febed7205c122d48847762 Binary files /dev/null and b/timm/utils/__pycache__/summary.cpython-313.pyc differ diff --git a/timm/utils/__pycache__/summary.cpython-38.pyc b/timm/utils/__pycache__/summary.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9b5c13338bdf46dc391a4e6cd853fed3db31386 Binary files /dev/null and b/timm/utils/__pycache__/summary.cpython-38.pyc differ diff --git a/timm/utils/agc.py b/timm/utils/agc.py new file mode 100644 index 0000000000000000000000000000000000000000..f51401726ff6810d97d0fa567f4e31b474325a59 --- /dev/null +++ b/timm/utils/agc.py @@ -0,0 +1,42 @@ +""" Adaptive Gradient Clipping + +An impl of AGC, as per (https://arxiv.org/abs/2102.06171): + +@article{brock2021high, + author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan}, + title={High-Performance Large-Scale Image Recognition Without Normalization}, + journal={arXiv preprint arXiv:}, + year={2021} +} + +Code references: + * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets + * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch + + +def unitwise_norm(x, norm_type=2.0): + if x.ndim <= 1: + return x.norm(norm_type) + else: + # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor + # might need special cases for other weights (possibly MHA) where this may not be true + return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True) + + +def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + for p in parameters: + if p.grad is None: + continue + p_data = p.detach() + g_data = p.grad.detach() + max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor) + grad_norm = unitwise_norm(g_data, norm_type=norm_type) + clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6)) + new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad) + p.grad.detach().copy_(new_grads) diff --git a/timm/utils/checkpoint_saver.py b/timm/utils/checkpoint_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..6aad74ee52655f68220f799efaffcbccdd0748ad --- /dev/null +++ b/timm/utils/checkpoint_saver.py @@ -0,0 +1,150 @@ +""" Checkpoint Saver + +Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. + +Hacked together by / Copyright 2020 Ross Wightman +""" + +import glob +import operator +import os +import logging + +import torch + +from .model import unwrap_model, get_state_dict + + +_logger = logging.getLogger(__name__) + + +class CheckpointSaver: + def __init__( + self, + model, + optimizer, + args=None, + model_ema=None, + amp_scaler=None, + checkpoint_prefix='checkpoint', + recovery_prefix='recovery', + checkpoint_dir='', + recovery_dir='', + decreasing=False, + max_history=10, + unwrap_fn=unwrap_model): + + # objects to save state_dicts of + self.model = model + self.optimizer = optimizer + self.args = args + self.model_ema = model_ema + self.amp_scaler = amp_scaler + + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None + self.best_metric = None + self.curr_recovery_file = '' + self.last_recovery_file = '' + + # config + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.save_prefix = checkpoint_prefix + self.recovery_prefix = recovery_prefix + self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.max_history = max_history + self.unwrap_fn = unwrap_fn + assert self.max_history >= 1 + + def save_checkpoint(self, epoch, metric=None): + assert epoch >= 0 + tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) + last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) + self._save(tmp_save_path, epoch, metric) + if os.path.exists(last_save_path): + os.unlink(last_save_path) # required for Windows support. + os.rename(tmp_save_path, last_save_path) + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if (len(self.checkpoint_files) < self.max_history + or metric is None or self.cmp(metric, worst_file[1])): + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension + save_path = os.path.join(self.checkpoint_dir, filename) + os.link(last_save_path, save_path) + self.checkpoint_files.append((save_path, metric)) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + checkpoints_str = "Current checkpoints:\n" + for c in self.checkpoint_files: + checkpoints_str += ' {}\n'.format(c) + _logger.info(checkpoints_str) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric + best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) + if os.path.exists(best_save_path): + os.unlink(best_save_path) + os.link(last_save_path, best_save_path) + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) + + def _save(self, save_path, epoch, metric=None): + save_state = { + 'epoch': epoch, + 'arch': type(self.model).__name__.lower(), + 'state_dict': get_state_dict(self.model, self.unwrap_fn), + 'optimizer': self.optimizer.state_dict(), + 'version': 2, # version < 2 increments epoch before save + } + if self.args is not None: + save_state['arch'] = self.args.model + save_state['args'] = self.args + if self.amp_scaler is not None: + save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() + if self.model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) + if metric is not None: + save_state['metric'] = metric + torch.save(save_state, save_path) + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index < 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + _logger.debug("Cleaning checkpoint: {}".format(d)) + os.remove(d[0]) + except Exception as e: + _logger.error("Exception '{}' while deleting checkpoint".format(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def save_recovery(self, epoch, batch_idx=0): + assert epoch >= 0 + filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension + save_path = os.path.join(self.recovery_dir, filename) + self._save(save_path, epoch) + if os.path.exists(self.last_recovery_file): + try: + _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) + os.remove(self.last_recovery_file) + except Exception as e: + _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) + self.last_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + return files[0] if len(files) else '' diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb40697a221edd6d8e622ff3306dad5e58afd94 --- /dev/null +++ b/timm/utils/clip_grad.py @@ -0,0 +1,23 @@ +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): + """ Dispatch to gradient clipping method + + Args: + parameters (Iterable): model parameters to clip + value (float): clipping value/factor/norm, mode dependant + mode (str): clipping mode, one of 'norm', 'value', 'agc' + norm_type (float): p-norm, default 2.0 + """ + if mode == 'norm': + torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type) + elif mode == 'value': + torch.nn.utils.clip_grad_value_(parameters, value) + elif mode == 'agc': + adaptive_clip_grad(parameters, value, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7bddf30463a7be7186c7def47c4e4dfb9993aa --- /dev/null +++ b/timm/utils/cuda.py @@ -0,0 +1,55 @@ +""" CUDA / AMP utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +try: + from apex import amp + has_apex = True +except ImportError: + amp = None + has_apex = False + +from .clip_grad import dispatch_clip_grad + + +class ApexScaler: + state_dict_key = "amp" + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward(create_graph=create_graph) + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScaler: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + self._scaler.scale(loss).backward(create_graph=create_graph) + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5dba8c1de5a6ff53638207521377fdfbc4f239 --- /dev/null +++ b/timm/utils/distributed.py @@ -0,0 +1,28 @@ +""" Distributed training/validation utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch +from torch import distributed as dist + +from .model import unwrap_model + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +def distribute_bn(model, world_size, reduce=False): + # ensure every node has the same running bn stats + for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): + if ('running_mean' in bn_name) or ('running_var' in bn_name): + if reduce: + # average bn stats across whole group + torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) + bn_buf /= float(world_size) + else: + # broadcast bn stats from rank 0 to whole group + torch.distributed.broadcast(bn_buf, 0) diff --git a/timm/utils/jit.py b/timm/utils/jit.py new file mode 100644 index 0000000000000000000000000000000000000000..185ab7a0d852b9a1c469cfbfff108dbafbb02466 --- /dev/null +++ b/timm/utils/jit.py @@ -0,0 +1,18 @@ +""" JIT scripting/tracing utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + + +def set_jit_legacy(): + """ Set JIT executor to legacy w/ support for op fusion + This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes + in the JIT exectutor. These API are not supported so could change. + """ + # + assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!" + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + #torch._C._jit_set_texpr_fuser_enabled(True) diff --git a/timm/utils/log.py b/timm/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..c99469e0884f3e45905ef7c7f0d1e491092697ad --- /dev/null +++ b/timm/utils/log.py @@ -0,0 +1,28 @@ +""" Logging helpers + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +import logging.handlers + + +class FormatterNoInfo(logging.Formatter): + def __init__(self, fmt='%(levelname)s: %(message)s'): + logging.Formatter.__init__(self, fmt) + + def format(self, record): + if record.levelno == logging.INFO: + return str(record.getMessage()) + return logging.Formatter.format(self, record) + + +def setup_default_logging(default_level=logging.INFO, log_path=''): + console_handler = logging.StreamHandler() + console_handler.setFormatter(FormatterNoInfo()) + logging.root.addHandler(console_handler) + logging.root.setLevel(default_level) + if log_path: + file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3) + file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") + file_handler.setFormatter(file_formatter) + logging.root.addHandler(file_handler) diff --git a/timm/utils/metrics.py b/timm/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..8e0b1f9989a9dc95708a0dbb42e747f9a8565378 --- /dev/null +++ b/timm/utils/metrics.py @@ -0,0 +1,32 @@ +""" Eval metrics and related + +Hacked together by / Copyright 2020 Ross Wightman +""" + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] diff --git a/timm/utils/misc.py b/timm/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..39c0097c60ed602547f832f1f8dafbe37f156064 --- /dev/null +++ b/timm/utils/misc.py @@ -0,0 +1,18 @@ +""" Misc utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import re + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def add_bool_arg(parser, name, default=False, help=''): + dest_name = name.replace('-', '_') + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) + group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) + parser.set_defaults(**{dest_name: default}) diff --git a/timm/utils/model.py b/timm/utils/model.py new file mode 100644 index 0000000000000000000000000000000000000000..bd46e2f49c6d5ee2de304bfda1456bd1716c6886 --- /dev/null +++ b/timm/utils/model.py @@ -0,0 +1,92 @@ +""" Model / state_dict utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +from .model_ema import ModelEma +import torch +import fnmatch + +def unwrap_model(model): + if isinstance(model, ModelEma): + return unwrap_model(model.ema) + else: + return model.module if hasattr(model, 'module') else model + + +def get_state_dict(model, unwrap_fn=unwrap_model): + return unwrap_fn(model).state_dict() + + +def avg_sq_ch_mean(model, input, output): + "calculate average channel square mean of output activations" + return torch.mean(output.mean(axis=[0,2,3])**2).item() + + +def avg_ch_var(model, input, output): + "calculate average channel variance of output activations" + return torch.mean(output.var(axis=[0,2,3])).item()\ + + +def avg_ch_var_residual(model, input, output): + "calculate average channel variance of output activations" + return torch.mean(output.var(axis=[0,2,3])).item() + + +class ActivationStatsHook: + """Iterates through each of `model`'s modules and matches modules using unix pattern + matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is + a match. + + Arguments: + model (nn.Module): model from which we will extract the activation stats + hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string + matching with the name of model's modules. + hook_fns (List[Callable]): List of hook functions to be registered at every + module in `layer_names`. + + Inspiration from https://docs.fast.ai/callback.hook.html. + + Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example + on how to plot Signal Propogation Plots using `ActivationStatsHook`. + """ + + def __init__(self, model, hook_fn_locs, hook_fns): + self.model = model + self.hook_fn_locs = hook_fn_locs + self.hook_fns = hook_fns + if len(hook_fn_locs) != len(hook_fns): + raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ + their lengths are different.") + self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) + for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): + self.register_hook(hook_fn_loc, hook_fn) + + def _create_hook(self, hook_fn): + def append_activation_stats(module, input, output): + out = hook_fn(module, input, output) + self.stats[hook_fn.__name__].append(out) + return append_activation_stats + + def register_hook(self, hook_fn_loc, hook_fn): + for name, module in self.model.named_modules(): + if not fnmatch.fnmatch(name, hook_fn_loc): + continue + module.register_forward_hook(self._create_hook(hook_fn)) + + +def extract_spp_stats(model, + hook_fn_locs, + hook_fns, + input_shape=[8, 3, 224, 224]): + """Extract average square channel mean and variance of activations during + forward pass to plot Signal Propogation Plots (SPP). + + Paper: https://arxiv.org/abs/2101.08692 + + Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 + """ + x = torch.normal(0., 1., input_shape) + hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) + _ = model(x) + return hook.stats + \ No newline at end of file diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..073d5c5ea1a4afc5aa3817b6354b2566f8cc2cf5 --- /dev/null +++ b/timm/utils/model_ema.py @@ -0,0 +1,126 @@ +""" Exponential Moving Average (EMA) of model updates + +Hacked together by / Copyright 2020 Ross Wightman +""" +import logging +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn + +_logger = logging.getLogger(__name__) + + +class ModelEma: + """ Model Exponential Moving Average (DEPRECATED) + + Keep a moving average of everything in the model state_dict (parameters and buffers). + This version is deprecated, it does not work with scripted models. Will be removed eventually. + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + _logger.info("Loaded state_dict_ema") + else: + _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + + +class ModelEmaV2(nn.Module): + """ Model Exponential Moving Average V2 + + Keep a moving average of everything in the model state_dict (parameters and buffers). + V2 of this module is simpler, it does not match params/buffers based on name but simply + iterates in order. It works with torchscript (JIT of full model). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + """ + def __init__(self, model, decay=0.9999, device=None): + super(ModelEmaV2, self).__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if self.device is not None: + self.module.to(device=device) + + def _update(self, model, update_fn): + with torch.no_grad(): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + if self.device is not None: + model_v = model_v.to(device=self.device) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/timm/utils/random.py b/timm/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..a9679983e96a9a6634c0b77aaf7b996e70eff50b --- /dev/null +++ b/timm/utils/random.py @@ -0,0 +1,9 @@ +import random +import numpy as np +import torch + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) diff --git a/timm/utils/summary.py b/timm/utils/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5af9a08598556c3fed136f258f88bd578c1e1c --- /dev/null +++ b/timm/utils/summary.py @@ -0,0 +1,39 @@ +""" Summary utilities + +Hacked together by / Copyright 2020 Ross Wightman +""" +import csv +import os +from collections import OrderedDict +try: + import wandb +except ImportError: + pass + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + + +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): + rowd = OrderedDict(epoch=epoch) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + wandb.log(rowd) + with open(filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd) diff --git a/timm/version.py b/timm/version.py new file mode 100644 index 0000000000000000000000000000000000000000..94c481976ff788a58b06ce8f450984989eadbb7f --- /dev/null +++ b/timm/version.py @@ -0,0 +1 @@ +__version__ = '0.4.12'