766 lines
32 KiB
Python
766 lines
32 KiB
Python
# This file is part of COAT, and is distributed under the
|
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
|
|
|
from copy import deepcopy
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn import init
|
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
|
from torchvision.models.detection.roi_heads import RoIHeads
|
|
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
|
|
from torchvision.models.detection.transform import GeneralizedRCNNTransform
|
|
from torchvision.ops import MultiScaleRoIAlign
|
|
from torchvision.ops import boxes as box_ops
|
|
from torchvision.models.detection import _utils as det_utils
|
|
|
|
from loss.oim import OIMLoss
|
|
from models.resnet import build_resnet
|
|
from models.transformer import TransformerHead
|
|
|
|
|
|
class COAT(nn.Module):
|
|
def __init__(self, cfg):
|
|
super(COAT, self).__init__()
|
|
|
|
backbone, _ = build_resnet(name="resnet50", pretrained=True)
|
|
anchor_generator = AnchorGenerator(
|
|
sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
|
|
)
|
|
head = RPNHead(
|
|
in_channels=backbone.out_channels,
|
|
num_anchors=anchor_generator.num_anchors_per_location()[0],
|
|
)
|
|
pre_nms_top_n = dict(
|
|
training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST
|
|
)
|
|
post_nms_top_n = dict(
|
|
training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST
|
|
)
|
|
rpn = RegionProposalNetwork(
|
|
anchor_generator=anchor_generator,
|
|
head=head,
|
|
fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN,
|
|
bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN,
|
|
batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN,
|
|
positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN,
|
|
pre_nms_top_n=pre_nms_top_n,
|
|
post_nms_top_n=post_nms_top_n,
|
|
nms_thresh=cfg.MODEL.RPN.NMS_THRESH,
|
|
)
|
|
|
|
box_head = TransformerHead(
|
|
cfg=cfg,
|
|
trans_names=cfg.MODEL.TRANSFORMER.NAMES_1ST,
|
|
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_1ST,
|
|
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_1ST,
|
|
)
|
|
box_head_2nd = TransformerHead(
|
|
cfg=cfg,
|
|
trans_names=cfg.MODEL.TRANSFORMER.NAMES_2ND,
|
|
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_2ND,
|
|
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_2ND,
|
|
)
|
|
box_head_3rd = TransformerHead(
|
|
cfg=cfg,
|
|
trans_names=cfg.MODEL.TRANSFORMER.NAMES_3RD,
|
|
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_3RD,
|
|
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_3RD,
|
|
)
|
|
|
|
faster_rcnn_predictor = FastRCNNPredictor(2048, 2)
|
|
box_roi_pool = MultiScaleRoIAlign(
|
|
featmap_names=["feat_res4"], output_size=14, sampling_ratio=2
|
|
)
|
|
box_predictor = BBoxRegressor(2048, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK)
|
|
roi_heads = CascadedROIHeads(
|
|
cfg=cfg,
|
|
# Cascade Transformer Head
|
|
faster_rcnn_predictor=faster_rcnn_predictor,
|
|
box_head_2nd=box_head_2nd,
|
|
box_head_3rd=box_head_3rd,
|
|
# parent class
|
|
box_roi_pool=box_roi_pool,
|
|
box_head=box_head,
|
|
box_predictor=box_predictor,
|
|
fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN,
|
|
bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN,
|
|
batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN,
|
|
positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN,
|
|
bbox_reg_weights=None,
|
|
score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST,
|
|
nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST,
|
|
detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST,
|
|
)
|
|
|
|
transform = GeneralizedRCNNTransform(
|
|
min_size=cfg.INPUT.MIN_SIZE,
|
|
max_size=cfg.INPUT.MAX_SIZE,
|
|
image_mean=[0.485, 0.456, 0.406],
|
|
image_std=[0.229, 0.224, 0.225],
|
|
)
|
|
|
|
self.backbone = backbone
|
|
self.rpn = rpn
|
|
self.roi_heads = roi_heads
|
|
self.transform = transform
|
|
self.eval_feat = cfg.EVAL_FEATURE
|
|
|
|
# loss weights
|
|
self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG
|
|
self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS
|
|
self.lw_rcnn_reg_1st = cfg.SOLVER.LW_RCNN_REG_1ST
|
|
self.lw_rcnn_cls_1st = cfg.SOLVER.LW_RCNN_CLS_1ST
|
|
self.lw_rcnn_reg_2nd = cfg.SOLVER.LW_RCNN_REG_2ND
|
|
self.lw_rcnn_cls_2nd = cfg.SOLVER.LW_RCNN_CLS_2ND
|
|
self.lw_rcnn_reg_3rd = cfg.SOLVER.LW_RCNN_REG_3RD
|
|
self.lw_rcnn_cls_3rd = cfg.SOLVER.LW_RCNN_CLS_3RD
|
|
self.lw_rcnn_reid_2nd = cfg.SOLVER.LW_RCNN_REID_2ND
|
|
self.lw_rcnn_reid_3rd = cfg.SOLVER.LW_RCNN_REID_3RD
|
|
|
|
def inference(self, images, targets=None, query_img_as_gallery=False):
|
|
original_image_sizes = [img.shape[-2:] for img in images]
|
|
images, targets = self.transform(images, targets)
|
|
features = self.backbone(images.tensors)
|
|
|
|
if query_img_as_gallery:
|
|
assert targets is not None
|
|
|
|
if targets is not None and not query_img_as_gallery:
|
|
# query
|
|
boxes = [t["boxes"] for t in targets]
|
|
box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes)
|
|
box_features_2nd = self.roi_heads.box_head_2nd(box_features)
|
|
embeddings_2nd, _ = self.roi_heads.embedding_head_2nd(box_features_2nd)
|
|
box_features_3rd = self.roi_heads.box_head_3rd(box_features)
|
|
embeddings_3rd, _ = self.roi_heads.embedding_head_3rd(box_features_3rd)
|
|
if self.eval_feat == 'concat':
|
|
embeddings = torch.cat((embeddings_2nd, embeddings_3rd), dim=1)
|
|
elif self.eval_feat == 'stage2':
|
|
embeddings = embeddings_2nd
|
|
elif self.eval_feat == 'stage3':
|
|
embeddings = embeddings_3rd
|
|
else:
|
|
raise Exception("Unknown evaluation feature name")
|
|
return embeddings.split(1, 0)
|
|
else:
|
|
# gallery
|
|
boxes, _ = self.rpn(images, features, targets)
|
|
detections = self.roi_heads(features, boxes, images.image_sizes, targets, query_img_as_gallery)[0]
|
|
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
|
|
return detections
|
|
|
|
def forward(self, images, targets=None, query_img_as_gallery=False):
|
|
if not self.training:
|
|
return self.inference(images, targets, query_img_as_gallery)
|
|
images, targets = self.transform(images, targets)
|
|
features = self.backbone(images.tensors)
|
|
boxes, rpn_losses = self.rpn(images, features, targets)
|
|
|
|
_, rcnn_losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = self.roi_heads(features, boxes, images.image_sizes, targets)
|
|
|
|
# rename rpn losses to be consistent with detection losses
|
|
rpn_losses["loss_rpn_reg"] = rpn_losses.pop("loss_rpn_box_reg")
|
|
rpn_losses["loss_rpn_cls"] = rpn_losses.pop("loss_objectness")
|
|
|
|
losses = {}
|
|
losses.update(rcnn_losses)
|
|
losses.update(rpn_losses)
|
|
|
|
# apply loss weights
|
|
losses["loss_rpn_reg"] *= self.lw_rpn_reg
|
|
losses["loss_rpn_cls"] *= self.lw_rpn_cls
|
|
losses["loss_rcnn_reg_1st"] *= self.lw_rcnn_reg_1st
|
|
losses["loss_rcnn_cls_1st"] *= self.lw_rcnn_cls_1st
|
|
losses["loss_rcnn_reg_2nd"] *= self.lw_rcnn_reg_2nd
|
|
losses["loss_rcnn_cls_2nd"] *= self.lw_rcnn_cls_2nd
|
|
losses["loss_rcnn_reg_3rd"] *= self.lw_rcnn_reg_3rd
|
|
losses["loss_rcnn_cls_3rd"] *= self.lw_rcnn_cls_3rd
|
|
losses["loss_rcnn_reid_2nd"] *= self.lw_rcnn_reid_2nd
|
|
losses["loss_rcnn_reid_3rd"] *= self.lw_rcnn_reid_3rd
|
|
|
|
return losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
|
|
|
|
class CascadedROIHeads(RoIHeads):
|
|
'''
|
|
https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py
|
|
'''
|
|
def __init__(
|
|
self,
|
|
cfg,
|
|
faster_rcnn_predictor,
|
|
box_head_2nd,
|
|
box_head_3rd,
|
|
*args,
|
|
**kwargs
|
|
):
|
|
super(CascadedROIHeads, self).__init__(*args, **kwargs)
|
|
|
|
# ROI head
|
|
self.use_diff_thresh=cfg.MODEL.ROI_HEAD.USE_DIFF_THRESH
|
|
self.nms_thresh_1st = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST
|
|
self.nms_thresh_2nd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND
|
|
self.nms_thresh_3rd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD
|
|
self.fg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN
|
|
self.bg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN
|
|
self.fg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND
|
|
self.bg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND
|
|
self.fg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD
|
|
self.bg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD
|
|
|
|
# Regression head
|
|
self.box_predictor_1st = faster_rcnn_predictor
|
|
self.box_predictor_2nd = self.box_predictor
|
|
self.box_predictor_3rd = deepcopy(self.box_predictor)
|
|
|
|
# Transformer head
|
|
self.box_head_1st = self.box_head
|
|
self.box_head_2nd = box_head_2nd
|
|
self.box_head_3rd = box_head_3rd
|
|
|
|
# feature mask
|
|
self.use_feature_mask = cfg.MODEL.USE_FEATURE_MASK
|
|
self.feature_mask_size = cfg.MODEL.FEATURE_MASK_SIZE
|
|
|
|
# Feature embedding
|
|
embedding_dim = cfg.MODEL.EMBEDDING_DIM
|
|
self.embedding_head_2nd = NormAwareEmbedding(featmap_names=["before_trans", "after_trans"], in_channels=[1024, 2048], dim=embedding_dim)
|
|
self.embedding_head_3rd = deepcopy(self.embedding_head_2nd)
|
|
|
|
# OIM
|
|
num_pids = cfg.MODEL.LOSS.LUT_SIZE
|
|
num_cq_size = cfg.MODEL.LOSS.CQ_SIZE
|
|
oim_momentum = cfg.MODEL.LOSS.OIM_MOMENTUM
|
|
oim_scalar = cfg.MODEL.LOSS.OIM_SCALAR
|
|
self.reid_loss_2nd = OIMLoss(embedding_dim, num_pids, num_cq_size, oim_momentum, oim_scalar)
|
|
self.reid_loss_3rd = deepcopy(self.reid_loss_2nd)
|
|
|
|
# rename the method inherited from parent class
|
|
self.postprocess_proposals = self.postprocess_detections
|
|
|
|
# evaluation
|
|
self.eval_feat = cfg.EVAL_FEATURE
|
|
|
|
def forward(self, features, boxes, image_shapes, targets=None, query_img_as_gallery=False):
|
|
"""
|
|
Arguments:
|
|
features (List[Tensor])
|
|
boxes (List[Tensor[N, 4]])
|
|
image_shapes (List[Tuple[H, W]])
|
|
targets (List[Dict])
|
|
"""
|
|
cws = True
|
|
gt_det_2nd = None
|
|
gt_det_3rd = None
|
|
feats_reid_2nd = None
|
|
feats_reid_3rd = None
|
|
targets_reid_2nd = None
|
|
targets_reid_3rd = None
|
|
|
|
if self.training:
|
|
if self.use_diff_thresh:
|
|
self.proposal_matcher = det_utils.Matcher(
|
|
self.fg_iou_thresh_1st,
|
|
self.bg_iou_thresh_1st,
|
|
allow_low_quality_matches=False)
|
|
boxes, _, box_pid_labels_1st, box_reg_targets_1st = self.select_training_samples(
|
|
boxes, targets
|
|
)
|
|
|
|
# ------------------- The first stage ------------------ #
|
|
box_features_1st = self.box_roi_pool(features, boxes, image_shapes)
|
|
box_features_1st = self.box_head_1st(box_features_1st)
|
|
box_cls_scores_1st, box_regs_1st = self.box_predictor_1st(box_features_1st["after_trans"])
|
|
|
|
if self.training:
|
|
boxes = self.get_boxes(box_regs_1st, boxes, image_shapes)
|
|
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
|
|
if self.use_diff_thresh:
|
|
self.proposal_matcher = det_utils.Matcher(
|
|
self.fg_iou_thresh_2nd,
|
|
self.bg_iou_thresh_2nd,
|
|
allow_low_quality_matches=False)
|
|
boxes, _, box_pid_labels_2nd, box_reg_targets_2nd = self.select_training_samples(boxes, targets)
|
|
else:
|
|
orig_thresh = self.nms_thresh # 0.4
|
|
self.nms_thresh = self.nms_thresh_1st
|
|
boxes, scores, _ = self.postprocess_proposals(
|
|
box_cls_scores_1st, box_regs_1st, boxes, image_shapes
|
|
)
|
|
|
|
if not self.training and query_img_as_gallery:
|
|
# When regarding the query image as gallery, GT boxes may be excluded
|
|
# from detected boxes. To avoid this, we compulsorily include GT in the
|
|
# detection results. Additionally, CWS should be disabled as the
|
|
# confidences of these people in query image are 1
|
|
cws = False
|
|
gt_box = [targets[0]["boxes"]]
|
|
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
|
|
gt_box_features = self.box_head_2nd(gt_box_features)
|
|
embeddings, _ = self.embedding_head_2nd(gt_box_features)
|
|
gt_det_2nd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
|
|
|
|
# no detection predicted by Faster R-CNN head in test phase
|
|
if boxes[0].shape[0] == 0:
|
|
assert not self.training
|
|
boxes = gt_det_2nd["boxes"] if gt_det_2nd else torch.zeros(0, 4)
|
|
labels = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
|
|
scores = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
|
|
if self.eval_feat == 'concat':
|
|
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_2nd["embeddings"]), dim=1) if gt_det_2nd else torch.zeros(0, 512)
|
|
elif self.eval_feat == 'stage2' or self.eval_feat == 'stage3':
|
|
embeddings = gt_det_2nd["embeddings"] if gt_det_2nd else torch.zeros(0, 256)
|
|
else:
|
|
raise Exception("Unknown evaluation feature name")
|
|
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
|
|
|
|
# --------------------- The second stage -------------------- #
|
|
box_features = self.box_roi_pool(features, boxes, image_shapes)
|
|
box_features = self.box_head_2nd(box_features)
|
|
box_regs_2nd = self.box_predictor_2nd(box_features["after_trans"])
|
|
box_embeddings_2nd, box_cls_scores_2nd = self.embedding_head_2nd(box_features)
|
|
if box_cls_scores_2nd.dim() == 0:
|
|
box_cls_scores_2nd = box_cls_scores_2nd.unsqueeze(0)
|
|
|
|
if self.training:
|
|
boxes = self.get_boxes(box_regs_2nd, boxes, image_shapes)
|
|
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
|
|
if self.use_diff_thresh:
|
|
self.proposal_matcher = det_utils.Matcher(
|
|
self.fg_iou_thresh_3rd,
|
|
self.bg_iou_thresh_3rd,
|
|
allow_low_quality_matches=False)
|
|
boxes, _, box_pid_labels_3rd, box_reg_targets_3rd = self.select_training_samples(boxes, targets)
|
|
else:
|
|
self.nms_thresh = self.nms_thresh_2nd
|
|
if self.eval_feat != 'stage2':
|
|
boxes, scores, _, _ = self.postprocess_boxes(
|
|
box_cls_scores_2nd,
|
|
box_regs_2nd,
|
|
box_embeddings_2nd,
|
|
boxes,
|
|
image_shapes,
|
|
fcs=scores,
|
|
gt_det=None,
|
|
cws=cws,
|
|
)
|
|
|
|
if not self.training and query_img_as_gallery and self.eval_feat != 'stage2':
|
|
cws = False
|
|
gt_box = [targets[0]["boxes"]]
|
|
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
|
|
gt_box_features = self.box_head_3rd(gt_box_features)
|
|
embeddings, _ = self.embedding_head_3rd(gt_box_features)
|
|
gt_det_3rd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
|
|
|
|
# no detection predicted by Faster R-CNN head in test phase
|
|
if boxes[0].shape[0] == 0 and self.eval_feat != 'stage2':
|
|
assert not self.training
|
|
boxes = gt_det_3rd["boxes"] if gt_det_3rd else torch.zeros(0, 4)
|
|
labels = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
|
|
scores = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
|
|
if self.eval_feat == 'concat':
|
|
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_3rd["embeddings"]), dim=1) if gt_det_3rd else torch.zeros(0, 512)
|
|
elif self.eval_feat == 'stage3':
|
|
embeddings = gt_det_2nd["embeddings"] if gt_det_3rd else torch.zeros(0, 256)
|
|
else:
|
|
raise Exception("Unknown evaluation feature name")
|
|
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
|
|
|
|
# --------------------- The third stage -------------------- #
|
|
box_features = self.box_roi_pool(features, boxes, image_shapes)
|
|
|
|
if not self.training:
|
|
box_features_2nd = self.box_head_2nd(box_features)
|
|
box_embeddings_2nd, _ = self.embedding_head_2nd(box_features_2nd)
|
|
|
|
box_features = self.box_head_3rd(box_features)
|
|
box_regs_3rd = self.box_predictor_3rd(box_features["after_trans"])
|
|
box_embeddings_3rd, box_cls_scores_3rd = self.embedding_head_3rd(box_features)
|
|
if box_cls_scores_3rd.dim() == 0:
|
|
box_cls_scores_3rd = box_cls_scores_3rd.unsqueeze(0)
|
|
|
|
result, losses = [], {}
|
|
if self.training:
|
|
box_labels_1st = [y.clamp(0, 1) for y in box_pid_labels_1st]
|
|
box_labels_2nd = [y.clamp(0, 1) for y in box_pid_labels_2nd]
|
|
box_labels_3rd = [y.clamp(0, 1) for y in box_pid_labels_3rd]
|
|
losses = detection_losses(
|
|
box_cls_scores_1st,
|
|
box_regs_1st,
|
|
box_labels_1st,
|
|
box_reg_targets_1st,
|
|
box_cls_scores_2nd,
|
|
box_regs_2nd,
|
|
box_labels_2nd,
|
|
box_reg_targets_2nd,
|
|
box_cls_scores_3rd,
|
|
box_regs_3rd,
|
|
box_labels_3rd,
|
|
box_reg_targets_3rd,
|
|
)
|
|
|
|
loss_rcnn_reid_2nd, feats_reid_2nd, targets_reid_2nd = self.reid_loss_2nd(box_embeddings_2nd, box_pid_labels_2nd)
|
|
loss_rcnn_reid_3rd, feats_reid_3rd, targets_reid_3rd = self.reid_loss_3rd(box_embeddings_3rd, box_pid_labels_3rd)
|
|
losses.update(loss_rcnn_reid_2nd=loss_rcnn_reid_2nd)
|
|
losses.update(loss_rcnn_reid_3rd=loss_rcnn_reid_3rd)
|
|
else:
|
|
if self.eval_feat == 'stage2':
|
|
boxes, scores, embeddings_2nd, labels = self.postprocess_boxes(
|
|
box_cls_scores_2nd,
|
|
box_regs_2nd,
|
|
box_embeddings_2nd,
|
|
boxes,
|
|
image_shapes,
|
|
fcs=scores,
|
|
gt_det=gt_det_2nd,
|
|
cws=cws,
|
|
)
|
|
else:
|
|
self.nms_thresh = self.nms_thresh_3rd
|
|
_, _, embeddings_2nd, _ = self.postprocess_boxes(
|
|
box_cls_scores_3rd,
|
|
box_regs_3rd,
|
|
box_embeddings_2nd,
|
|
boxes,
|
|
image_shapes,
|
|
fcs=scores,
|
|
gt_det=gt_det_2nd,
|
|
cws=cws,
|
|
)
|
|
boxes, scores, embeddings_3rd, labels = self.postprocess_boxes(
|
|
box_cls_scores_3rd,
|
|
box_regs_3rd,
|
|
box_embeddings_3rd,
|
|
boxes,
|
|
image_shapes,
|
|
fcs=scores,
|
|
gt_det=gt_det_3rd,
|
|
cws=cws,
|
|
)
|
|
# set to original thresh after finishing postprocess
|
|
self.nms_thresh = orig_thresh
|
|
|
|
num_images = len(boxes)
|
|
for i in range(num_images):
|
|
if self.eval_feat == 'concat':
|
|
embeddings = torch.cat((embeddings_2nd[i],embeddings_3rd[i]), dim=1)
|
|
elif self.eval_feat == 'stage2':
|
|
embeddings = embeddings_2nd[i]
|
|
elif self.eval_feat == 'stage3':
|
|
embeddings = embeddings_3rd[i]
|
|
else:
|
|
raise Exception("Unknown evaluation feature name")
|
|
result.append(
|
|
dict(
|
|
boxes=boxes[i],
|
|
labels=labels[i],
|
|
scores=scores[i],
|
|
embeddings=embeddings
|
|
)
|
|
)
|
|
|
|
return result, losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
|
|
|
|
def get_boxes(self, box_regression, proposals, image_shapes):
|
|
"""
|
|
Get boxes from proposals.
|
|
"""
|
|
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
|
|
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
|
pred_boxes = pred_boxes.split(boxes_per_image, 0)
|
|
|
|
all_boxes = []
|
|
for boxes, image_shape in zip(pred_boxes, image_shapes):
|
|
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
|
# remove predictions with the background label
|
|
boxes = boxes[:, 1:].reshape(-1, 4)
|
|
all_boxes.append(boxes)
|
|
|
|
return all_boxes
|
|
|
|
def postprocess_boxes(
|
|
self,
|
|
class_logits,
|
|
box_regression,
|
|
embeddings,
|
|
proposals,
|
|
image_shapes,
|
|
fcs=None,
|
|
gt_det=None,
|
|
cws=True,
|
|
):
|
|
"""
|
|
Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement
|
|
First Classification Score (FCS).
|
|
"""
|
|
device = class_logits.device
|
|
|
|
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
|
|
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
|
|
|
if fcs is not None:
|
|
# Fist Classification Score (FCS)
|
|
pred_scores = fcs[0]
|
|
else:
|
|
pred_scores = torch.sigmoid(class_logits)
|
|
if cws:
|
|
# Confidence Weighted Similarity (CWS)
|
|
embeddings = embeddings * pred_scores.view(-1, 1)
|
|
|
|
# split boxes and scores per image
|
|
pred_boxes = pred_boxes.split(boxes_per_image, 0)
|
|
pred_scores = pred_scores.split(boxes_per_image, 0)
|
|
pred_embeddings = embeddings.split(boxes_per_image, 0)
|
|
|
|
all_boxes = []
|
|
all_scores = []
|
|
all_labels = []
|
|
all_embeddings = []
|
|
for boxes, scores, embeddings, image_shape in zip(
|
|
pred_boxes, pred_scores, pred_embeddings, image_shapes
|
|
):
|
|
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
|
|
|
# create labels for each prediction
|
|
labels = torch.ones(scores.size(0), device=device)
|
|
|
|
# remove predictions with the background label
|
|
boxes = boxes[:, 1:]
|
|
scores = scores.unsqueeze(1)
|
|
labels = labels.unsqueeze(1)
|
|
|
|
# batch everything, by making every class prediction be a separate instance
|
|
boxes = boxes.reshape(-1, 4)
|
|
scores = scores.flatten()
|
|
labels = labels.flatten()
|
|
embeddings = embeddings.reshape(-1, self.embedding_head_2nd.dim)
|
|
|
|
# remove low scoring boxes
|
|
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
|
|
boxes, scores, labels, embeddings = (
|
|
boxes[inds],
|
|
scores[inds],
|
|
labels[inds],
|
|
embeddings[inds],
|
|
)
|
|
|
|
# remove empty boxes
|
|
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
|
|
boxes, scores, labels, embeddings = (
|
|
boxes[keep],
|
|
scores[keep],
|
|
labels[keep],
|
|
embeddings[keep],
|
|
)
|
|
|
|
if gt_det is not None:
|
|
# include GT into the detection results
|
|
boxes = torch.cat((boxes, gt_det["boxes"]), dim=0)
|
|
labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0)
|
|
scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0)
|
|
embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0)
|
|
|
|
# non-maximum suppression, independently done per class
|
|
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
|
# keep only topk scoring predictions
|
|
keep = keep[: self.detections_per_img]
|
|
boxes, scores, labels, embeddings = (
|
|
boxes[keep],
|
|
scores[keep],
|
|
labels[keep],
|
|
embeddings[keep],
|
|
)
|
|
|
|
all_boxes.append(boxes)
|
|
all_scores.append(scores)
|
|
all_labels.append(labels)
|
|
all_embeddings.append(embeddings)
|
|
|
|
return all_boxes, all_scores, all_embeddings, all_labels
|
|
|
|
|
|
class NormAwareEmbedding(nn.Module):
|
|
"""
|
|
Implements the Norm-Aware Embedding proposed in
|
|
Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020.
|
|
"""
|
|
|
|
def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256):
|
|
super(NormAwareEmbedding, self).__init__()
|
|
self.featmap_names = featmap_names
|
|
self.in_channels = in_channels
|
|
self.dim = dim
|
|
|
|
self.projectors = nn.ModuleDict()
|
|
indv_dims = self._split_embedding_dim()
|
|
for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims):
|
|
proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim))
|
|
init.normal_(proj[0].weight, std=0.01)
|
|
init.normal_(proj[1].weight, std=0.01)
|
|
init.constant_(proj[0].bias, 0)
|
|
init.constant_(proj[1].bias, 0)
|
|
self.projectors[ftname] = proj
|
|
|
|
self.rescaler = nn.BatchNorm1d(1, affine=True)
|
|
|
|
def forward(self, featmaps):
|
|
"""
|
|
Arguments:
|
|
featmaps: OrderedDict[Tensor], and in featmap_names you can choose which
|
|
featmaps to use
|
|
Returns:
|
|
tensor of size (BatchSize, dim), L2 normalized embeddings.
|
|
tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits.
|
|
"""
|
|
assert len(featmaps) == len(self.featmap_names)
|
|
if len(featmaps) == 1:
|
|
k, v = featmaps.items()[0]
|
|
v = self._flatten_fc_input(v)
|
|
embeddings = self.projectors[k](v)
|
|
norms = embeddings.norm(2, 1, keepdim=True)
|
|
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
|
|
norms = self.rescaler(norms).squeeze()
|
|
return embeddings, norms
|
|
else:
|
|
outputs = []
|
|
for k, v in featmaps.items():
|
|
v = self._flatten_fc_input(v)
|
|
outputs.append(self.projectors[k](v))
|
|
embeddings = torch.cat(outputs, dim=1)
|
|
norms = embeddings.norm(2, 1, keepdim=True)
|
|
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
|
|
norms = self.rescaler(norms).squeeze()
|
|
return embeddings, norms
|
|
|
|
def _flatten_fc_input(self, x):
|
|
if x.ndimension() == 4:
|
|
assert list(x.shape[2:]) == [1, 1]
|
|
return x.flatten(start_dim=1)
|
|
return x
|
|
|
|
def _split_embedding_dim(self):
|
|
parts = len(self.in_channels)
|
|
tmp = [self.dim // parts] * parts
|
|
if sum(tmp) == self.dim:
|
|
return tmp
|
|
else:
|
|
res = self.dim % parts
|
|
for i in range(1, res + 1):
|
|
tmp[-i] += 1
|
|
assert sum(tmp) == self.dim
|
|
return tmp
|
|
|
|
|
|
class BBoxRegressor(nn.Module):
|
|
"""
|
|
Bounding box regression layer.
|
|
"""
|
|
|
|
def __init__(self, in_channels, num_classes=2, bn_neck=True):
|
|
"""
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
num_classes (int, optional): Defaults to 2 (background and pedestrian).
|
|
bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True.
|
|
"""
|
|
super(BBoxRegressor, self).__init__()
|
|
if bn_neck:
|
|
self.bbox_pred = nn.Sequential(
|
|
nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes)
|
|
)
|
|
init.normal_(self.bbox_pred[0].weight, std=0.01)
|
|
init.normal_(self.bbox_pred[1].weight, std=0.01)
|
|
init.constant_(self.bbox_pred[0].bias, 0)
|
|
init.constant_(self.bbox_pred[1].bias, 0)
|
|
else:
|
|
self.bbox_pred = nn.Linear(in_channels, 4 * num_classes)
|
|
init.normal_(self.bbox_pred.weight, std=0.01)
|
|
init.constant_(self.bbox_pred.bias, 0)
|
|
|
|
def forward(self, x):
|
|
if x.ndimension() == 4:
|
|
if list(x.shape[2:]) != [1, 1]:
|
|
x = F.adaptive_avg_pool2d(x, output_size=1)
|
|
x = x.flatten(start_dim=1)
|
|
bbox_deltas = self.bbox_pred(x)
|
|
return bbox_deltas
|
|
|
|
|
|
def detection_losses(
|
|
box_cls_scores_1st,
|
|
box_regs_1st,
|
|
box_labels_1st,
|
|
box_reg_targets_1st,
|
|
box_cls_scores_2nd,
|
|
box_regs_2nd,
|
|
box_labels_2nd,
|
|
box_reg_targets_2nd,
|
|
box_cls_scores_3rd,
|
|
box_regs_3rd,
|
|
box_labels_3rd,
|
|
box_reg_targets_3rd,
|
|
):
|
|
# --------------------- The first stage -------------------- #
|
|
box_labels_1st = torch.cat(box_labels_1st, dim=0)
|
|
box_reg_targets_1st = torch.cat(box_reg_targets_1st, dim=0)
|
|
loss_rcnn_cls_1st = F.cross_entropy(box_cls_scores_1st, box_labels_1st)
|
|
|
|
# get indices that correspond to the regression targets for the
|
|
# corresponding ground truth labels, to be used with advanced indexing
|
|
sampled_pos_inds_subset = torch.nonzero(box_labels_1st > 0).squeeze(1)
|
|
labels_pos = box_labels_1st[sampled_pos_inds_subset]
|
|
N = box_cls_scores_1st.size(0)
|
|
box_regs_1st = box_regs_1st.reshape(N, -1, 4)
|
|
|
|
loss_rcnn_reg_1st = F.smooth_l1_loss(
|
|
box_regs_1st[sampled_pos_inds_subset, labels_pos],
|
|
box_reg_targets_1st[sampled_pos_inds_subset],
|
|
reduction="sum",
|
|
)
|
|
loss_rcnn_reg_1st = loss_rcnn_reg_1st / box_labels_1st.numel()
|
|
|
|
# --------------------- The second stage -------------------- #
|
|
box_labels_2nd = torch.cat(box_labels_2nd, dim=0)
|
|
box_reg_targets_2nd = torch.cat(box_reg_targets_2nd, dim=0)
|
|
loss_rcnn_cls_2nd = F.binary_cross_entropy_with_logits(box_cls_scores_2nd, box_labels_2nd.float())
|
|
|
|
sampled_pos_inds_subset = torch.nonzero(box_labels_2nd > 0).squeeze(1)
|
|
labels_pos = box_labels_2nd[sampled_pos_inds_subset]
|
|
N = box_cls_scores_2nd.size(0)
|
|
box_regs_2nd = box_regs_2nd.reshape(N, -1, 4)
|
|
|
|
loss_rcnn_reg_2nd = F.smooth_l1_loss(
|
|
box_regs_2nd[sampled_pos_inds_subset, labels_pos],
|
|
box_reg_targets_2nd[sampled_pos_inds_subset],
|
|
reduction="sum",
|
|
)
|
|
loss_rcnn_reg_2nd = loss_rcnn_reg_2nd / box_labels_2nd.numel()
|
|
|
|
# --------------------- The third stage -------------------- #
|
|
box_labels_3rd = torch.cat(box_labels_3rd, dim=0)
|
|
box_reg_targets_3rd = torch.cat(box_reg_targets_3rd, dim=0)
|
|
loss_rcnn_cls_3rd = F.binary_cross_entropy_with_logits(box_cls_scores_3rd, box_labels_3rd.float())
|
|
|
|
sampled_pos_inds_subset = torch.nonzero(box_labels_3rd > 0).squeeze(1)
|
|
labels_pos = box_labels_3rd[sampled_pos_inds_subset]
|
|
N = box_cls_scores_3rd.size(0)
|
|
box_regs_3rd = box_regs_3rd.reshape(N, -1, 4)
|
|
|
|
loss_rcnn_reg_3rd = F.smooth_l1_loss(
|
|
box_regs_3rd[sampled_pos_inds_subset, labels_pos],
|
|
box_reg_targets_3rd[sampled_pos_inds_subset],
|
|
reduction="sum",
|
|
)
|
|
loss_rcnn_reg_3rd = loss_rcnn_reg_3rd / box_labels_3rd.numel()
|
|
|
|
return dict(
|
|
loss_rcnn_cls_1st=loss_rcnn_cls_1st,
|
|
loss_rcnn_reg_1st=loss_rcnn_reg_1st,
|
|
loss_rcnn_cls_2nd=loss_rcnn_cls_2nd,
|
|
loss_rcnn_reg_2nd=loss_rcnn_reg_2nd,
|
|
loss_rcnn_cls_3rd=loss_rcnn_cls_3rd,
|
|
loss_rcnn_reg_3rd=loss_rcnn_reg_3rd,
|
|
)
|