77 lines
2.8 KiB
Python
77 lines
2.8 KiB
Python
|
# This file is part of COAT, and is distributed under the
|
||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||
|
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import autograd, nn
|
||
|
|
||
|
class OIM(autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, inputs, targets, lut, cq, header, momentum):
|
||
|
ctx.save_for_backward(inputs, targets, lut, cq, header, momentum)
|
||
|
outputs_labeled = inputs.mm(lut.t())
|
||
|
outputs_unlabeled = inputs.mm(cq.t())
|
||
|
return torch.cat([outputs_labeled, outputs_unlabeled], dim=1)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_outputs):
|
||
|
inputs, targets, lut, cq, header, momentum = ctx.saved_tensors
|
||
|
|
||
|
grad_inputs = None
|
||
|
if ctx.needs_input_grad[0]:
|
||
|
grad_inputs = grad_outputs.mm(torch.cat([lut, cq], dim=0))
|
||
|
if grad_inputs.dtype == torch.float16:
|
||
|
grad_inputs = grad_inputs.to(torch.float32)
|
||
|
|
||
|
for x, y in zip(inputs, targets):
|
||
|
if y < len(lut):
|
||
|
lut[y] = momentum * lut[y] + (1.0 - momentum) * x
|
||
|
lut[y] /= lut[y].norm()
|
||
|
else:
|
||
|
cq[header] = x
|
||
|
header = (header + 1) % cq.size(0)
|
||
|
return grad_inputs, None, None, None, None, None
|
||
|
|
||
|
|
||
|
def oim(inputs, targets, lut, cq, header, momentum=0.5):
|
||
|
return OIM.apply(inputs, targets, lut, cq, torch.tensor(header), torch.tensor(momentum))
|
||
|
|
||
|
|
||
|
class OIMLoss(nn.Module):
|
||
|
def __init__(self, num_features, num_pids, num_cq_size, oim_momentum, oim_scalar):
|
||
|
super(OIMLoss, self).__init__()
|
||
|
self.num_features = num_features
|
||
|
self.num_pids = num_pids
|
||
|
self.num_unlabeled = num_cq_size
|
||
|
self.momentum = oim_momentum
|
||
|
self.oim_scalar = oim_scalar
|
||
|
|
||
|
self.register_buffer("lut", torch.zeros(self.num_pids, self.num_features))
|
||
|
self.register_buffer("cq", torch.zeros(self.num_unlabeled, self.num_features))
|
||
|
|
||
|
self.header_cq = 0
|
||
|
|
||
|
def forward(self, inputs, roi_label):
|
||
|
# merge into one batch, background label = 0
|
||
|
targets = torch.cat(roi_label)
|
||
|
label = targets - 1 # background label = -1
|
||
|
|
||
|
inds = label >= 0
|
||
|
|
||
|
label = label[inds]
|
||
|
inputs = inputs[inds.unsqueeze(1).expand_as(inputs)].view(-1, self.num_features)
|
||
|
|
||
|
projected = oim(inputs, label, self.lut, self.cq, self.header_cq, momentum=self.momentum)
|
||
|
# projected - Tensor [M, lut+cq], e.g., [M, 482+500]=[M, 982]
|
||
|
|
||
|
projected *= self.oim_scalar
|
||
|
|
||
|
self.header_cq = (
|
||
|
self.header_cq + (label >= self.num_pids).long().sum().item()
|
||
|
) % self.num_unlabeled
|
||
|
|
||
|
loss_oim = F.cross_entropy(projected, label, ignore_index=5554)
|
||
|
|
||
|
return loss_oim, inputs, label
|