defforward_features(self, x): x = self.patch_embed(x) # 先取得图象块的embeddings if self.ape: # 位置编码是否需要 x = x + self.absolute_pos_embed x = self.pos_drop(x) # dropout层
for layer in self.layers: x = layer(x)
x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C L x = torch.flatten(x, 1) # 展平 return x
defforward(self, x): x = self.forward_features(x) x = self.head(x) # 线性层 return x
defforward(self, x): """ x: B, H*W, C """ H, W = self.input_resolution B, L, C = x.shape # L就等于num_patches assert L == H * W, "input feature has wrong size" assert H % 2 == 0and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C) # 将特征图再分成按patch为单位,长为W,高为H的形状排列 # 对原图进行取样,降低image_size x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C size是原来的二分之一,通道数是原来的2倍
x = self.norm(x) x = self.reduction(x)
return x
下面是一个示意图(输入张量N=1, H=W=8, C=1,不包含最后的全连接层调整)
个人感觉这像是PixelShuffle的反操作
Window Partition/Reverse
window partition函数是用于对张量划分窗口,指定窗口大小。将原本的张量从 N H W C, 划分成 num_windows*B, window_size, window_size, C,其中 num_windows = H*W / window_size,即窗口的个数,这里的窗口是由patches为单位组成的。而window reverse函数则是对应的逆过程。这两个函数会在后面的Window Attention用到。
1 2 3 4 5 6 7 8 9 10 11 12
defwindow_partition(x, window_size): B, H, W, C = x.shape x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows # (num_windows*B, window_size, window_size, C)
defwindow_reverse(windows, window_size, H, W): B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x
classWindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window.(int, int) num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """
defforward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
defforward(self, x): H, W = self.input_resolution B, L, C = x.shape assert L == H * W, "input feature has wrong size"
shortcut = x x = self.norm1(x) x = x.view(B, H, W, C)
# cyclic shift if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x
# partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C)
# FFN x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x)))