File size: 8,633 Bytes
c9b5796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
# Copyright (c) Meta Platforms, Inc. and affiliates.

import torch
import torch.nn as nn

from .base import BaseModel
from .feature_extractor import FeatureExtractor
import numpy as np
# from  .embeddings import AttentionWeightedEmbedding,

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.nn.functional as F

class ImprovedAttentionEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, weight_dim=1, dropout=0.1, weight_init='normal'):
        super(ImprovedAttentionEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.weight_dim = weight_dim

        # 可学习的权重矩阵 [num_embeddings, weight_dim]
        if weight_init == 'normal':
            self.weights = nn.Parameter(torch.randn(num_embeddings, weight_dim))
        elif weight_init == 'uniform':
            self.weights = nn.Parameter(torch.rand(num_embeddings, weight_dim))
        else:
            self.weights = nn.Parameter(torch.ones(num_embeddings, weight_dim))

        self.weight_norm = nn.LayerNorm(weight_dim)
        self.dropout = nn.Dropout(dropout)

        # L2正则化
        self.l2_reg = 1e-5

    def forward(self, input):
        embedded = self.embedding(input)  # [batch, 256, 256, embedding_dim]

        # 获取权重,并进行归一化
        weight = self.weights[input]  # [batch, 256, 256, weight_dim]
        weight = self.weight_norm(weight)
        weight = F.softmax(weight, dim=-1)

        # 对嵌入向量进行加权
        if self.weight_dim == 1:
            weighted_embedded = embedded * weight  # [batch, 256, 256, embedding_dim]
        else:
            weighted_embedded = embedded * weight.unsqueeze(-1)

        weighted_embedded = self.dropout(weighted_embedded)

        return weighted_embedded

    def get_l2_reg(self):
        return self.l2_reg * (self.weights ** 2).sum()
class AttentionWeightedEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(AttentionWeightedEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.query = nn.Parameter(torch.randn(embedding_dim))  # 可训练的查询向量
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input):
        # 获取嵌入向量
        embedded = self.embedding(input)  # Shape: [batch_size, sequence_length, embedding_dim]

        # 计算注意力得分
        attn_scores = torch.matmul(embedded, self.query)  # Shape: [batch_size, sequence_length]

        # 归一化注意力得分以得到权重
        attn_weights = self.softmax(attn_scores).unsqueeze(-1)  # Shape: [batch_size, sequence_length, 1]

        # 对嵌入向量应用权重
        weighted_embedded = embedded * attn_weights  # Shape: [batch_size, sequence_length, embedding_dim]

        return weighted_embedded
class WeightedEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(WeightedEmbedding, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        # 可学习的权重矩阵 [num_embeddings, 1]
        self.weights = nn.Parameter(torch.ones(num_embeddings, 1))

    def forward(self, input):
        embedded = self.embedding(input)  # [batch, 256, 256, embedding_dim]
        # 获取权重,并扩展维度以便进行广播运算
        weight = self.weights[input]  # [batch, 256, 256, 1]
        # 对嵌入向量进行按元素乘法
        weighted_embedded = embedded * weight  # [batch, 256, 256, embedding_dim]
        return weighted_embedded

class MapEncoderSingle(BaseModel):
    default_conf = {
        "embedding_dim": "???",
        "output_dim": None,
        "num_classes": "???",
        "backbone": "???",
        "unary_prior": False,
        "weighted_embedding": False
    }

    def _init(self, conf):
        if conf.weighted_embedding==False:
            self.embeddings = torch.nn.ModuleDict(
                {
                    k: torch.nn.Embedding(n + 1, conf.embedding_dim)
                    for k, n in conf.num_classes.items()
                }
            )
        else:
            if conf.weighted_embedding=="AttentionWeightedEmbedding":
                self.embeddings = torch.nn.ModuleDict(
                    {
                        k: AttentionWeightedEmbedding(n + 1, conf.embedding_dim)
                        for k, n in conf.num_classes.items()
                    }
                )
            elif conf.weighted_embedding=="WeightedEmbedding":
                self.embeddings = torch.nn.ModuleDict(
                    {
                        k: WeightedEmbedding(n + 1, conf.embedding_dim)
                        for k, n in conf.num_classes.items()
                    }
                )
            elif conf.weighted_embedding=="ImprovedAttentionEmbedding":
                self.embeddings = torch.nn.ModuleDict(
                    {
                        k: ImprovedAttentionEmbedding(n + 1, conf.embedding_dim)
                        for k, n in conf.num_classes.items()
                    }
                )
            else:
                pass


        #num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33}
        input_dim = len(conf.num_classes) * conf.embedding_dim
        output_dim = conf.output_dim
        if output_dim is None:
            output_dim = conf.backbone.output_dim
        if conf.unary_prior:
            output_dim += 1
        if conf.backbone is None:
            self.encoder = nn.Conv2d(input_dim, output_dim, 1)
        elif conf.backbone == "simple":
            self.encoder = nn.Sequential(
                nn.Conv2d(input_dim, 128, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 128, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, output_dim, 3, padding=1),
            )
        else:
            self.encoder = FeatureExtractor(
                {
                    **conf.backbone,
                    "input_dim": input_dim,
                    "output_dim": output_dim,
                }
            )
    def batch_process(self,input_tensor):
        # 获取输入张量的维度
        batch_size, dim1, dim2, dim3 = input_tensor.shape

        # 首先,我们需要对第一个索引为0的二维数组中的非零元素增加43
        input_tensor[:, 0, :, :] += torch.where(input_tensor[:, 0, :, :] != 0, 43, 0)

        # 接着,对第一个索引为1的二维数组中的非零元素增加33
        input_tensor[:, 1, :, :] += torch.where(input_tensor[:, 1, :, :] != 0, 33, 0)

        # 创建一个全零的输出张量
        output_tensor = torch.zeros((batch_size, dim2, dim3), dtype=input_tensor.dtype, device=input_tensor.device)

        # 找到输入张量中至少有一个非零值的位置
        nonzero_mask = torch.any(input_tensor != 0, dim=1)

        # 根据优先级赋值
        output_tensor[nonzero_mask] = input_tensor[:, 2, :, :][nonzero_mask]
        output_tensor[nonzero_mask] = torch.where(input_tensor[:, 2, :, :][nonzero_mask] == 0, input_tensor[:, 1, :, :][nonzero_mask], output_tensor[nonzero_mask])
        output_tensor[nonzero_mask] = torch.where(torch.logical_and(input_tensor[:, 2, :, :][nonzero_mask] == 0, input_tensor[:, 1, :, :][nonzero_mask] == 0), input_tensor[:, 0, :, :][nonzero_mask], output_tensor[nonzero_mask])

        return output_tensor
    def _forward(self, data):
        temp=data["map"]
        temp=self.batch_process(temp)
        # a=self.embeddings["all"]
        # print("temp",temp.shape,data["map"].shape)

        # 找到tensor中的最大值
        # max_value = temp.max()

        # print("最大值是:", max_value.item())
        embeddings = self.embeddings["all"](temp)#shape:[batch,256,256,48]

        # print("embeddings.shape A",embeddings.shape)
        embeddings =embeddings.permute(0, 3, 1, 2)
        # print("embeddings.shape B",embeddings.shape)
        # print("Single",embeddings.shape)
        pass
        if isinstance(self.encoder, BaseModel):
            # print("encoder is BaseModel")
            features = self.encoder({"image": embeddings})["feature_maps"]
        else:
            # print("encoder is not BaseModel")
            features = [self.encoder(embeddings)]
        pred = {}
        if self.conf.unary_prior:
            pred["log_prior"] = [f[:, -1] for f in features]
            features = [f[:, :-1] for f in features]
        pred["map_features"] = features#6,8,256,256 list of tensor ,shape:[6,8, 256, 256]

        return pred