COAT/models/transformer.py

301 lines
11 KiB
Python

# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import math
import random
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.mask import exchange_token, exchange_patch, get_mask_box, jigsaw_token, cutout_patch, erase_patch, mixup_patch, jigsaw_patch
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class TransformerHead(nn.Module):
def __init__(
self,
cfg,
trans_names,
kernel_size,
use_feature_mask,
):
super(TransformerHead, self).__init__()
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
# Mask parameters
self.use_feature_mask = use_feature_mask
mask_shape = cfg.MODEL.MASK_SHAPE
mask_size = cfg.MODEL.MASK_SIZE
mask_mode = cfg.MODEL.MASK_MODE
self.bypass_mask = exchange_patch(mask_shape, mask_size, mask_mode)
self.get_mask_box = get_mask_box(mask_shape, mask_size, mask_mode)
self.transformer_encoder = Transformers(
cfg=cfg,
trans_names=trans_names,
kernel_size=kernel_size,
use_feature_mask=use_feature_mask,
)
self.conv0 = conv1x1(1024, 1024)
self.conv1 = conv1x1(1024, d_model)
self.conv2 = conv1x1(d_model, 2048)
def forward(self, box_features):
mask_box = self.get_mask_box(box_features)
if self.use_feature_mask:
skip_features = self.conv0(box_features)
if self.training:
skip_features = self.bypass_mask(skip_features)
else:
skip_features = box_features
trans_features = {}
trans_features["before_trans"] = F.adaptive_max_pool2d(skip_features, 1)
box_features = self.conv1(box_features)
box_features = self.transformer_encoder((box_features,mask_box))
box_features = self.conv2(box_features)
trans_features["after_trans"] = F.adaptive_max_pool2d(box_features, 1)
return trans_features
class Transformers(nn.Module):
def __init__(
self,
cfg,
trans_names,
kernel_size,
use_feature_mask,
):
super(Transformers, self).__init__()
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
self.use_feature_mask = use_feature_mask
# If no conv before transformer, we do not use scales
if not cfg.MODEL.TRANSFORMER.USE_PATCH2VEC:
trans_names = ['scale1']
kernel_size = [(1,1)]
self.trans_names = trans_names
self.scale_size = len(self.trans_names)
hidden = d_model//(2*self.scale_size)
# kernel_size: (padding, stride)
kernels = {
(1,1): [(0,0),(1,1)],
(3,3): [(1,1),(1,1)]
}
padding = []
stride = []
for ksize in kernel_size:
if ksize not in [(1,1),(3,3)]:
raise ValueError('Undefined kernel size.')
padding.append(kernels[ksize][0])
stride.append(kernels[ksize][1])
self.use_output_layer = cfg.MODEL.TRANSFORMER.USE_OUTPUT_LAYER
self.use_global_shortcut = cfg.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT
self.blocks = nn.ModuleDict()
for tname, ksize, psize, ssize in zip(self.trans_names, kernel_size, padding, stride):
transblock = Transformer(
cfg, d_model//self.scale_size, ksize, psize, ssize, hidden, use_feature_mask
)
self.blocks[tname] = nn.Sequential(transblock)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
self.mask_para = [cfg.MODEL.MASK_SHAPE, cfg.MODEL.MASK_SIZE, cfg.MODEL.MASK_MODE]
def forward(self, inputs):
trans_feat = []
enc_feat, mask_box = inputs
if self.training and self.use_feature_mask and self.feature_aug_type == 'exchange_patch':
feature_mask = exchange_patch(self.mask_para[0], self.mask_para[1], self.mask_para[2])
enc_feat = feature_mask(enc_feat)
for tname, feat in zip(self.trans_names, torch.chunk(enc_feat, len(self.trans_names), dim=1)):
feat = self.blocks[tname]((feat, mask_box))
trans_feat.append(feat)
trans_feat = torch.cat(trans_feat, 1)
if self.use_output_layer:
trans_feat = self.output_linear(trans_feat)
if self.use_global_shortcut:
trans_feat = enc_feat + trans_feat
return trans_feat
class Transformer(nn.Module):
def __init__(self, cfg, channel, kernel_size, padding, stride, hidden, use_feature_mask
):
super(Transformer, self).__init__()
self.k = kernel_size[0]
stack_num = cfg.MODEL.TRANSFORMER.ENCODER_LAYERS
num_head = cfg.MODEL.TRANSFORMER.N_HEAD
dropout = cfg.MODEL.TRANSFORMER.DROPOUT
output_size = (14,14)
token_size = tuple(map(lambda x,y:x//y, output_size, stride))
blocks = []
self.transblock = TransformerBlock(token_size, hidden=hidden, num_head=num_head, dropout=dropout)
for _ in range(stack_num):
blocks.append(self.transblock)
self.transformer = nn.Sequential(*blocks)
self.patch2vec = nn.Conv2d(channel, hidden, kernel_size=kernel_size, stride=stride, padding=padding)
self.vec2patch = Vec2Patch(channel, hidden, output_size, kernel_size, stride, padding)
self.use_local_shortcut = cfg.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT
self.use_feature_mask = use_feature_mask
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
self.use_patch2vec = cfg.MODEL.TRANSFORMER.USE_PATCH2VEC
def forward(self, inputs):
enc_feat, mask_box = inputs
b, c, h, w = enc_feat.size()
trans_feat = self.patch2vec(enc_feat)
_, c, h, w = trans_feat.size()
trans_feat = trans_feat.view(b, c, -1).permute(0, 2, 1)
# For 1x1 & 3x3 kernels, exchange tokens
if self.training and self.use_feature_mask:
if self.feature_aug_type == 'exchange_token':
feature_mask = exchange_token()
trans_feat = feature_mask(trans_feat, mask_box)
elif self.feature_aug_type == 'cutout_patch':
feature_mask = cutout_patch()
trans_feat = feature_mask(trans_feat)
elif self.feature_aug_type == 'erase_patch':
feature_mask = erase_patch()
trans_feat = feature_mask(trans_feat)
elif self.feature_aug_type == 'mixup_patch':
feature_mask = mixup_patch()
trans_feat = feature_mask(trans_feat)
if self.use_feature_mask:
if self.feature_aug_type == 'jigsaw_patch':
feature_mask = jigsaw_patch()
trans_feat = feature_mask(trans_feat)
elif self.feature_aug_type == 'jigsaw_token':
feature_mask = jigsaw_token()
trans_feat = feature_mask(trans_feat)
trans_feat = self.transformer(trans_feat)
trans_feat = self.vec2patch(trans_feat)
if self.use_local_shortcut:
trans_feat = enc_feat + trans_feat
return trans_feat
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, tokensize, hidden=128, num_head=4, dropout=0.1):
super().__init__()
self.attention = MultiHeadedAttention(tokensize, d_model=hidden, head=num_head, p=dropout)
self.ffn = FeedForward(hidden, p=dropout)
self.norm1 = nn.LayerNorm(hidden)
self.norm2 = nn.LayerNorm(hidden)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = self.norm1(x)
x = x + self.dropout(self.attention(x))
y = self.norm2(x)
x = x + self.ffn(y)
return x
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def __init__(self, p=0.1):
super(Attention, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
p_attn = F.softmax(scores, dim=-1)
p_attn = self.dropout(p_attn)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class Vec2Patch(nn.Module):
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
super(Vec2Patch, self).__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(hidden, c_out)
self.to_patch = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
h, w = output_size
def forward(self, x):
feat = self.embedding(x)
b, n, c = feat.size()
feat = feat.permute(0, 2, 1)
feat = self.to_patch(feat)
return feat
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, tokensize, d_model, head, p=0.1):
super().__init__()
self.query_embedding = nn.Linear(d_model, d_model)
self.value_embedding = nn.Linear(d_model, d_model)
self.key_embedding = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention(p=p)
self.head = head
self.h, self.w = tokensize
def forward(self, x):
b, n, c = x.size()
c_h = c // self.head
key = self.key_embedding(x)
query = self.query_embedding(x)
value = self.value_embedding(x)
key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
att, _ = self.attention(query, key, value)
att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c)
output = self.output_linear(att)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, p=0.1):
super(FeedForward, self).__init__()
self.conv = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(inplace=True),
nn.Dropout(p=p),
nn.Linear(d_model * 4, d_model),
nn.Dropout(p=p))
def forward(self, x):
x = self.conv(x)
return x