Example-Research-COAT/Code/Python/models/coat.py

780 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import boxes as box_ops
from torchvision.models.detection import _utils as det_utils
from loss.oim import OIMLoss
from models.resnet import build_resnet,build_network
from models.transformer import TransformerHead
class COAT(nn.Module):
def __init__(self, cfg):
super(COAT, self).__init__()
#backbone = build_network(name="pvtv2", pretrained=True)
backbone = build_network(name="resnet50", pretrained=True)
anchor_generator = AnchorGenerator(
sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
)
head = RPNHead(
# rpn的输入通道数量是backbone的输出参数通道
in_channels=backbone.out_channels,
# 这个参数是sizes和aspect_ratios确定后这个参数num_anchors就确定了
num_anchors=anchor_generator.num_anchors_per_location()[0],
)
pre_nms_top_n = dict(
training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST
)
post_nms_top_n = dict(
training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST
)
rpn = RegionProposalNetwork(
anchor_generator=anchor_generator,
head=head,
fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN,
bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN,
batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN,
positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN,
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=cfg.MODEL.RPN.NMS_THRESH,
)
box_head = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_1ST,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_1ST,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_1ST,
)
box_head_2nd = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_2ND,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_2ND,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_2ND,
)
box_head_3rd = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_3RD,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_3RD,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_3RD,
)
# faster_rcnn分类器件2分类输入通道2048
# 这个2048是怎么来的
faster_rcnn_predictor = FastRCNNPredictor(2048, 2)
box_roi_pool = MultiScaleRoIAlign(
featmap_names=["feat_res4"], output_size=14, sampling_ratio=2
)
box_predictor = BBoxRegressor(2048, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK)
roi_heads = CascadedROIHeads(
cfg=cfg,
# Cascade Transformer Head
faster_rcnn_predictor=faster_rcnn_predictor,
box_head_2nd=box_head_2nd,
box_head_3rd=box_head_3rd,
# parent class
box_roi_pool=box_roi_pool,
box_head=box_head,
box_predictor=box_predictor,
fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN,
bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN,
batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN,
positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN,
bbox_reg_weights=None,
score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST,
nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST,
detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST,
)
transform = GeneralizedRCNNTransform(
min_size=cfg.INPUT.MIN_SIZE,
max_size=cfg.INPUT.MAX_SIZE,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
)
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
self.transform = transform
self.eval_feat = cfg.EVAL_FEATURE
# loss weights
self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG
self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS
self.lw_rcnn_reg_1st = cfg.SOLVER.LW_RCNN_REG_1ST
self.lw_rcnn_cls_1st = cfg.SOLVER.LW_RCNN_CLS_1ST
self.lw_rcnn_reg_2nd = cfg.SOLVER.LW_RCNN_REG_2ND
self.lw_rcnn_cls_2nd = cfg.SOLVER.LW_RCNN_CLS_2ND
self.lw_rcnn_reg_3rd = cfg.SOLVER.LW_RCNN_REG_3RD
self.lw_rcnn_cls_3rd = cfg.SOLVER.LW_RCNN_CLS_3RD
self.lw_rcnn_reid_2nd = cfg.SOLVER.LW_RCNN_REID_2ND
self.lw_rcnn_reid_3rd = cfg.SOLVER.LW_RCNN_REID_3RD
def inference(self, images, targets=None, query_img_as_gallery=False):
original_image_sizes = [img.shape[-2:] for img in images]
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if query_img_as_gallery:
assert targets is not None
if targets is not None and not query_img_as_gallery:
# query
boxes = [t["boxes"] for t in targets]
box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes)
box_features_2nd = self.roi_heads.box_head_2nd(box_features)
embeddings_2nd, _ = self.roi_heads.embedding_head_2nd(box_features_2nd)
box_features_3rd = self.roi_heads.box_head_3rd(box_features)
embeddings_3rd, _ = self.roi_heads.embedding_head_3rd(box_features_3rd)
if self.eval_feat == 'concat':
embeddings = torch.cat((embeddings_2nd, embeddings_3rd), dim=1)
elif self.eval_feat == 'stage2':
embeddings = embeddings_2nd
elif self.eval_feat == 'stage3':
embeddings = embeddings_3rd
else:
raise Exception("Unknown evaluation feature name")
return embeddings.split(1, 0)
else:
# gallery
boxes, _ = self.rpn(images, features, targets)
detections = self.roi_heads(features, boxes, images.image_sizes, targets, query_img_as_gallery)[0]
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
return detections
def forward(self, images, targets=None, query_img_as_gallery=False):
if not self.training:
return self.inference(images, targets, query_img_as_gallery)
# 对输入的图像和GT进行变换
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
print(features["feat_res4"].shape)
# 这里会有问题
boxes, rpn_losses = self.rpn(images, features, targets)
_, rcnn_losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = self.roi_heads(features, boxes, images.image_sizes, targets)
# rename rpn losses to be consistent with detection losses
rpn_losses["loss_rpn_reg"] = rpn_losses.pop("loss_rpn_box_reg")
rpn_losses["loss_rpn_cls"] = rpn_losses.pop("loss_objectness")
losses = {}
losses.update(rcnn_losses)
losses.update(rpn_losses)
# apply loss weights
losses["loss_rpn_reg"] *= self.lw_rpn_reg
losses["loss_rpn_cls"] *= self.lw_rpn_cls
losses["loss_rcnn_reg_1st"] *= self.lw_rcnn_reg_1st
losses["loss_rcnn_cls_1st"] *= self.lw_rcnn_cls_1st
losses["loss_rcnn_reg_2nd"] *= self.lw_rcnn_reg_2nd
losses["loss_rcnn_cls_2nd"] *= self.lw_rcnn_cls_2nd
losses["loss_rcnn_reg_3rd"] *= self.lw_rcnn_reg_3rd
losses["loss_rcnn_cls_3rd"] *= self.lw_rcnn_cls_3rd
losses["loss_rcnn_reid_2nd"] *= self.lw_rcnn_reid_2nd
losses["loss_rcnn_reid_3rd"] *= self.lw_rcnn_reid_3rd
return losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
class CascadedROIHeads(RoIHeads):
'''
https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py
'''
def __init__(
self,
cfg,
faster_rcnn_predictor,
box_head_2nd,
box_head_3rd,
*args,
**kwargs
):
super(CascadedROIHeads, self).__init__(*args, **kwargs)
# ROI head
self.use_diff_thresh=cfg.MODEL.ROI_HEAD.USE_DIFF_THRESH
self.nms_thresh_1st = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST
self.nms_thresh_2nd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND
self.nms_thresh_3rd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD
self.fg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN
self.bg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN
self.fg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND
self.bg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND
self.fg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD
self.bg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD
# Regression head
self.box_predictor_1st = faster_rcnn_predictor
self.box_predictor_2nd = self.box_predictor
self.box_predictor_3rd = deepcopy(self.box_predictor)
# Transformer head
self.box_head_1st = self.box_head
self.box_head_2nd = box_head_2nd
self.box_head_3rd = box_head_3rd
# feature mask
self.use_feature_mask = cfg.MODEL.USE_FEATURE_MASK
self.feature_mask_size = cfg.MODEL.FEATURE_MASK_SIZE
# Feature embedding
embedding_dim = cfg.MODEL.EMBEDDING_DIM
self.embedding_head_2nd = NormAwareEmbedding(featmap_names=["before_trans", "after_trans"], in_channels=[1024, 2048], dim=embedding_dim)
self.embedding_head_3rd = deepcopy(self.embedding_head_2nd)
# OIM
num_pids = cfg.MODEL.LOSS.LUT_SIZE
num_cq_size = cfg.MODEL.LOSS.CQ_SIZE
oim_momentum = cfg.MODEL.LOSS.OIM_MOMENTUM
oim_scalar = cfg.MODEL.LOSS.OIM_SCALAR
self.reid_loss_2nd = OIMLoss(embedding_dim, num_pids, num_cq_size, oim_momentum, oim_scalar)
self.reid_loss_3rd = deepcopy(self.reid_loss_2nd)
# rename the method inherited from parent class
self.postprocess_proposals = self.postprocess_detections
# evaluation
self.eval_feat = cfg.EVAL_FEATURE
def forward(self, features, boxes, image_shapes, targets=None, query_img_as_gallery=False):
"""
Arguments:
features (List[Tensor])
boxes (List[Tensor[N, 4]])
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
cws = True
gt_det_2nd = None
gt_det_3rd = None
feats_reid_2nd = None
feats_reid_3rd = None
targets_reid_2nd = None
targets_reid_3rd = None
if self.training:
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_1st,
self.bg_iou_thresh_1st,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_1st, box_reg_targets_1st = self.select_training_samples(
boxes, targets
)
# ------------------- The first stage ------------------ #
# print(features["feat_res4"].shape)
# torch.Size([2, 1024, 58, 94])
# torch.Size([2, 1024, 54, 94])
box_features_1st = self.box_roi_pool(features, boxes, image_shapes)
box_features_1st = self.box_head_1st(box_features_1st)
box_cls_scores_1st, box_regs_1st = self.box_predictor_1st(box_features_1st["after_trans"])
if self.training:
boxes = self.get_boxes(box_regs_1st, boxes, image_shapes)
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_2nd,
self.bg_iou_thresh_2nd,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_2nd, box_reg_targets_2nd = self.select_training_samples(boxes, targets)
else:
orig_thresh = self.nms_thresh # 0.4
self.nms_thresh = self.nms_thresh_1st
boxes, scores, _ = self.postprocess_proposals(
box_cls_scores_1st, box_regs_1st, boxes, image_shapes
)
if not self.training and query_img_as_gallery:
# When regarding the query image as gallery, GT boxes may be excluded
# from detected boxes. To avoid this, we compulsorily include GT in the
# detection results. Additionally, CWS should be disabled as the
# confidences of these people in query image are 1
cws = False
gt_box = [targets[0]["boxes"]]
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
gt_box_features = self.box_head_2nd(gt_box_features)
embeddings, _ = self.embedding_head_2nd(gt_box_features)
gt_det_2nd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
# no detection predicted by Faster R-CNN head in test phase
if boxes[0].shape[0] == 0:
assert not self.training
boxes = gt_det_2nd["boxes"] if gt_det_2nd else torch.zeros(0, 4)
labels = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
scores = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
if self.eval_feat == 'concat':
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_2nd["embeddings"]), dim=1) if gt_det_2nd else torch.zeros(0, 512)
elif self.eval_feat == 'stage2' or self.eval_feat == 'stage3':
embeddings = gt_det_2nd["embeddings"] if gt_det_2nd else torch.zeros(0, 256)
else:
raise Exception("Unknown evaluation feature name")
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
# --------------------- The second stage -------------------- #
box_features = self.box_roi_pool(features, boxes, image_shapes)
box_features = self.box_head_2nd(box_features)
box_regs_2nd = self.box_predictor_2nd(box_features["after_trans"])
box_embeddings_2nd, box_cls_scores_2nd = self.embedding_head_2nd(box_features)
if box_cls_scores_2nd.dim() == 0:
box_cls_scores_2nd = box_cls_scores_2nd.unsqueeze(0)
if self.training:
boxes = self.get_boxes(box_regs_2nd, boxes, image_shapes)
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_3rd,
self.bg_iou_thresh_3rd,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_3rd, box_reg_targets_3rd = self.select_training_samples(boxes, targets)
else:
self.nms_thresh = self.nms_thresh_2nd
if self.eval_feat != 'stage2':
boxes, scores, _, _ = self.postprocess_boxes(
box_cls_scores_2nd,
box_regs_2nd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=None,
cws=cws,
)
if not self.training and query_img_as_gallery and self.eval_feat != 'stage2':
cws = False
gt_box = [targets[0]["boxes"]]
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
gt_box_features = self.box_head_3rd(gt_box_features)
embeddings, _ = self.embedding_head_3rd(gt_box_features)
gt_det_3rd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
# no detection predicted by Faster R-CNN head in test phase
if boxes[0].shape[0] == 0 and self.eval_feat != 'stage2':
assert not self.training
boxes = gt_det_3rd["boxes"] if gt_det_3rd else torch.zeros(0, 4)
labels = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
scores = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
if self.eval_feat == 'concat':
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_3rd["embeddings"]), dim=1) if gt_det_3rd else torch.zeros(0, 512)
elif self.eval_feat == 'stage3':
embeddings = gt_det_2nd["embeddings"] if gt_det_3rd else torch.zeros(0, 256)
else:
raise Exception("Unknown evaluation feature name")
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
# --------------------- The third stage -------------------- #
box_features = self.box_roi_pool(features, boxes, image_shapes)
if not self.training:
box_features_2nd = self.box_head_2nd(box_features)
box_embeddings_2nd, _ = self.embedding_head_2nd(box_features_2nd)
box_features = self.box_head_3rd(box_features)
box_regs_3rd = self.box_predictor_3rd(box_features["after_trans"])
box_embeddings_3rd, box_cls_scores_3rd = self.embedding_head_3rd(box_features)
if box_cls_scores_3rd.dim() == 0:
box_cls_scores_3rd = box_cls_scores_3rd.unsqueeze(0)
result, losses = [], {}
if self.training:
box_labels_1st = [y.clamp(0, 1) for y in box_pid_labels_1st]
box_labels_2nd = [y.clamp(0, 1) for y in box_pid_labels_2nd]
box_labels_3rd = [y.clamp(0, 1) for y in box_pid_labels_3rd]
losses = detection_losses(
box_cls_scores_1st,
box_regs_1st,
box_labels_1st,
box_reg_targets_1st,
box_cls_scores_2nd,
box_regs_2nd,
box_labels_2nd,
box_reg_targets_2nd,
box_cls_scores_3rd,
box_regs_3rd,
box_labels_3rd,
box_reg_targets_3rd,
)
loss_rcnn_reid_2nd, feats_reid_2nd, targets_reid_2nd = self.reid_loss_2nd(box_embeddings_2nd, box_pid_labels_2nd)
loss_rcnn_reid_3rd, feats_reid_3rd, targets_reid_3rd = self.reid_loss_3rd(box_embeddings_3rd, box_pid_labels_3rd)
losses.update(loss_rcnn_reid_2nd=loss_rcnn_reid_2nd)
losses.update(loss_rcnn_reid_3rd=loss_rcnn_reid_3rd)
else:
if self.eval_feat == 'stage2':
boxes, scores, embeddings_2nd, labels = self.postprocess_boxes(
box_cls_scores_2nd,
box_regs_2nd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_2nd,
cws=cws,
)
else:
self.nms_thresh = self.nms_thresh_3rd
_, _, embeddings_2nd, _ = self.postprocess_boxes(
box_cls_scores_3rd,
box_regs_3rd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_2nd,
cws=cws,
)
boxes, scores, embeddings_3rd, labels = self.postprocess_boxes(
box_cls_scores_3rd,
box_regs_3rd,
box_embeddings_3rd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_3rd,
cws=cws,
)
# set to original thresh after finishing postprocess
self.nms_thresh = orig_thresh
num_images = len(boxes)
for i in range(num_images):
if self.eval_feat == 'concat':
embeddings = torch.cat((embeddings_2nd[i],embeddings_3rd[i]), dim=1)
elif self.eval_feat == 'stage2':
embeddings = embeddings_2nd[i]
elif self.eval_feat == 'stage3':
embeddings = embeddings_3rd[i]
else:
raise Exception("Unknown evaluation feature name")
result.append(
dict(
boxes=boxes[i],
labels=labels[i],
scores=scores[i],
embeddings=embeddings
)
)
return result, losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
def get_boxes(self, box_regression, proposals, image_shapes):
"""
Get boxes from proposals.
"""
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
pred_boxes = pred_boxes.split(boxes_per_image, 0)
all_boxes = []
for boxes, image_shape in zip(pred_boxes, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# remove predictions with the background label
boxes = boxes[:, 1:].reshape(-1, 4)
all_boxes.append(boxes)
return all_boxes
def postprocess_boxes(
self,
class_logits,
box_regression,
embeddings,
proposals,
image_shapes,
fcs=None,
gt_det=None,
cws=True,
):
"""
Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement
First Classification Score (FCS).
"""
device = class_logits.device
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
if fcs is not None:
# Fist Classification Score (FCS)
pred_scores = fcs[0]
else:
pred_scores = torch.sigmoid(class_logits)
if cws:
# Confidence Weighted Similarity (CWS)
embeddings = embeddings * pred_scores.view(-1, 1)
# split boxes and scores per image
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0)
pred_embeddings = embeddings.split(boxes_per_image, 0)
all_boxes = []
all_scores = []
all_labels = []
all_embeddings = []
for boxes, scores, embeddings, image_shape in zip(
pred_boxes, pred_scores, pred_embeddings, image_shapes
):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# create labels for each prediction
labels = torch.ones(scores.size(0), device=device)
# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores.unsqueeze(1)
labels = labels.unsqueeze(1)
# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.flatten()
labels = labels.flatten()
embeddings = embeddings.reshape(-1, self.embedding_head_2nd.dim)
# remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
boxes, scores, labels, embeddings = (
boxes[inds],
scores[inds],
labels[inds],
embeddings[inds],
)
# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
boxes, scores, labels, embeddings = (
boxes[keep],
scores[keep],
labels[keep],
embeddings[keep],
)
if gt_det is not None:
# include GT into the detection results
boxes = torch.cat((boxes, gt_det["boxes"]), dim=0)
labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0)
scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0)
embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0)
# non-maximum suppression, independently done per class
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[: self.detections_per_img]
boxes, scores, labels, embeddings = (
boxes[keep],
scores[keep],
labels[keep],
embeddings[keep],
)
all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)
all_embeddings.append(embeddings)
return all_boxes, all_scores, all_embeddings, all_labels
class NormAwareEmbedding(nn.Module):
"""
Implements the Norm-Aware Embedding proposed in
Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020.
"""
def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256):
super(NormAwareEmbedding, self).__init__()
self.featmap_names = featmap_names
self.in_channels = in_channels
self.dim = dim
self.projectors = nn.ModuleDict()
indv_dims = self._split_embedding_dim()
for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims):
proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim))
init.normal_(proj[0].weight, std=0.01)
init.normal_(proj[1].weight, std=0.01)
init.constant_(proj[0].bias, 0)
init.constant_(proj[1].bias, 0)
self.projectors[ftname] = proj
self.rescaler = nn.BatchNorm1d(1, affine=True)
def forward(self, featmaps):
"""
Arguments:
featmaps: OrderedDict[Tensor], and in featmap_names you can choose which
featmaps to use
Returns:
tensor of size (BatchSize, dim), L2 normalized embeddings.
tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits.
"""
assert len(featmaps) == len(self.featmap_names)
if len(featmaps) == 1:
k, v = featmaps.items()[0]
v = self._flatten_fc_input(v)
embeddings = self.projectors[k](v)
norms = embeddings.norm(2, 1, keepdim=True)
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
norms = self.rescaler(norms).squeeze()
return embeddings, norms
else:
outputs = []
for k, v in featmaps.items():
v = self._flatten_fc_input(v)
outputs.append(self.projectors[k](v))
embeddings = torch.cat(outputs, dim=1)
norms = embeddings.norm(2, 1, keepdim=True)
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
norms = self.rescaler(norms).squeeze()
return embeddings, norms
def _flatten_fc_input(self, x):
if x.ndimension() == 4:
assert list(x.shape[2:]) == [1, 1]
return x.flatten(start_dim=1)
return x
def _split_embedding_dim(self):
parts = len(self.in_channels)
tmp = [self.dim // parts] * parts
if sum(tmp) == self.dim:
return tmp
else:
res = self.dim % parts
for i in range(1, res + 1):
tmp[-i] += 1
assert sum(tmp) == self.dim
return tmp
class BBoxRegressor(nn.Module):
"""
Bounding box regression layer.
"""
def __init__(self, in_channels, num_classes=2, bn_neck=True):
"""
Args:
in_channels (int): Input channels.
num_classes (int, optional): Defaults to 2 (background and pedestrian).
bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True.
"""
super(BBoxRegressor, self).__init__()
if bn_neck:
self.bbox_pred = nn.Sequential(
nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes)
)
init.normal_(self.bbox_pred[0].weight, std=0.01)
init.normal_(self.bbox_pred[1].weight, std=0.01)
init.constant_(self.bbox_pred[0].bias, 0)
init.constant_(self.bbox_pred[1].bias, 0)
else:
self.bbox_pred = nn.Linear(in_channels, 4 * num_classes)
init.normal_(self.bbox_pred.weight, std=0.01)
init.constant_(self.bbox_pred.bias, 0)
def forward(self, x):
if x.ndimension() == 4:
if list(x.shape[2:]) != [1, 1]:
x = F.adaptive_avg_pool2d(x, output_size=1)
x = x.flatten(start_dim=1)
bbox_deltas = self.bbox_pred(x)
return bbox_deltas
def detection_losses(
box_cls_scores_1st,
box_regs_1st,
box_labels_1st,
box_reg_targets_1st,
box_cls_scores_2nd,
box_regs_2nd,
box_labels_2nd,
box_reg_targets_2nd,
box_cls_scores_3rd,
box_regs_3rd,
box_labels_3rd,
box_reg_targets_3rd,
):
# --------------------- The first stage -------------------- #
box_labels_1st = torch.cat(box_labels_1st, dim=0)
box_reg_targets_1st = torch.cat(box_reg_targets_1st, dim=0)
loss_rcnn_cls_1st = F.cross_entropy(box_cls_scores_1st, box_labels_1st)
# get indices that correspond to the regression targets for the
# corresponding ground truth labels, to be used with advanced indexing
sampled_pos_inds_subset = torch.nonzero(box_labels_1st > 0).squeeze(1)
labels_pos = box_labels_1st[sampled_pos_inds_subset]
N = box_cls_scores_1st.size(0)
box_regs_1st = box_regs_1st.reshape(N, -1, 4)
loss_rcnn_reg_1st = F.smooth_l1_loss(
box_regs_1st[sampled_pos_inds_subset, labels_pos],
box_reg_targets_1st[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_1st = loss_rcnn_reg_1st / box_labels_1st.numel()
# --------------------- The second stage -------------------- #
box_labels_2nd = torch.cat(box_labels_2nd, dim=0)
box_reg_targets_2nd = torch.cat(box_reg_targets_2nd, dim=0)
loss_rcnn_cls_2nd = F.binary_cross_entropy_with_logits(box_cls_scores_2nd, box_labels_2nd.float())
sampled_pos_inds_subset = torch.nonzero(box_labels_2nd > 0).squeeze(1)
labels_pos = box_labels_2nd[sampled_pos_inds_subset]
N = box_cls_scores_2nd.size(0)
box_regs_2nd = box_regs_2nd.reshape(N, -1, 4)
loss_rcnn_reg_2nd = F.smooth_l1_loss(
box_regs_2nd[sampled_pos_inds_subset, labels_pos],
box_reg_targets_2nd[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_2nd = loss_rcnn_reg_2nd / box_labels_2nd.numel()
# --------------------- The third stage -------------------- #
box_labels_3rd = torch.cat(box_labels_3rd, dim=0)
box_reg_targets_3rd = torch.cat(box_reg_targets_3rd, dim=0)
loss_rcnn_cls_3rd = F.binary_cross_entropy_with_logits(box_cls_scores_3rd, box_labels_3rd.float())
sampled_pos_inds_subset = torch.nonzero(box_labels_3rd > 0).squeeze(1)
labels_pos = box_labels_3rd[sampled_pos_inds_subset]
N = box_cls_scores_3rd.size(0)
box_regs_3rd = box_regs_3rd.reshape(N, -1, 4)
loss_rcnn_reg_3rd = F.smooth_l1_loss(
box_regs_3rd[sampled_pos_inds_subset, labels_pos],
box_reg_targets_3rd[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_3rd = loss_rcnn_reg_3rd / box_labels_3rd.numel()
return dict(
loss_rcnn_cls_1st=loss_rcnn_cls_1st,
loss_rcnn_reg_1st=loss_rcnn_reg_1st,
loss_rcnn_cls_2nd=loss_rcnn_cls_2nd,
loss_rcnn_reg_2nd=loss_rcnn_reg_2nd,
loss_rcnn_cls_3rd=loss_rcnn_cls_3rd,
loss_rcnn_reg_3rd=loss_rcnn_reg_3rd,
)