COAT/coat-pvtv2-b2/models/coat.py

766 lines
32 KiB
Python
Raw Permalink Normal View History

2023-10-10 21:52:30 +08:00
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import boxes as box_ops
from torchvision.models.detection import _utils as det_utils
from loss.oim import OIMLoss
from models.resnet import build_resnet
from models.transformer import TransformerHead
class COAT(nn.Module):
def __init__(self, cfg):
super(COAT, self).__init__()
backbone, _ = build_resnet(name="resnet50", pretrained=True)
anchor_generator = AnchorGenerator(
sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
)
head = RPNHead(
in_channels=backbone.out_channels,
num_anchors=anchor_generator.num_anchors_per_location()[0],
)
pre_nms_top_n = dict(
training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST
)
post_nms_top_n = dict(
training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST
)
rpn = RegionProposalNetwork(
anchor_generator=anchor_generator,
head=head,
fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN,
bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN,
batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN,
positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN,
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=cfg.MODEL.RPN.NMS_THRESH,
)
box_head = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_1ST,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_1ST,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_1ST,
)
box_head_2nd = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_2ND,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_2ND,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_2ND,
)
box_head_3rd = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_3RD,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_3RD,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_3RD,
)
faster_rcnn_predictor = FastRCNNPredictor(2048, 2)
box_roi_pool = MultiScaleRoIAlign(
featmap_names=["feat_res4"], output_size=14, sampling_ratio=2
)
box_predictor = BBoxRegressor(2048, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK)
roi_heads = CascadedROIHeads(
cfg=cfg,
# Cascade Transformer Head
faster_rcnn_predictor=faster_rcnn_predictor,
box_head_2nd=box_head_2nd,
box_head_3rd=box_head_3rd,
# parent class
box_roi_pool=box_roi_pool,
box_head=box_head,
box_predictor=box_predictor,
fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN,
bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN,
batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN,
positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN,
bbox_reg_weights=None,
score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST,
nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST,
detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST,
)
transform = GeneralizedRCNNTransform(
min_size=cfg.INPUT.MIN_SIZE,
max_size=cfg.INPUT.MAX_SIZE,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
)
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
self.transform = transform
self.eval_feat = cfg.EVAL_FEATURE
# loss weights
self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG
self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS
self.lw_rcnn_reg_1st = cfg.SOLVER.LW_RCNN_REG_1ST
self.lw_rcnn_cls_1st = cfg.SOLVER.LW_RCNN_CLS_1ST
self.lw_rcnn_reg_2nd = cfg.SOLVER.LW_RCNN_REG_2ND
self.lw_rcnn_cls_2nd = cfg.SOLVER.LW_RCNN_CLS_2ND
self.lw_rcnn_reg_3rd = cfg.SOLVER.LW_RCNN_REG_3RD
self.lw_rcnn_cls_3rd = cfg.SOLVER.LW_RCNN_CLS_3RD
self.lw_rcnn_reid_2nd = cfg.SOLVER.LW_RCNN_REID_2ND
self.lw_rcnn_reid_3rd = cfg.SOLVER.LW_RCNN_REID_3RD
def inference(self, images, targets=None, query_img_as_gallery=False):
original_image_sizes = [img.shape[-2:] for img in images]
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if query_img_as_gallery:
assert targets is not None
if targets is not None and not query_img_as_gallery:
# query
boxes = [t["boxes"] for t in targets]
box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes)
box_features_2nd = self.roi_heads.box_head_2nd(box_features)
embeddings_2nd, _ = self.roi_heads.embedding_head_2nd(box_features_2nd)
box_features_3rd = self.roi_heads.box_head_3rd(box_features)
embeddings_3rd, _ = self.roi_heads.embedding_head_3rd(box_features_3rd)
if self.eval_feat == 'concat':
embeddings = torch.cat((embeddings_2nd, embeddings_3rd), dim=1)
elif self.eval_feat == 'stage2':
embeddings = embeddings_2nd
elif self.eval_feat == 'stage3':
embeddings = embeddings_3rd
else:
raise Exception("Unknown evaluation feature name")
return embeddings.split(1, 0)
else:
# gallery
boxes, _ = self.rpn(images, features, targets)
detections = self.roi_heads(features, boxes, images.image_sizes, targets, query_img_as_gallery)[0]
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
return detections
def forward(self, images, targets=None, query_img_as_gallery=False):
if not self.training:
return self.inference(images, targets, query_img_as_gallery)
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
boxes, rpn_losses = self.rpn(images, features, targets)
_, rcnn_losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = self.roi_heads(features, boxes, images.image_sizes, targets)
# rename rpn losses to be consistent with detection losses
rpn_losses["loss_rpn_reg"] = rpn_losses.pop("loss_rpn_box_reg")
rpn_losses["loss_rpn_cls"] = rpn_losses.pop("loss_objectness")
losses = {}
losses.update(rcnn_losses)
losses.update(rpn_losses)
# apply loss weights
losses["loss_rpn_reg"] *= self.lw_rpn_reg
losses["loss_rpn_cls"] *= self.lw_rpn_cls
losses["loss_rcnn_reg_1st"] *= self.lw_rcnn_reg_1st
losses["loss_rcnn_cls_1st"] *= self.lw_rcnn_cls_1st
losses["loss_rcnn_reg_2nd"] *= self.lw_rcnn_reg_2nd
losses["loss_rcnn_cls_2nd"] *= self.lw_rcnn_cls_2nd
losses["loss_rcnn_reg_3rd"] *= self.lw_rcnn_reg_3rd
losses["loss_rcnn_cls_3rd"] *= self.lw_rcnn_cls_3rd
losses["loss_rcnn_reid_2nd"] *= self.lw_rcnn_reid_2nd
losses["loss_rcnn_reid_3rd"] *= self.lw_rcnn_reid_3rd
return losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
class CascadedROIHeads(RoIHeads):
'''
https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py
'''
def __init__(
self,
cfg,
faster_rcnn_predictor,
box_head_2nd,
box_head_3rd,
*args,
**kwargs
):
super(CascadedROIHeads, self).__init__(*args, **kwargs)
# ROI head
self.use_diff_thresh=cfg.MODEL.ROI_HEAD.USE_DIFF_THRESH
self.nms_thresh_1st = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST
self.nms_thresh_2nd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND
self.nms_thresh_3rd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD
self.fg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN
self.bg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN
self.fg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND
self.bg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND
self.fg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD
self.bg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD
# Regression head
self.box_predictor_1st = faster_rcnn_predictor
self.box_predictor_2nd = self.box_predictor
self.box_predictor_3rd = deepcopy(self.box_predictor)
# Transformer head
self.box_head_1st = self.box_head
self.box_head_2nd = box_head_2nd
self.box_head_3rd = box_head_3rd
# feature mask
self.use_feature_mask = cfg.MODEL.USE_FEATURE_MASK
self.feature_mask_size = cfg.MODEL.FEATURE_MASK_SIZE
# Feature embedding
embedding_dim = cfg.MODEL.EMBEDDING_DIM
self.embedding_head_2nd = NormAwareEmbedding(featmap_names=["before_trans", "after_trans"], in_channels=[1024, 2048], dim=embedding_dim)
self.embedding_head_3rd = deepcopy(self.embedding_head_2nd)
# OIM
num_pids = cfg.MODEL.LOSS.LUT_SIZE
num_cq_size = cfg.MODEL.LOSS.CQ_SIZE
oim_momentum = cfg.MODEL.LOSS.OIM_MOMENTUM
oim_scalar = cfg.MODEL.LOSS.OIM_SCALAR
self.reid_loss_2nd = OIMLoss(embedding_dim, num_pids, num_cq_size, oim_momentum, oim_scalar)
self.reid_loss_3rd = deepcopy(self.reid_loss_2nd)
# rename the method inherited from parent class
self.postprocess_proposals = self.postprocess_detections
# evaluation
self.eval_feat = cfg.EVAL_FEATURE
def forward(self, features, boxes, image_shapes, targets=None, query_img_as_gallery=False):
"""
Arguments:
features (List[Tensor])
boxes (List[Tensor[N, 4]])
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
cws = True
gt_det_2nd = None
gt_det_3rd = None
feats_reid_2nd = None
feats_reid_3rd = None
targets_reid_2nd = None
targets_reid_3rd = None
if self.training:
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_1st,
self.bg_iou_thresh_1st,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_1st, box_reg_targets_1st = self.select_training_samples(
boxes, targets
)
# ------------------- The first stage ------------------ #
box_features_1st = self.box_roi_pool(features, boxes, image_shapes)
box_features_1st = self.box_head_1st(box_features_1st)
box_cls_scores_1st, box_regs_1st = self.box_predictor_1st(box_features_1st["after_trans"])
if self.training:
boxes = self.get_boxes(box_regs_1st, boxes, image_shapes)
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_2nd,
self.bg_iou_thresh_2nd,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_2nd, box_reg_targets_2nd = self.select_training_samples(boxes, targets)
else:
orig_thresh = self.nms_thresh # 0.4
self.nms_thresh = self.nms_thresh_1st
boxes, scores, _ = self.postprocess_proposals(
box_cls_scores_1st, box_regs_1st, boxes, image_shapes
)
if not self.training and query_img_as_gallery:
# When regarding the query image as gallery, GT boxes may be excluded
# from detected boxes. To avoid this, we compulsorily include GT in the
# detection results. Additionally, CWS should be disabled as the
# confidences of these people in query image are 1
cws = False
gt_box = [targets[0]["boxes"]]
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
gt_box_features = self.box_head_2nd(gt_box_features)
embeddings, _ = self.embedding_head_2nd(gt_box_features)
gt_det_2nd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
# no detection predicted by Faster R-CNN head in test phase
if boxes[0].shape[0] == 0:
assert not self.training
boxes = gt_det_2nd["boxes"] if gt_det_2nd else torch.zeros(0, 4)
labels = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
scores = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
if self.eval_feat == 'concat':
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_2nd["embeddings"]), dim=1) if gt_det_2nd else torch.zeros(0, 512)
elif self.eval_feat == 'stage2' or self.eval_feat == 'stage3':
embeddings = gt_det_2nd["embeddings"] if gt_det_2nd else torch.zeros(0, 256)
else:
raise Exception("Unknown evaluation feature name")
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
# --------------------- The second stage -------------------- #
box_features = self.box_roi_pool(features, boxes, image_shapes)
box_features = self.box_head_2nd(box_features)
box_regs_2nd = self.box_predictor_2nd(box_features["after_trans"])
box_embeddings_2nd, box_cls_scores_2nd = self.embedding_head_2nd(box_features)
if box_cls_scores_2nd.dim() == 0:
box_cls_scores_2nd = box_cls_scores_2nd.unsqueeze(0)
if self.training:
boxes = self.get_boxes(box_regs_2nd, boxes, image_shapes)
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_3rd,
self.bg_iou_thresh_3rd,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_3rd, box_reg_targets_3rd = self.select_training_samples(boxes, targets)
else:
self.nms_thresh = self.nms_thresh_2nd
if self.eval_feat != 'stage2':
boxes, scores, _, _ = self.postprocess_boxes(
box_cls_scores_2nd,
box_regs_2nd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=None,
cws=cws,
)
if not self.training and query_img_as_gallery and self.eval_feat != 'stage2':
cws = False
gt_box = [targets[0]["boxes"]]
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
gt_box_features = self.box_head_3rd(gt_box_features)
embeddings, _ = self.embedding_head_3rd(gt_box_features)
gt_det_3rd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
# no detection predicted by Faster R-CNN head in test phase
if boxes[0].shape[0] == 0 and self.eval_feat != 'stage2':
assert not self.training
boxes = gt_det_3rd["boxes"] if gt_det_3rd else torch.zeros(0, 4)
labels = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
scores = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
if self.eval_feat == 'concat':
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_3rd["embeddings"]), dim=1) if gt_det_3rd else torch.zeros(0, 512)
elif self.eval_feat == 'stage3':
embeddings = gt_det_2nd["embeddings"] if gt_det_3rd else torch.zeros(0, 256)
else:
raise Exception("Unknown evaluation feature name")
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
# --------------------- The third stage -------------------- #
box_features = self.box_roi_pool(features, boxes, image_shapes)
if not self.training:
box_features_2nd = self.box_head_2nd(box_features)
box_embeddings_2nd, _ = self.embedding_head_2nd(box_features_2nd)
box_features = self.box_head_3rd(box_features)
box_regs_3rd = self.box_predictor_3rd(box_features["after_trans"])
box_embeddings_3rd, box_cls_scores_3rd = self.embedding_head_3rd(box_features)
if box_cls_scores_3rd.dim() == 0:
box_cls_scores_3rd = box_cls_scores_3rd.unsqueeze(0)
result, losses = [], {}
if self.training:
box_labels_1st = [y.clamp(0, 1) for y in box_pid_labels_1st]
box_labels_2nd = [y.clamp(0, 1) for y in box_pid_labels_2nd]
box_labels_3rd = [y.clamp(0, 1) for y in box_pid_labels_3rd]
losses = detection_losses(
box_cls_scores_1st,
box_regs_1st,
box_labels_1st,
box_reg_targets_1st,
box_cls_scores_2nd,
box_regs_2nd,
box_labels_2nd,
box_reg_targets_2nd,
box_cls_scores_3rd,
box_regs_3rd,
box_labels_3rd,
box_reg_targets_3rd,
)
loss_rcnn_reid_2nd, feats_reid_2nd, targets_reid_2nd = self.reid_loss_2nd(box_embeddings_2nd, box_pid_labels_2nd)
loss_rcnn_reid_3rd, feats_reid_3rd, targets_reid_3rd = self.reid_loss_3rd(box_embeddings_3rd, box_pid_labels_3rd)
losses.update(loss_rcnn_reid_2nd=loss_rcnn_reid_2nd)
losses.update(loss_rcnn_reid_3rd=loss_rcnn_reid_3rd)
else:
if self.eval_feat == 'stage2':
boxes, scores, embeddings_2nd, labels = self.postprocess_boxes(
box_cls_scores_2nd,
box_regs_2nd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_2nd,
cws=cws,
)
else:
self.nms_thresh = self.nms_thresh_3rd
_, _, embeddings_2nd, _ = self.postprocess_boxes(
box_cls_scores_3rd,
box_regs_3rd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_2nd,
cws=cws,
)
boxes, scores, embeddings_3rd, labels = self.postprocess_boxes(
box_cls_scores_3rd,
box_regs_3rd,
box_embeddings_3rd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_3rd,
cws=cws,
)
# set to original thresh after finishing postprocess
self.nms_thresh = orig_thresh
num_images = len(boxes)
for i in range(num_images):
if self.eval_feat == 'concat':
embeddings = torch.cat((embeddings_2nd[i],embeddings_3rd[i]), dim=1)
elif self.eval_feat == 'stage2':
embeddings = embeddings_2nd[i]
elif self.eval_feat == 'stage3':
embeddings = embeddings_3rd[i]
else:
raise Exception("Unknown evaluation feature name")
result.append(
dict(
boxes=boxes[i],
labels=labels[i],
scores=scores[i],
embeddings=embeddings
)
)
return result, losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
def get_boxes(self, box_regression, proposals, image_shapes):
"""
Get boxes from proposals.
"""
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
pred_boxes = pred_boxes.split(boxes_per_image, 0)
all_boxes = []
for boxes, image_shape in zip(pred_boxes, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# remove predictions with the background label
boxes = boxes[:, 1:].reshape(-1, 4)
all_boxes.append(boxes)
return all_boxes
def postprocess_boxes(
self,
class_logits,
box_regression,
embeddings,
proposals,
image_shapes,
fcs=None,
gt_det=None,
cws=True,
):
"""
Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement
First Classification Score (FCS).
"""
device = class_logits.device
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
if fcs is not None:
# Fist Classification Score (FCS)
pred_scores = fcs[0]
else:
pred_scores = torch.sigmoid(class_logits)
if cws:
# Confidence Weighted Similarity (CWS)
embeddings = embeddings * pred_scores.view(-1, 1)
# split boxes and scores per image
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0)
pred_embeddings = embeddings.split(boxes_per_image, 0)
all_boxes = []
all_scores = []
all_labels = []
all_embeddings = []
for boxes, scores, embeddings, image_shape in zip(
pred_boxes, pred_scores, pred_embeddings, image_shapes
):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# create labels for each prediction
labels = torch.ones(scores.size(0), device=device)
# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores.unsqueeze(1)
labels = labels.unsqueeze(1)
# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.flatten()
labels = labels.flatten()
embeddings = embeddings.reshape(-1, self.embedding_head_2nd.dim)
# remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
boxes, scores, labels, embeddings = (
boxes[inds],
scores[inds],
labels[inds],
embeddings[inds],
)
# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
boxes, scores, labels, embeddings = (
boxes[keep],
scores[keep],
labels[keep],
embeddings[keep],
)
if gt_det is not None:
# include GT into the detection results
boxes = torch.cat((boxes, gt_det["boxes"]), dim=0)
labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0)
scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0)
embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0)
# non-maximum suppression, independently done per class
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[: self.detections_per_img]
boxes, scores, labels, embeddings = (
boxes[keep],
scores[keep],
labels[keep],
embeddings[keep],
)
all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)
all_embeddings.append(embeddings)
return all_boxes, all_scores, all_embeddings, all_labels
class NormAwareEmbedding(nn.Module):
"""
Implements the Norm-Aware Embedding proposed in
Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020.
"""
def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256):
super(NormAwareEmbedding, self).__init__()
self.featmap_names = featmap_names
self.in_channels = in_channels
self.dim = dim
self.projectors = nn.ModuleDict()
indv_dims = self._split_embedding_dim()
for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims):
proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim))
init.normal_(proj[0].weight, std=0.01)
init.normal_(proj[1].weight, std=0.01)
init.constant_(proj[0].bias, 0)
init.constant_(proj[1].bias, 0)
self.projectors[ftname] = proj
self.rescaler = nn.BatchNorm1d(1, affine=True)
def forward(self, featmaps):
"""
Arguments:
featmaps: OrderedDict[Tensor], and in featmap_names you can choose which
featmaps to use
Returns:
tensor of size (BatchSize, dim), L2 normalized embeddings.
tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits.
"""
assert len(featmaps) == len(self.featmap_names)
if len(featmaps) == 1:
k, v = featmaps.items()[0]
v = self._flatten_fc_input(v)
embeddings = self.projectors[k](v)
norms = embeddings.norm(2, 1, keepdim=True)
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
norms = self.rescaler(norms).squeeze()
return embeddings, norms
else:
outputs = []
for k, v in featmaps.items():
v = self._flatten_fc_input(v)
outputs.append(self.projectors[k](v))
embeddings = torch.cat(outputs, dim=1)
norms = embeddings.norm(2, 1, keepdim=True)
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
norms = self.rescaler(norms).squeeze()
return embeddings, norms
def _flatten_fc_input(self, x):
if x.ndimension() == 4:
assert list(x.shape[2:]) == [1, 1]
return x.flatten(start_dim=1)
return x
def _split_embedding_dim(self):
parts = len(self.in_channels)
tmp = [self.dim // parts] * parts
if sum(tmp) == self.dim:
return tmp
else:
res = self.dim % parts
for i in range(1, res + 1):
tmp[-i] += 1
assert sum(tmp) == self.dim
return tmp
class BBoxRegressor(nn.Module):
"""
Bounding box regression layer.
"""
def __init__(self, in_channels, num_classes=2, bn_neck=True):
"""
Args:
in_channels (int): Input channels.
num_classes (int, optional): Defaults to 2 (background and pedestrian).
bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True.
"""
super(BBoxRegressor, self).__init__()
if bn_neck:
self.bbox_pred = nn.Sequential(
nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes)
)
init.normal_(self.bbox_pred[0].weight, std=0.01)
init.normal_(self.bbox_pred[1].weight, std=0.01)
init.constant_(self.bbox_pred[0].bias, 0)
init.constant_(self.bbox_pred[1].bias, 0)
else:
self.bbox_pred = nn.Linear(in_channels, 4 * num_classes)
init.normal_(self.bbox_pred.weight, std=0.01)
init.constant_(self.bbox_pred.bias, 0)
def forward(self, x):
if x.ndimension() == 4:
if list(x.shape[2:]) != [1, 1]:
x = F.adaptive_avg_pool2d(x, output_size=1)
x = x.flatten(start_dim=1)
bbox_deltas = self.bbox_pred(x)
return bbox_deltas
def detection_losses(
box_cls_scores_1st,
box_regs_1st,
box_labels_1st,
box_reg_targets_1st,
box_cls_scores_2nd,
box_regs_2nd,
box_labels_2nd,
box_reg_targets_2nd,
box_cls_scores_3rd,
box_regs_3rd,
box_labels_3rd,
box_reg_targets_3rd,
):
# --------------------- The first stage -------------------- #
box_labels_1st = torch.cat(box_labels_1st, dim=0)
box_reg_targets_1st = torch.cat(box_reg_targets_1st, dim=0)
loss_rcnn_cls_1st = F.cross_entropy(box_cls_scores_1st, box_labels_1st)
# get indices that correspond to the regression targets for the
# corresponding ground truth labels, to be used with advanced indexing
sampled_pos_inds_subset = torch.nonzero(box_labels_1st > 0).squeeze(1)
labels_pos = box_labels_1st[sampled_pos_inds_subset]
N = box_cls_scores_1st.size(0)
box_regs_1st = box_regs_1st.reshape(N, -1, 4)
loss_rcnn_reg_1st = F.smooth_l1_loss(
box_regs_1st[sampled_pos_inds_subset, labels_pos],
box_reg_targets_1st[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_1st = loss_rcnn_reg_1st / box_labels_1st.numel()
# --------------------- The second stage -------------------- #
box_labels_2nd = torch.cat(box_labels_2nd, dim=0)
box_reg_targets_2nd = torch.cat(box_reg_targets_2nd, dim=0)
loss_rcnn_cls_2nd = F.binary_cross_entropy_with_logits(box_cls_scores_2nd, box_labels_2nd.float())
sampled_pos_inds_subset = torch.nonzero(box_labels_2nd > 0).squeeze(1)
labels_pos = box_labels_2nd[sampled_pos_inds_subset]
N = box_cls_scores_2nd.size(0)
box_regs_2nd = box_regs_2nd.reshape(N, -1, 4)
loss_rcnn_reg_2nd = F.smooth_l1_loss(
box_regs_2nd[sampled_pos_inds_subset, labels_pos],
box_reg_targets_2nd[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_2nd = loss_rcnn_reg_2nd / box_labels_2nd.numel()
# --------------------- The third stage -------------------- #
box_labels_3rd = torch.cat(box_labels_3rd, dim=0)
box_reg_targets_3rd = torch.cat(box_reg_targets_3rd, dim=0)
loss_rcnn_cls_3rd = F.binary_cross_entropy_with_logits(box_cls_scores_3rd, box_labels_3rd.float())
sampled_pos_inds_subset = torch.nonzero(box_labels_3rd > 0).squeeze(1)
labels_pos = box_labels_3rd[sampled_pos_inds_subset]
N = box_cls_scores_3rd.size(0)
box_regs_3rd = box_regs_3rd.reshape(N, -1, 4)
loss_rcnn_reg_3rd = F.smooth_l1_loss(
box_regs_3rd[sampled_pos_inds_subset, labels_pos],
box_reg_targets_3rd[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_3rd = loss_rcnn_reg_3rd / box_labels_3rd.numel()
return dict(
loss_rcnn_cls_1st=loss_rcnn_cls_1st,
loss_rcnn_reg_1st=loss_rcnn_reg_1st,
loss_rcnn_cls_2nd=loss_rcnn_cls_2nd,
loss_rcnn_reg_2nd=loss_rcnn_reg_2nd,
loss_rcnn_cls_3rd=loss_rcnn_cls_3rd,
loss_rcnn_reg_3rd=loss_rcnn_reg_3rd,
)