""" This is just a wrapper around the various baselines implemented in the Chefer et. al. Transformer Explainability repository. Implements - CheferLRPSegmentationModel - CheferRolloutSegmentationModel - CheferLastLayerAttentionSegmentationModel - CheferAttentionGradCAMSegmentationModel - CheferTransformerAttributionSegmentationModel - CheferFullLRPSegmentationModel - CheferLastLayerLRPSegmentationModel """ # # segmentation test for the rollout baseline # if args.method == 'rollout': # Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(batch_size, 1, 14, 14) # # segmentation test for the LRP baseline (this is full LRP, not partial) # elif args.method == 'full_lrp': # Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(batch_size, 1, 224, 224) # # segmentation test for our method # elif args.method == 'transformer_attribution': # Res = lrp.generate_LRP(image.cuda(), start_layer=1, method="transformer_attribution").reshape(batch_size, 1, 14, 14) # # segmentation test for the partial LRP baseline (last attn layer) # elif args.method == 'lrp_last_layer': # Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer", is_ablation=args.is_ablation)\ # .reshape(batch_size, 1, 14, 14) # # segmentation test for the raw attention baseline (last attn layer) # elif args.method == 'attn_last_layer': # Res = orig_lrp.generate_LRP(image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation)\ # .reshape(batch_size, 1, 14, 14) # # segmentation test for the GradCam baseline (last attn layer) # elif args.method == 'attn_gradcam': # Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14) # if args.method != 'full_lrp': # # interpolate to full image size (224,224) # Res = torch.nn.functional.interpolate(Res, scale_factor=16, mode='bilinear').cuda() import torch import PIL from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import LRP from concept_attention.segmentation import SegmentationAbstractClass from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_explanation_generator import Baselines, LRP from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_new import vit_base_patch16_224 from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_LRP import vit_base_patch16_224 as vit_LRP from concept_attention.binary_segmentation_baselines.chefer_vit_explainability.ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP # # Model # model = vit_base_patch16_224(pretrained=True).cuda() # baselines = Baselines(model) # # LRP # model_LRP = vit_LRP(pretrained=True).cuda() # model_LRP.eval() # lrp = LRP(model_LRP) # # orig LRP # model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() # model_orig_LRP.eval() # orig_lrp = LRP(model_orig_LRP) # model.eval() class CheferLRPSegmentationModel(SegmentationAbstractClass): def __init__( self, device: str = "cuda", width: int = 224, height: int = 224, ): """ Initialize the segmentation model. """ super(CheferLRPSegmentationModel, self).__init__() self.width = width self.height = height self.device = device # Load the LRP model model_orig_LRP = vit_orig_LRP(pretrained=True).to(self.device) model_orig_LRP.eval() self.orig_lrp = LRP(model_orig_LRP) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): """ Takes a real image and generates a concept segmentation map it by adding noise and running the DiT on it. """ if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.orig_lrp.generate_LRP( image.to(self.device), method="full" ) prediction_map = prediction_map.unsqueeze(0) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None class CheferRolloutSegmentationModel(SegmentationAbstractClass): def __init__(self, device: str = "cuda", width: int = 224, height: int = 224): super(CheferRolloutSegmentationModel, self).__init__() self.width = width self.height = height self.device = device model = vit_base_patch16_224(pretrained=True).to(device) self.baselines = Baselines(model) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.baselines.generate_rollout( image.to(self.device), start_layer=1 ).reshape(1, 1, 14, 14) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None class CheferLastLayerAttentionSegmentationModel(SegmentationAbstractClass): def __init__(self, device: str = "cuda", width: int = 224, height: int = 224): super(CheferLastLayerAttentionSegmentationModel, self).__init__() self.width = width self.height = height self.device = device model_orig_LRP = vit_orig_LRP(pretrained=True).to(device) model_orig_LRP.eval() self.orig_lrp = LRP(model_orig_LRP) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.orig_lrp.generate_LRP( image.to(self.device), method="last_layer_attn" ).reshape(1, 1, 14, 14) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None class CheferAttentionGradCAMSegmentationModel(SegmentationAbstractClass): def __init__(self, device: str = "cuda", width: int = 224, height: int = 224): super(CheferAttentionGradCAMSegmentationModel, self).__init__() self.width = width self.height = height self.device = device model = vit_base_patch16_224(pretrained=True).to(device) self.baselines = Baselines(model) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.baselines.generate_cam_attn( image.to(self.device) ).reshape(1, 1, 14, 14) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None class CheferTransformerAttributionSegmentationModel(SegmentationAbstractClass): def __init__(self, device: str = "cuda", width: int = 224, height: int = 224): super(CheferTransformerAttributionSegmentationModel, self).__init__() self.width = width self.height = height self.device = device model_LRP = vit_LRP(pretrained=True).to(device) model_LRP.eval() self.lrp = LRP(model_LRP) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.lrp.generate_LRP( image.to(self.device), start_layer=1, method="transformer_attribution" ).reshape(1, 1, 14, 14) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None class CheferFullLRPSegmentationModel(SegmentationAbstractClass): def __init__(self, device: str = "cuda", width: int = 224, height: int = 224): super(CheferFullLRPSegmentationModel, self).__init__() self.width = width self.height = height self.device = device model_LRP = vit_LRP(pretrained=True).to(device) model_LRP.eval() self.lrp = LRP(model_LRP) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.lrp.generate_LRP( image.to(self.device), method="full" ).reshape(1, 1, 224, 224) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None class CheferLastLayerLRPSegmentationModel(SegmentationAbstractClass): def __init__(self, device: str = "cuda", width: int = 224, height: int = 224): super(CheferLastLayerLRPSegmentationModel, self).__init__() self.width = width self.height = height self.device = device model_LRP = vit_LRP(pretrained=True).to(device) model_LRP.eval() self.lrp = LRP(model_LRP) def segment_individual_image(self, image: torch.Tensor, concepts: list[str], caption: str, **kwargs): if len(image.shape) == 3: image = image.unsqueeze(0) prediction_map = self.lrp.generate_LRP( image.to(self.device), method="last_layer" ).reshape(1, 1, 14, 14) # Rescale the prediction map to 64x64 prediction_map = torch.nn.functional.interpolate( prediction_map, size=(self.width, self.height), mode="nearest" ).reshape(1, self.width, self.height) return prediction_map, None