301 lines
11 KiB
Python
301 lines
11 KiB
Python
|
# This file is part of COAT, and is distributed under the
|
||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||
|
|
||
|
import math
|
||
|
import random
|
||
|
from functools import reduce
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from utils.mask import exchange_token, exchange_patch, get_mask_box, jigsaw_token, cutout_patch, erase_patch, mixup_patch, jigsaw_patch
|
||
|
|
||
|
|
||
|
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||
|
"""1x1 convolution"""
|
||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||
|
|
||
|
|
||
|
class TransformerHead(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
cfg,
|
||
|
trans_names,
|
||
|
kernel_size,
|
||
|
use_feature_mask,
|
||
|
):
|
||
|
super(TransformerHead, self).__init__()
|
||
|
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
|
||
|
|
||
|
# Mask parameters
|
||
|
self.use_feature_mask = use_feature_mask
|
||
|
mask_shape = cfg.MODEL.MASK_SHAPE
|
||
|
mask_size = cfg.MODEL.MASK_SIZE
|
||
|
mask_mode = cfg.MODEL.MASK_MODE
|
||
|
|
||
|
self.bypass_mask = exchange_patch(mask_shape, mask_size, mask_mode)
|
||
|
self.get_mask_box = get_mask_box(mask_shape, mask_size, mask_mode)
|
||
|
|
||
|
self.transformer_encoder = Transformers(
|
||
|
cfg=cfg,
|
||
|
trans_names=trans_names,
|
||
|
kernel_size=kernel_size,
|
||
|
use_feature_mask=use_feature_mask,
|
||
|
)
|
||
|
self.conv0 = conv1x1(1024, 1024)
|
||
|
self.conv1 = conv1x1(1024, d_model)
|
||
|
self.conv2 = conv1x1(d_model, 2048)
|
||
|
|
||
|
def forward(self, box_features):
|
||
|
mask_box = self.get_mask_box(box_features)
|
||
|
|
||
|
if self.use_feature_mask:
|
||
|
skip_features = self.conv0(box_features)
|
||
|
if self.training:
|
||
|
skip_features = self.bypass_mask(skip_features)
|
||
|
else:
|
||
|
skip_features = box_features
|
||
|
|
||
|
trans_features = {}
|
||
|
trans_features["before_trans"] = F.adaptive_max_pool2d(skip_features, 1)
|
||
|
box_features = self.conv1(box_features)
|
||
|
box_features = self.transformer_encoder((box_features,mask_box))
|
||
|
box_features = self.conv2(box_features)
|
||
|
trans_features["after_trans"] = F.adaptive_max_pool2d(box_features, 1)
|
||
|
|
||
|
return trans_features
|
||
|
|
||
|
|
||
|
class Transformers(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
cfg,
|
||
|
trans_names,
|
||
|
kernel_size,
|
||
|
use_feature_mask,
|
||
|
):
|
||
|
super(Transformers, self).__init__()
|
||
|
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
|
||
|
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
|
||
|
self.use_feature_mask = use_feature_mask
|
||
|
|
||
|
# If no conv before transformer, we do not use scales
|
||
|
if not cfg.MODEL.TRANSFORMER.USE_PATCH2VEC:
|
||
|
trans_names = ['scale1']
|
||
|
kernel_size = [(1,1)]
|
||
|
|
||
|
self.trans_names = trans_names
|
||
|
self.scale_size = len(self.trans_names)
|
||
|
hidden = d_model//(2*self.scale_size)
|
||
|
|
||
|
# kernel_size: (padding, stride)
|
||
|
kernels = {
|
||
|
(1,1): [(0,0),(1,1)],
|
||
|
(3,3): [(1,1),(1,1)]
|
||
|
}
|
||
|
|
||
|
padding = []
|
||
|
stride = []
|
||
|
for ksize in kernel_size:
|
||
|
if ksize not in [(1,1),(3,3)]:
|
||
|
raise ValueError('Undefined kernel size.')
|
||
|
padding.append(kernels[ksize][0])
|
||
|
stride.append(kernels[ksize][1])
|
||
|
|
||
|
self.use_output_layer = cfg.MODEL.TRANSFORMER.USE_OUTPUT_LAYER
|
||
|
self.use_global_shortcut = cfg.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT
|
||
|
|
||
|
self.blocks = nn.ModuleDict()
|
||
|
for tname, ksize, psize, ssize in zip(self.trans_names, kernel_size, padding, stride):
|
||
|
transblock = Transformer(
|
||
|
cfg, d_model//self.scale_size, ksize, psize, ssize, hidden, use_feature_mask
|
||
|
)
|
||
|
self.blocks[tname] = nn.Sequential(transblock)
|
||
|
|
||
|
self.output_linear = nn.Sequential(
|
||
|
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
|
||
|
nn.LeakyReLU(0.2, inplace=True)
|
||
|
)
|
||
|
self.mask_para = [cfg.MODEL.MASK_SHAPE, cfg.MODEL.MASK_SIZE, cfg.MODEL.MASK_MODE]
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
trans_feat = []
|
||
|
enc_feat, mask_box = inputs
|
||
|
|
||
|
if self.training and self.use_feature_mask and self.feature_aug_type == 'exchange_patch':
|
||
|
feature_mask = exchange_patch(self.mask_para[0], self.mask_para[1], self.mask_para[2])
|
||
|
enc_feat = feature_mask(enc_feat)
|
||
|
|
||
|
for tname, feat in zip(self.trans_names, torch.chunk(enc_feat, len(self.trans_names), dim=1)):
|
||
|
feat = self.blocks[tname]((feat, mask_box))
|
||
|
trans_feat.append(feat)
|
||
|
|
||
|
trans_feat = torch.cat(trans_feat, 1)
|
||
|
if self.use_output_layer:
|
||
|
trans_feat = self.output_linear(trans_feat)
|
||
|
if self.use_global_shortcut:
|
||
|
trans_feat = enc_feat + trans_feat
|
||
|
return trans_feat
|
||
|
|
||
|
|
||
|
class Transformer(nn.Module):
|
||
|
def __init__(self, cfg, channel, kernel_size, padding, stride, hidden, use_feature_mask
|
||
|
):
|
||
|
super(Transformer, self).__init__()
|
||
|
self.k = kernel_size[0]
|
||
|
stack_num = cfg.MODEL.TRANSFORMER.ENCODER_LAYERS
|
||
|
num_head = cfg.MODEL.TRANSFORMER.N_HEAD
|
||
|
dropout = cfg.MODEL.TRANSFORMER.DROPOUT
|
||
|
output_size = (14,14)
|
||
|
token_size = tuple(map(lambda x,y:x//y, output_size, stride))
|
||
|
blocks = []
|
||
|
self.transblock = TransformerBlock(token_size, hidden=hidden, num_head=num_head, dropout=dropout)
|
||
|
for _ in range(stack_num):
|
||
|
blocks.append(self.transblock)
|
||
|
self.transformer = nn.Sequential(*blocks)
|
||
|
self.patch2vec = nn.Conv2d(channel, hidden, kernel_size=kernel_size, stride=stride, padding=padding)
|
||
|
self.vec2patch = Vec2Patch(channel, hidden, output_size, kernel_size, stride, padding)
|
||
|
self.use_local_shortcut = cfg.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT
|
||
|
self.use_feature_mask = use_feature_mask
|
||
|
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
|
||
|
self.use_patch2vec = cfg.MODEL.TRANSFORMER.USE_PATCH2VEC
|
||
|
|
||
|
def forward(self, inputs):
|
||
|
enc_feat, mask_box = inputs
|
||
|
b, c, h, w = enc_feat.size()
|
||
|
|
||
|
trans_feat = self.patch2vec(enc_feat)
|
||
|
|
||
|
_, c, h, w = trans_feat.size()
|
||
|
trans_feat = trans_feat.view(b, c, -1).permute(0, 2, 1)
|
||
|
|
||
|
# For 1x1 & 3x3 kernels, exchange tokens
|
||
|
if self.training and self.use_feature_mask:
|
||
|
if self.feature_aug_type == 'exchange_token':
|
||
|
feature_mask = exchange_token()
|
||
|
trans_feat = feature_mask(trans_feat, mask_box)
|
||
|
elif self.feature_aug_type == 'cutout_patch':
|
||
|
feature_mask = cutout_patch()
|
||
|
trans_feat = feature_mask(trans_feat)
|
||
|
elif self.feature_aug_type == 'erase_patch':
|
||
|
feature_mask = erase_patch()
|
||
|
trans_feat = feature_mask(trans_feat)
|
||
|
elif self.feature_aug_type == 'mixup_patch':
|
||
|
feature_mask = mixup_patch()
|
||
|
trans_feat = feature_mask(trans_feat)
|
||
|
|
||
|
if self.use_feature_mask:
|
||
|
if self.feature_aug_type == 'jigsaw_patch':
|
||
|
feature_mask = jigsaw_patch()
|
||
|
trans_feat = feature_mask(trans_feat)
|
||
|
elif self.feature_aug_type == 'jigsaw_token':
|
||
|
feature_mask = jigsaw_token()
|
||
|
trans_feat = feature_mask(trans_feat)
|
||
|
|
||
|
trans_feat = self.transformer(trans_feat)
|
||
|
trans_feat = self.vec2patch(trans_feat)
|
||
|
if self.use_local_shortcut:
|
||
|
trans_feat = enc_feat + trans_feat
|
||
|
|
||
|
return trans_feat
|
||
|
|
||
|
|
||
|
class TransformerBlock(nn.Module):
|
||
|
"""
|
||
|
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
|
||
|
"""
|
||
|
def __init__(self, tokensize, hidden=128, num_head=4, dropout=0.1):
|
||
|
super().__init__()
|
||
|
self.attention = MultiHeadedAttention(tokensize, d_model=hidden, head=num_head, p=dropout)
|
||
|
self.ffn = FeedForward(hidden, p=dropout)
|
||
|
self.norm1 = nn.LayerNorm(hidden)
|
||
|
self.norm2 = nn.LayerNorm(hidden)
|
||
|
self.dropout = nn.Dropout(p=dropout)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.norm1(x)
|
||
|
x = x + self.dropout(self.attention(x))
|
||
|
y = self.norm2(x)
|
||
|
x = x + self.ffn(y)
|
||
|
|
||
|
return x
|
||
|
|
||
|
|
||
|
class Attention(nn.Module):
|
||
|
"""
|
||
|
Compute 'Scaled Dot Product Attention
|
||
|
"""
|
||
|
def __init__(self, p=0.1):
|
||
|
super(Attention, self).__init__()
|
||
|
self.dropout = nn.Dropout(p=p)
|
||
|
|
||
|
def forward(self, query, key, value):
|
||
|
scores = torch.matmul(query, key.transpose(-2, -1)
|
||
|
) / math.sqrt(query.size(-1))
|
||
|
p_attn = F.softmax(scores, dim=-1)
|
||
|
p_attn = self.dropout(p_attn)
|
||
|
p_val = torch.matmul(p_attn, value)
|
||
|
return p_val, p_attn
|
||
|
|
||
|
|
||
|
class Vec2Patch(nn.Module):
|
||
|
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
|
||
|
super(Vec2Patch, self).__init__()
|
||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||
|
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
||
|
self.embedding = nn.Linear(hidden, c_out)
|
||
|
self.to_patch = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
|
||
|
h, w = output_size
|
||
|
|
||
|
def forward(self, x):
|
||
|
feat = self.embedding(x)
|
||
|
b, n, c = feat.size()
|
||
|
feat = feat.permute(0, 2, 1)
|
||
|
feat = self.to_patch(feat)
|
||
|
|
||
|
return feat
|
||
|
|
||
|
class MultiHeadedAttention(nn.Module):
|
||
|
"""
|
||
|
Take in model size and number of heads.
|
||
|
"""
|
||
|
def __init__(self, tokensize, d_model, head, p=0.1):
|
||
|
super().__init__()
|
||
|
self.query_embedding = nn.Linear(d_model, d_model)
|
||
|
self.value_embedding = nn.Linear(d_model, d_model)
|
||
|
self.key_embedding = nn.Linear(d_model, d_model)
|
||
|
self.output_linear = nn.Linear(d_model, d_model)
|
||
|
self.attention = Attention(p=p)
|
||
|
self.head = head
|
||
|
self.h, self.w = tokensize
|
||
|
|
||
|
def forward(self, x):
|
||
|
b, n, c = x.size()
|
||
|
c_h = c // self.head
|
||
|
key = self.key_embedding(x)
|
||
|
query = self.query_embedding(x)
|
||
|
value = self.value_embedding(x)
|
||
|
key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
|
||
|
query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
|
||
|
value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
|
||
|
att, _ = self.attention(query, key, value)
|
||
|
att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c)
|
||
|
output = self.output_linear(att)
|
||
|
|
||
|
return output
|
||
|
|
||
|
|
||
|
class FeedForward(nn.Module):
|
||
|
def __init__(self, d_model, p=0.1):
|
||
|
super(FeedForward, self).__init__()
|
||
|
self.conv = nn.Sequential(
|
||
|
nn.Linear(d_model, d_model * 4),
|
||
|
nn.ReLU(inplace=True),
|
||
|
nn.Dropout(p=p),
|
||
|
nn.Linear(d_model * 4, d_model),
|
||
|
nn.Dropout(p=p))
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.conv(x)
|
||
|
return x
|