COAT/main/utils/mask.py

326 lines
12 KiB
Python

# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import random
import torch
class exchange_token:
def __init__(self):
pass
def __call__(self, features, mask_box):
b, hw, c = features.size()
assert hw == 14*14
new_idx, mask_x1, mask_x2, mask_y1, mask_y2 = mask_box
features = features.view(b, 14, 14, c)
features[:, mask_x1 : mask_x2, mask_y1 : mask_y2, :] = features[new_idx, mask_x1 : mask_x2, mask_y1 : mask_y2, :]
features = features.view(b, hw, c)
return features
class jigsaw_token:
def __init__(self, shift=5, group=2, begin=1):
self.shift = shift
self.group = group
self.begin = begin
def __call__(self, features):
batchsize = features.size(0)
dim = features.size(2)
num_tokens = features.size(1)
if num_tokens == 196:
self.group = 2
elif num_tokens == 25:
self.group = 5
else:
raise Exception("Jigsaw - Unwanted number of tokens")
# Shift Operation
feature_random = torch.cat([features[:, self.begin-1+self.shift:, :], features[:, self.begin-1:self.begin-1+self.shift, :]], dim=1)
x = feature_random
# Patch Shuffle Operation
try:
x = x.view(batchsize, self.group, -1, dim)
except:
raise Exception("Jigsaw - Unwanted number of groups")
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, dim)
return x
class get_mask_box:
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
self.shape = shape
self.mask_size = mask_size
self.mode = mode
def __call__(self, features):
# Stripe mask
if self.shape == 'stripe':
if self.mode == 'horizontal':
mask_box = self.hstripe(features, self.mask_size)
elif self.mode == 'vertical':
mask_box = self.vstripe(features, self.mask_size)
elif self.mode == 'random_direction':
if random.random() < 0.5:
mask_box = self.hstripe(features, self.mask_size)
else:
mask_box = self.vstripe(features, self.mask_size)
else:
raise Exception("Unknown stripe mask mode name")
# Square mask
elif self.shape == 'square':
if self.mode == 'random_size':
self.mask_size = 4 if random.random() < 0.5 else 5
mask_box = self.square(features, self.mask_size)
# Random stripe/square mask
elif self.shape == 'random':
random_num = random.random()
if random_num < 0.25:
mask_box = self.hstripe(features, 2)
elif random_num < 0.5 and random_num >= 0.25:
mask_box = self.vstripe(features, 2)
elif random_num < 0.75 and random_num >= 0.5:
mask_box = self.square(features, 4)
else:
mask_box = self.square(features, 5)
else:
raise Exception("Unknown mask shape name")
return mask_box
def hstripe(self, features, mask_size):
"""
"""
# horizontal stripe
mask_x1 = 0
mask_x2 = features.shape[2]
y1_max = features.shape[3] - mask_size
mask_y1 = torch.randint(y1_max, (1,))
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
return mask_box
def vstripe(self, features, mask_size):
"""
"""
# vertical stripe
mask_y1 = 0
mask_y2 = features.shape[3]
x1_max = features.shape[2] - mask_size
mask_x1 = torch.randint(x1_max, (1,))
mask_x2 = mask_x1 + mask_size
new_idx = torch.randperm(features.shape[0])
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
return mask_box
def square(self, features, mask_size):
"""
"""
# square
x1_max = features.shape[2] - mask_size
y1_max = features.shape[3] - mask_size
mask_x1 = torch.randint(x1_max, (1,))
mask_y1 = torch.randint(y1_max, (1,))
mask_x2 = mask_x1 + mask_size
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
return mask_box
class exchange_patch:
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
self.shape = shape
self.mask_size = mask_size
self.mode = mode
def __call__(self, features):
# Stripe mask
if self.shape == 'stripe':
if self.mode == 'horizontal':
features = self.xpatch_hstripe(features, self.mask_size)
elif self.mode == 'vertical':
features = self.xpatch_vstripe(features, self.mask_size)
elif self.mode == 'random_direction':
if random.random() < 0.5:
features = self.xpatch_hstripe(features, self.mask_size)
else:
features = self.xpatch_vstripe(features, self.mask_size)
else:
raise Exception("Unknown stripe mask mode name")
# Square mask
elif self.shape == 'square':
if self.mode == 'random_size':
self.mask_size = 4 if random.random() < 0.5 else 5
features = self.xpatch_square(features, self.mask_size)
# Random stripe/square mask
elif self.shape == 'random':
random_num = random.random()
if random_num < 0.25:
features = self.xpatch_hstripe(features, 2)
elif random_num < 0.5 and random_num >= 0.25:
features = self.xpatch_vstripe(features, 2)
elif random_num < 0.75 and random_num >= 0.5:
features = self.xpatch_square(features, 4)
else:
features = self.xpatch_square(features, 5)
else:
raise Exception("Unknown mask shape name")
return features
def xpatch_hstripe(self, features, mask_size):
"""
"""
# horizontal stripe
y1_max = features.shape[3] - mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (1,))
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, :, mask_y1 : mask_y2] = features[new_idx, :, :, mask_y1 : mask_y2]
return features
def xpatch_vstripe(self, features, mask_size):
"""
"""
# vertical stripe
x1_max = features.shape[2] - mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (1,))
mask_x2 = mask_x1 + mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, mask_x1 : mask_x2, :] = features[new_idx, :, mask_x1 : mask_x2, :]
return features
def xpatch_square(self, features, mask_size):
"""
"""
# square
x1_max = features.shape[2] - mask_size
y1_max = features.shape[3] - mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (1,))
mask_y1 = torch.randint(y1_max, (1,))
mask_x2 = mask_x1 + mask_size
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, mask_x1 : mask_x2, mask_y1 : mask_y2] = features[new_idx, :, mask_x1 : mask_x2, mask_y1 : mask_y2]
return features
class cutout_patch:
def __init__(self, mask_size=2):
self.mask_size = mask_size
def __call__(self, features):
if random.random() < 0.5:
y1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (features.shape[0],))
mask_y2 = mask_y1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, :, mask_y1[k] : mask_y2[k]] = 0
else:
x1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (features.shape[0],))
mask_x2 = mask_x1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, mask_x1[k] : mask_x2[k], :] = 0
return features
class erase_patch:
def __init__(self, mask_size=2):
self.mask_size = mask_size
def __call__(self, features):
std, mean = torch.std_mean(features.detach())
dim = features.shape[1]
if random.random() < 0.5:
y1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (features.shape[0],))
mask_y2 = mask_y1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, :, mask_y1[k] : mask_y2[k]] = torch.normal(mean.repeat(dim,14,2), std.repeat(dim,14,2))
else:
x1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (features.shape[0],))
mask_x2 = mask_x1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, mask_x1[k] : mask_x2[k], :] = torch.normal(mean.repeat(dim,2,14), std.repeat(dim,2,14))
return features
class mixup_patch:
def __init__(self, mask_size=2):
self.mask_size = mask_size
def __call__(self, features):
lam = random.uniform(0, 1)
if random.random() < 0.5:
y1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (1,))
mask_y2 = mask_y1 + self.mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, :, mask_y1 : mask_y2] = lam*features[:, :, :, mask_y1 : mask_y2] + (1-lam)*features[new_idx, :, :, mask_y1 : mask_y2]
else:
x1_max = features.shape[2] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (1,))
mask_x2 = mask_x1 + self.mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, mask_x1 : mask_x2, :] = lam*features[:, :, mask_x1 : mask_x2, :] + (1-lam)*features[new_idx, :, mask_x1 : mask_x2, :]
return features
class jigsaw_patch:
def __init__(self, shift=5, group=2):
self.shift = shift
self.group = group
def __call__(self, features):
batchsize = features.size(0)
dim = features.size(1)
features = features.view(batchsize, dim, -1)
# Shift Operation
feature_random = torch.cat([features[:, :, self.shift:], features[:, :, :self.shift]], dim=2)
x = feature_random
# Patch Shuffle Operation
try:
x = x.view(batchsize, dim, self.group, -1)
except:
x = torch.cat([x, x[:, -2:-1, :]], dim=1)
x = x.view(batchsize, self.group, -1, dim)
x = torch.transpose(x, 2, 3).contiguous()
x = x.view(batchsize, dim, -1)
x = x.view(batchsize, dim, 14, 14)
return x