326 lines
12 KiB
Python
326 lines
12 KiB
Python
# This file is part of COAT, and is distributed under the
|
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
|
|
|
import random
|
|
import torch
|
|
|
|
class exchange_token:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def __call__(self, features, mask_box):
|
|
b, hw, c = features.size()
|
|
assert hw == 14*14
|
|
new_idx, mask_x1, mask_x2, mask_y1, mask_y2 = mask_box
|
|
features = features.view(b, 14, 14, c)
|
|
features[:, mask_x1 : mask_x2, mask_y1 : mask_y2, :] = features[new_idx, mask_x1 : mask_x2, mask_y1 : mask_y2, :]
|
|
features = features.view(b, hw, c)
|
|
return features
|
|
|
|
class jigsaw_token:
|
|
def __init__(self, shift=5, group=2, begin=1):
|
|
self.shift = shift
|
|
self.group = group
|
|
self.begin = begin
|
|
|
|
def __call__(self, features):
|
|
batchsize = features.size(0)
|
|
dim = features.size(2)
|
|
|
|
num_tokens = features.size(1)
|
|
if num_tokens == 196:
|
|
self.group = 2
|
|
elif num_tokens == 25:
|
|
self.group = 5
|
|
else:
|
|
raise Exception("Jigsaw - Unwanted number of tokens")
|
|
|
|
# Shift Operation
|
|
feature_random = torch.cat([features[:, self.begin-1+self.shift:, :], features[:, self.begin-1:self.begin-1+self.shift, :]], dim=1)
|
|
x = feature_random
|
|
|
|
# Patch Shuffle Operation
|
|
try:
|
|
x = x.view(batchsize, self.group, -1, dim)
|
|
except:
|
|
raise Exception("Jigsaw - Unwanted number of groups")
|
|
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
|
x = x.view(batchsize, -1, dim)
|
|
|
|
return x
|
|
|
|
class get_mask_box:
|
|
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
|
|
self.shape = shape
|
|
self.mask_size = mask_size
|
|
self.mode = mode
|
|
|
|
def __call__(self, features):
|
|
# Stripe mask
|
|
if self.shape == 'stripe':
|
|
if self.mode == 'horizontal':
|
|
mask_box = self.hstripe(features, self.mask_size)
|
|
elif self.mode == 'vertical':
|
|
mask_box = self.vstripe(features, self.mask_size)
|
|
elif self.mode == 'random_direction':
|
|
if random.random() < 0.5:
|
|
mask_box = self.hstripe(features, self.mask_size)
|
|
else:
|
|
mask_box = self.vstripe(features, self.mask_size)
|
|
else:
|
|
raise Exception("Unknown stripe mask mode name")
|
|
# Square mask
|
|
elif self.shape == 'square':
|
|
if self.mode == 'random_size':
|
|
self.mask_size = 4 if random.random() < 0.5 else 5
|
|
mask_box = self.square(features, self.mask_size)
|
|
# Random stripe/square mask
|
|
elif self.shape == 'random':
|
|
random_num = random.random()
|
|
if random_num < 0.25:
|
|
mask_box = self.hstripe(features, 2)
|
|
elif random_num < 0.5 and random_num >= 0.25:
|
|
mask_box = self.vstripe(features, 2)
|
|
elif random_num < 0.75 and random_num >= 0.5:
|
|
mask_box = self.square(features, 4)
|
|
else:
|
|
mask_box = self.square(features, 5)
|
|
else:
|
|
raise Exception("Unknown mask shape name")
|
|
return mask_box
|
|
|
|
def hstripe(self, features, mask_size):
|
|
"""
|
|
"""
|
|
# horizontal stripe
|
|
mask_x1 = 0
|
|
mask_x2 = features.shape[2]
|
|
y1_max = features.shape[3] - mask_size
|
|
mask_y1 = torch.randint(y1_max, (1,))
|
|
mask_y2 = mask_y1 + mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
|
|
return mask_box
|
|
|
|
def vstripe(self, features, mask_size):
|
|
"""
|
|
"""
|
|
# vertical stripe
|
|
mask_y1 = 0
|
|
mask_y2 = features.shape[3]
|
|
x1_max = features.shape[2] - mask_size
|
|
mask_x1 = torch.randint(x1_max, (1,))
|
|
mask_x2 = mask_x1 + mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
|
|
return mask_box
|
|
|
|
def square(self, features, mask_size):
|
|
"""
|
|
"""
|
|
# square
|
|
x1_max = features.shape[2] - mask_size
|
|
y1_max = features.shape[3] - mask_size
|
|
mask_x1 = torch.randint(x1_max, (1,))
|
|
mask_y1 = torch.randint(y1_max, (1,))
|
|
mask_x2 = mask_x1 + mask_size
|
|
mask_y2 = mask_y1 + mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
|
|
return mask_box
|
|
|
|
|
|
class exchange_patch:
|
|
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
|
|
self.shape = shape
|
|
self.mask_size = mask_size
|
|
self.mode = mode
|
|
|
|
def __call__(self, features):
|
|
# Stripe mask
|
|
if self.shape == 'stripe':
|
|
if self.mode == 'horizontal':
|
|
features = self.xpatch_hstripe(features, self.mask_size)
|
|
elif self.mode == 'vertical':
|
|
features = self.xpatch_vstripe(features, self.mask_size)
|
|
elif self.mode == 'random_direction':
|
|
if random.random() < 0.5:
|
|
features = self.xpatch_hstripe(features, self.mask_size)
|
|
else:
|
|
features = self.xpatch_vstripe(features, self.mask_size)
|
|
else:
|
|
raise Exception("Unknown stripe mask mode name")
|
|
# Square mask
|
|
elif self.shape == 'square':
|
|
if self.mode == 'random_size':
|
|
self.mask_size = 4 if random.random() < 0.5 else 5
|
|
features = self.xpatch_square(features, self.mask_size)
|
|
# Random stripe/square mask
|
|
elif self.shape == 'random':
|
|
random_num = random.random()
|
|
if random_num < 0.25:
|
|
features = self.xpatch_hstripe(features, 2)
|
|
elif random_num < 0.5 and random_num >= 0.25:
|
|
features = self.xpatch_vstripe(features, 2)
|
|
elif random_num < 0.75 and random_num >= 0.5:
|
|
features = self.xpatch_square(features, 4)
|
|
else:
|
|
features = self.xpatch_square(features, 5)
|
|
else:
|
|
raise Exception("Unknown mask shape name")
|
|
|
|
return features
|
|
|
|
def xpatch_hstripe(self, features, mask_size):
|
|
"""
|
|
"""
|
|
# horizontal stripe
|
|
y1_max = features.shape[3] - mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_y1 = torch.randint(y1_max, (1,))
|
|
mask_y2 = mask_y1 + mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
features[:, :, :, mask_y1 : mask_y2] = features[new_idx, :, :, mask_y1 : mask_y2]
|
|
return features
|
|
|
|
|
|
def xpatch_vstripe(self, features, mask_size):
|
|
"""
|
|
"""
|
|
# vertical stripe
|
|
x1_max = features.shape[2] - mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_x1 = torch.randint(x1_max, (1,))
|
|
mask_x2 = mask_x1 + mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
features[:, :, mask_x1 : mask_x2, :] = features[new_idx, :, mask_x1 : mask_x2, :]
|
|
return features
|
|
|
|
|
|
def xpatch_square(self, features, mask_size):
|
|
"""
|
|
"""
|
|
# square
|
|
x1_max = features.shape[2] - mask_size
|
|
y1_max = features.shape[3] - mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_x1 = torch.randint(x1_max, (1,))
|
|
mask_y1 = torch.randint(y1_max, (1,))
|
|
mask_x2 = mask_x1 + mask_size
|
|
mask_y2 = mask_y1 + mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
features[:, :, mask_x1 : mask_x2, mask_y1 : mask_y2] = features[new_idx, :, mask_x1 : mask_x2, mask_y1 : mask_y2]
|
|
return features
|
|
|
|
|
|
class cutout_patch:
|
|
def __init__(self, mask_size=2):
|
|
self.mask_size = mask_size
|
|
|
|
def __call__(self, features):
|
|
if random.random() < 0.5:
|
|
y1_max = features.shape[3] - self.mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_y1 = torch.randint(y1_max, (features.shape[0],))
|
|
mask_y2 = mask_y1 + self.mask_size
|
|
for k in range(features.shape[0]):
|
|
features[k, :, :, mask_y1[k] : mask_y2[k]] = 0
|
|
else:
|
|
x1_max = features.shape[3] - self.mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_x1 = torch.randint(x1_max, (features.shape[0],))
|
|
mask_x2 = mask_x1 + self.mask_size
|
|
for k in range(features.shape[0]):
|
|
features[k, :, mask_x1[k] : mask_x2[k], :] = 0
|
|
|
|
return features
|
|
|
|
|
|
class erase_patch:
|
|
def __init__(self, mask_size=2):
|
|
self.mask_size = mask_size
|
|
|
|
def __call__(self, features):
|
|
std, mean = torch.std_mean(features.detach())
|
|
dim = features.shape[1]
|
|
if random.random() < 0.5:
|
|
y1_max = features.shape[3] - self.mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_y1 = torch.randint(y1_max, (features.shape[0],))
|
|
mask_y2 = mask_y1 + self.mask_size
|
|
for k in range(features.shape[0]):
|
|
features[k, :, :, mask_y1[k] : mask_y2[k]] = torch.normal(mean.repeat(dim,14,2), std.repeat(dim,14,2))
|
|
else:
|
|
x1_max = features.shape[3] - self.mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_x1 = torch.randint(x1_max, (features.shape[0],))
|
|
mask_x2 = mask_x1 + self.mask_size
|
|
for k in range(features.shape[0]):
|
|
features[k, :, mask_x1[k] : mask_x2[k], :] = torch.normal(mean.repeat(dim,2,14), std.repeat(dim,2,14))
|
|
|
|
return features
|
|
|
|
class mixup_patch:
|
|
def __init__(self, mask_size=2):
|
|
self.mask_size = mask_size
|
|
|
|
def __call__(self, features):
|
|
lam = random.uniform(0, 1)
|
|
if random.random() < 0.5:
|
|
y1_max = features.shape[3] - self.mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_y1 = torch.randint(y1_max, (1,))
|
|
mask_y2 = mask_y1 + self.mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
features[:, :, :, mask_y1 : mask_y2] = lam*features[:, :, :, mask_y1 : mask_y2] + (1-lam)*features[new_idx, :, :, mask_y1 : mask_y2]
|
|
else:
|
|
x1_max = features.shape[2] - self.mask_size
|
|
num_masks = 1
|
|
for i in range(num_masks):
|
|
mask_x1 = torch.randint(x1_max, (1,))
|
|
mask_x2 = mask_x1 + self.mask_size
|
|
new_idx = torch.randperm(features.shape[0])
|
|
features[:, :, mask_x1 : mask_x2, :] = lam*features[:, :, mask_x1 : mask_x2, :] + (1-lam)*features[new_idx, :, mask_x1 : mask_x2, :]
|
|
|
|
return features
|
|
|
|
|
|
class jigsaw_patch:
|
|
def __init__(self, shift=5, group=2):
|
|
self.shift = shift
|
|
self.group = group
|
|
|
|
def __call__(self, features):
|
|
batchsize = features.size(0)
|
|
dim = features.size(1)
|
|
features = features.view(batchsize, dim, -1)
|
|
|
|
# Shift Operation
|
|
feature_random = torch.cat([features[:, :, self.shift:], features[:, :, :self.shift]], dim=2)
|
|
x = feature_random
|
|
|
|
# Patch Shuffle Operation
|
|
try:
|
|
x = x.view(batchsize, dim, self.group, -1)
|
|
except:
|
|
x = torch.cat([x, x[:, -2:-1, :]], dim=1)
|
|
x = x.view(batchsize, self.group, -1, dim)
|
|
|
|
x = torch.transpose(x, 2, 3).contiguous()
|
|
|
|
x = x.view(batchsize, dim, -1)
|
|
x = x.view(batchsize, dim, 14, 14)
|
|
|
|
return x
|
|
|