更新pvtv2-b2的分支
parent
bbe2e5694b
commit
b44c1bfb16
|
@ -0,0 +1,3 @@
|
||||||
|
**/__pycache__
|
||||||
|
*.pth
|
||||||
|
**/logs
|
|
@ -0,0 +1,26 @@
|
||||||
|
name: coat
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- conda-forge
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- cudatoolkit=11.0
|
||||||
|
- numpy=1.19.2
|
||||||
|
- pillow=8.2.0
|
||||||
|
- pip=21.0.1
|
||||||
|
- python=3.8.8
|
||||||
|
- pytorch=1.7.1
|
||||||
|
- scipy=1.6.2
|
||||||
|
- torchvision=0.8.2
|
||||||
|
- tqdm=4.60.0
|
||||||
|
- scikit-learn=0.24.1
|
||||||
|
- black=21.5b0
|
||||||
|
- flake8=3.9.0
|
||||||
|
- isort=5.8.0
|
||||||
|
- tabulate=0.8.9
|
||||||
|
- future=0.18.2
|
||||||
|
- tensorboard=2.4.1
|
||||||
|
- tensorboardx=2.2
|
||||||
|
- pip:
|
||||||
|
- ipython==7.5.0
|
||||||
|
- yacs==0.1.8
|
|
@ -0,0 +1,150 @@
|
||||||
|
# **COAT代码使用说明**
|
||||||
|
|
||||||
|
这个存储库托管了论文的源代码:[[CVPR 2022] Cascade Transformers for End-to-End Person Search](https://arxiv.org/abs/2203.09642)。在这项工作中,我们开发了一种新颖的级联遮挡感知Transformer(COAT)模型,用于端到端的人物搜索。COAT模型在PRW基准数据集上以显著的优势胜过了最先进的方法,并在CUHK-SYSU数据集上取得了最先进的性能。
|
||||||
|
|
||||||
|
| 数据集(Datasets) | mAP | Top-1 | Model |
|
||||||
|
| ---------------- | ---- | ----- | ------------------------------------------------------------ |
|
||||||
|
| CUHK-SYSU | 94.2 | 94.7 | [model](https://drive.google.com/file/d/1LkEwXYaJg93yk4Kfhyk3m6j8v3i9s1B7/view?usp=sharing) |
|
||||||
|
| PRW | 53.3 | 87.4 | [model](https://drive.google.com/file/d/1vEd_zzFN88RgxbRMG5-WfJZgD3vmP0Xg/view?usp=sharing) |
|
||||||
|
|
||||||
|
**Abstract**: The goal of person search is to localize a target person from a gallery set of scene images, which is extremely challenging due to large scale variations, pose/viewpoint changes, and occlusions. In this paper, we propose the Cascade Occluded Attention Transformer (COAT) for end-to-end person search. Specifically, our three-stage cascade design focuses on detecting people at the first stage, then progressively refines the representation for person detection and re-identification simultaneously at the following stages. The occluded attention transformer at each stage applies tighter intersection over union thresholds, forcing the network to learn coarse-to-fine pose/scale invariant features. Meanwhile, we calculate the occluded attention across instances in a mini-batch to differentiate tokens from other people or the background. In this way, we simulate the effect of other objects occluding a person of interest at the token-level. Through comprehensive experiments, we demonstrate the benefits of our method by achieving state-of-the-art performance on two benchmark datasets.
|
||||||
|
|
||||||
|
![COAT](doc/framework.png)
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
1. Download the datasets in your path `$DATA_DIR`. Change the dataset paths in L4 in [cuhk_sysu.yaml](configs/cuhk_sysu.yaml) and [prw.yaml](configs/prw.yaml).
|
||||||
|
|
||||||
|
**PRW**:
|
||||||
|
|
||||||
|
```
|
||||||
|
cd $DATA_DIR
|
||||||
|
pip install gdown
|
||||||
|
gdown https://drive.google.com/uc?id=0B6tjyrV1YrHeYnlhNnhEYTh5MUU
|
||||||
|
unzip PRW-v16.04.20.zip
|
||||||
|
mv PRW-v16.04.20 PRW
|
||||||
|
```
|
||||||
|
|
||||||
|
**CUHK-SYSU**:
|
||||||
|
|
||||||
|
```
|
||||||
|
cd $DATA_DIR
|
||||||
|
gdown https://drive.google.com/uc?id=1z3LsFrJTUeEX3-XjSEJMOBrslxD2T5af
|
||||||
|
tar -xzvf cuhk_sysu.tar.gz
|
||||||
|
mv cuhk_sysu CUHK-SYSU
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Our method is tested with PyTorch 1.7.1. You can install the required packages by anaconda/miniconda with the following commands:
|
||||||
|
|
||||||
|
```
|
||||||
|
cd COAT
|
||||||
|
conda env create -f COAT_pt171.yml
|
||||||
|
conda activate coat
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to install another version of PyTorch, you can modify the versions in `coat_pt171.yml`. Just make sure the dependencies have the appropriate version.
|
||||||
|
|
||||||
|
|
||||||
|
## CUHK-SYSU数据集实验
|
||||||
|
**训练**: 目前代码只支持单GPU. The default training script for CUHK-SYSU is as follows:
|
||||||
|
|
||||||
|
**在本地GTX4090训练**:
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
cd COAT
|
||||||
|
# 说明:4090显存较小,所以batchsize只能设置为2, 实测可以运行
|
||||||
|
python train.py --cfg configs/cuhk_sysu-local.yaml INPUT.BATCH_SIZE_TRAIN 2 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 14 SOLVER.LR_DECAY_MILESTONES [11] MODEL.LOSS.USE_SOFTMAX True SOLVER.LW_RCNN_SOFTMAX_2ND 0.1 SOLVER.LW_RCNN_SOFTMAX_3RD 0.1 OUTPUT_DIR ./logs/cuhk-sysu
|
||||||
|
```
|
||||||
|
|
||||||
|
**在本地UESTC训练**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd COAT
|
||||||
|
# 说明:RTX8000显存48G,所以batchsize只能设置为3
|
||||||
|
python train.py --cfg configs/cuhk_sysu.yaml INPUT.BATCH_SIZE_TRAIN 2 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 14 SOLVER.LR_DECAY_MILESTONES [11] MODEL.LOSS.USE_SOFTMAX True SOLVER.LW_RCNN_SOFTMAX_2ND 0.1 SOLVER.LW_RCNN_SOFTMAX_3RD 0.1 OUTPUT_DIR ./logs/cuhk-sysu
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that the dataset-specific parameters are defined in `configs/cuhk_sysu.yaml`. When the batch size (`INPUT.BATCH_SIZE_TRAIN`) is 3, the training will take about 23GB GPU memory, being suitable for GPUs like RTX6000. When the batch size is 5, the training will take about 38GB GPU memory, being able to run on A100 GPU. The larger batch size usually results in better performance on CUHK-SYSU.
|
||||||
|
|
||||||
|
For the CUHK-SYSU dataset, we use a relative low weight for softmax loss (`SOLVER.LW_RCNN_SOFTMAX_2ND` 0.1 and `SOLVER.LW_RCNN_SOFTMAX_3RD` 0.1). The trained models and TF logs will be saved in the folder `OUTPUT_DIR`. Other important training parameters can be found in the file `COAT/defaults.py`. For example, `CKPT_PERIOD` is the frequency of saving a checkpoint model.
|
||||||
|
|
||||||
|
**Testing**: The test script is very simple. You just need to add the flag `--eval` and provide the folder `--ckpt` where the [model](https://drive.google.com/file/d/1LkEwXYaJg93yk4Kfhyk3m6j8v3i9s1B7/view?usp=sharing) was saved.
|
||||||
|
|
||||||
|
测试:这个测试脚本非常简单,你只需要添加flag --eval以及对应提供--ckpt当模型已经保存的时候
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth
|
||||||
|
```
|
||||||
|
|
||||||
|
**Testing with CBGM**: Context Bipartite Graph Matching ([CBGM](https://github.com/serend1p1ty/SeqNet)) is an optimized matching algorithm in test phase. The detail can be found in the paper [[AAAI 2021] Sequential End-to-end Network for Efficient Person Search](https://arxiv.org/abs/2103.10148). We can use CBGM to further improve the person search accuracy. In test script, we just set the flag `EVAL_USE_CBGM` to True (default is False).
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth EVAL_USE_CB GM True
|
||||||
|
```
|
||||||
|
|
||||||
|
**Testing with different gallery sizes on CUHK-SYSU**: The default gallery size for evaluating CUHK-SYSU is 100. If you want to test with other pre-defined gallery sizes (50, 100, 500, 1000, 2000, 4000) for drawing the CUHK-SYSU gallery size curve, please set the parameter `EVAL_GALLERY_SIZE` with a gallery size.
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth EVAL_GALLER Y_SIZE 500
|
||||||
|
```
|
||||||
|
|
||||||
|
## Experiments on PRW
|
||||||
|
**Training**: The script is similar to CUHK-SYSU. The code currently only supports single GPU. The default training script for PRW is as follows:
|
||||||
|
|
||||||
|
**在本地GTX4090训练**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd COAT
|
||||||
|
# PRW数据集较小,可以RTX4090的bs可以设置为3
|
||||||
|
python train.py --cfg ./configs/prw-local.yaml INPUT.BATCH_SIZE_TRAIN 3 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 13 MODEL.LOSS.USE_SOFTMAX True OUTPUT_DIR ./logs/prw
|
||||||
|
```
|
||||||
|
|
||||||
|
**在本地UESTC训练**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd COAT
|
||||||
|
# PRW数据集较小,可以RTX4090的bs可以设置为3
|
||||||
|
python train.py --cfg ./configs/prw.yaml INPUT.BATCH_SIZE_TRAIN 3 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 13 MODEL.LOSS.USE_SOFTMAX True OUTPUT_DIR ./logs/prw
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset-specific parameters are defined in `configs/prw.yaml`. When the batch size (`INPUT.BATCH_SIZE_TRAIN`) is 3, the training will take about 19GB GPU memory, being suitable for GPUs like RTX6000. The larger batch size does not necessarily result in better accuracy on the PRW dataset.
|
||||||
|
Softmax loss is effective on PRW. The default weights of softmax loss at Stage 2 and Stage 3 (`SOLVER.LW_RCNN_SOFTMAX_2ND` and `SOLVER.LW_RCNN_SOFTMAX_3RD`) are 0.5, which can be found in the file `COAT/defaults.py`. If you want to run a model without Softmax loss for comparison, just set `MODEL.LOSS.USE_SOFTMAX` to False in the script.
|
||||||
|
|
||||||
|
|
||||||
|
**Testing**: The test script is similar to CUHK-SYSU. Make sure the path of pre-trained model [model](https://drive.google.com/file/d/1vEd_zzFN88RgxbRMG5-WfJZgD3vmP0Xg/view?usp=sharing) is correct.
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --cfg ./logs/prw/config.yaml --eval --ckpt ./logs/prw/prw_COAT.pth
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Testing with CBGM**: Similar to CUHK-SYSU, set the flag `EVAL_USE_CBGM` to True (default is False).
|
||||||
|
|
||||||
|
```
|
||||||
|
python train.py --cfg ./logs/prw/config.yaml --eval --ckpt ./logs/prw/prw_COAT.pth EVAL_USE_CBGM True
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Acknowledgement
|
||||||
|
This code borrows from [SeqNet](https://github.com/serend1p1ty/SeqNet), [TransReID](https://github.com/damo-cv/TransReID), and [DSTT](https://github.com/ruiliu-ai/DSTT).
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
If you use this code in your research, please cite this project as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
@inproceedings{yu2022coat,
|
||||||
|
title = {Cascade Transformers for End-to-End Person Search},
|
||||||
|
author = {Rui Yu and
|
||||||
|
Dawei Du and
|
||||||
|
Rodney LaLonde and
|
||||||
|
Daniel Davila and
|
||||||
|
Christopher Funk and
|
||||||
|
Anthony Hoogs and
|
||||||
|
Brian Clipp},
|
||||||
|
booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition},
|
||||||
|
year = {2022}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
This work is distributed under the OSI-approved BSD 3-Clause [License](https://github.com/Kitware/COAT/blob/master/LICENSE).
|
|
@ -0,0 +1,15 @@
|
||||||
|
OUTPUT_DIR: "./logs/cuhk_coat"
|
||||||
|
INPUT:
|
||||||
|
DATASET: "CUHK-SYSU"
|
||||||
|
DATA_ROOT: "E:/DeepLearning/PersonSearch/COAT/datasets/CUHK-SYSU"
|
||||||
|
BATCH_SIZE_TRAIN: 3
|
||||||
|
SOLVER:
|
||||||
|
MAX_EPOCHS: 14
|
||||||
|
BASE_LR: 0.003
|
||||||
|
LW_RCNN_SOFTMAX_2ND: 0.1
|
||||||
|
LW_RCNN_SOFTMAX_3RD: 0.1
|
||||||
|
MODEL:
|
||||||
|
LOSS:
|
||||||
|
LUT_SIZE: 5532
|
||||||
|
CQ_SIZE: 5000
|
||||||
|
DISP_PERIOD: 100
|
|
@ -0,0 +1,15 @@
|
||||||
|
OUTPUT_DIR: "./logs/cuhk_coat"
|
||||||
|
INPUT:
|
||||||
|
DATASET: "CUHK-SYSU"
|
||||||
|
DATA_ROOT: "/home/logzhan/datasets/CUHK-SYSU"
|
||||||
|
BATCH_SIZE_TRAIN: 4
|
||||||
|
SOLVER:
|
||||||
|
MAX_EPOCHS: 14
|
||||||
|
BASE_LR: 0.003
|
||||||
|
LW_RCNN_SOFTMAX_2ND: 0.1
|
||||||
|
LW_RCNN_SOFTMAX_3RD: 0.1
|
||||||
|
MODEL:
|
||||||
|
LOSS:
|
||||||
|
LUT_SIZE: 5532
|
||||||
|
CQ_SIZE: 5000
|
||||||
|
DISP_PERIOD: 100
|
|
@ -0,0 +1,13 @@
|
||||||
|
OUTPUT_DIR: "./logs/prw_coat"
|
||||||
|
INPUT:
|
||||||
|
DATASET: "PRW"
|
||||||
|
DATA_ROOT: "E:/DeepLearning/PersonSearch/COAT/datasets/PRW"
|
||||||
|
BATCH_SIZE_TRAIN: 3
|
||||||
|
SOLVER:
|
||||||
|
MAX_EPOCHS: 13
|
||||||
|
BASE_LR: 0.003
|
||||||
|
MODEL:
|
||||||
|
LOSS:
|
||||||
|
LUT_SIZE: 482
|
||||||
|
CQ_SIZE: 500
|
||||||
|
DISP_PERIOD: 100
|
|
@ -0,0 +1,13 @@
|
||||||
|
OUTPUT_DIR: "./logs/prw_coat"
|
||||||
|
INPUT:
|
||||||
|
DATASET: "PRW"
|
||||||
|
DATA_ROOT: "../../datasets/PRW"
|
||||||
|
BATCH_SIZE_TRAIN: 3
|
||||||
|
SOLVER:
|
||||||
|
MAX_EPOCHS: 13
|
||||||
|
BASE_LR: 0.003
|
||||||
|
MODEL:
|
||||||
|
LOSS:
|
||||||
|
LUT_SIZE: 482
|
||||||
|
CQ_SIZE: 500
|
||||||
|
DISP_PERIOD: 100
|
|
@ -0,0 +1,5 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
from .build import build_test_loader, build_train_loader
|
|
@ -0,0 +1,42 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
class BaseDataset:
|
||||||
|
"""
|
||||||
|
Base class of person search dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root, transforms, split):
|
||||||
|
self.root = root
|
||||||
|
self.transforms = transforms
|
||||||
|
self.split = split
|
||||||
|
assert self.split in ("train", "gallery", "query")
|
||||||
|
self.annotations = self._load_annotations()
|
||||||
|
|
||||||
|
def _load_annotations(self):
|
||||||
|
"""
|
||||||
|
For each image, load its annotation that is a dictionary with the following keys:
|
||||||
|
img_name (str): image name
|
||||||
|
img_path (str): image path
|
||||||
|
boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format
|
||||||
|
pids (np.array[N]): person IDs corresponding to these boxes
|
||||||
|
cam_id (int): camera ID (only for PRW dataset)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
anno = self.annotations[index]
|
||||||
|
img = Image.open(anno["img_path"]).convert("RGB")
|
||||||
|
boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32)
|
||||||
|
labels = torch.as_tensor(anno["pids"], dtype=torch.int64)
|
||||||
|
target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels}
|
||||||
|
if self.transforms is not None:
|
||||||
|
img, target = self.transforms(img, target)
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.annotations)
|
|
@ -0,0 +1,104 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from utils.transforms import build_transforms
|
||||||
|
from utils.utils import create_small_table
|
||||||
|
from .cuhk_sysu import CUHKSYSU
|
||||||
|
from .prw import PRW
|
||||||
|
|
||||||
|
def print_statistics(dataset):
|
||||||
|
"""
|
||||||
|
Print dataset statistics.
|
||||||
|
"""
|
||||||
|
num_imgs = len(dataset.annotations)
|
||||||
|
num_boxes = 0
|
||||||
|
pid_set = set()
|
||||||
|
for anno in dataset.annotations:
|
||||||
|
num_boxes += anno["boxes"].shape[0]
|
||||||
|
for pid in anno["pids"]:
|
||||||
|
pid_set.add(pid)
|
||||||
|
statistics = {
|
||||||
|
"dataset": dataset.name,
|
||||||
|
"split": dataset.split,
|
||||||
|
"num_images": num_imgs,
|
||||||
|
"num_boxes": num_boxes,
|
||||||
|
}
|
||||||
|
if dataset.name != "CUHK-SYSU" or dataset.split != "query":
|
||||||
|
pid_list = sorted(list(pid_set))
|
||||||
|
if dataset.split == "query":
|
||||||
|
num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
|
||||||
|
statistics.update(
|
||||||
|
{
|
||||||
|
"num_labeled_pids": num_pids,
|
||||||
|
"min_labeled_pid": int(min_pid),
|
||||||
|
"max_labeled_pid": int(max_pid),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unlabeled_pid = pid_list[-1]
|
||||||
|
pid_list = pid_list[:-1] # remove unlabeled pid
|
||||||
|
num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
|
||||||
|
statistics.update(
|
||||||
|
{
|
||||||
|
"num_labeled_pids": num_pids,
|
||||||
|
"min_labeled_pid": int(min_pid),
|
||||||
|
"max_labeled_pid": int(max_pid),
|
||||||
|
"unlabeled_pid": int(unlabeled_pid),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"=> {dataset.name}-{dataset.split} loaded:\n" + create_small_table(statistics))
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(dataset_name, root, transforms, split, verbose=True):
|
||||||
|
if dataset_name == "CUHK-SYSU":
|
||||||
|
dataset = CUHKSYSU(root, transforms, split)
|
||||||
|
elif dataset_name == "PRW":
|
||||||
|
dataset = PRW(root, transforms, split)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknow dataset: {dataset_name}")
|
||||||
|
if verbose:
|
||||||
|
print_statistics(dataset)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
return tuple(zip(*batch))
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_loader(cfg):
|
||||||
|
transforms = build_transforms(cfg, is_train=True)
|
||||||
|
dataset = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "train")
|
||||||
|
return torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=cfg.INPUT.BATCH_SIZE_TRAIN,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=cfg.INPUT.NUM_WORKERS_TRAIN,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_loader(cfg):
|
||||||
|
transforms = build_transforms(cfg, is_train=False)
|
||||||
|
gallery_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "gallery")
|
||||||
|
query_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "query")
|
||||||
|
gallery_loader = torch.utils.data.DataLoader(
|
||||||
|
gallery_set,
|
||||||
|
batch_size=cfg.INPUT.BATCH_SIZE_TEST,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=cfg.INPUT.NUM_WORKERS_TEST,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
query_loader = torch.utils.data.DataLoader(
|
||||||
|
query_set,
|
||||||
|
batch_size=cfg.INPUT.BATCH_SIZE_TEST,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=cfg.INPUT.NUM_WORKERS_TEST,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
return gallery_loader, query_loader
|
|
@ -0,0 +1,121 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import loadmat
|
||||||
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
class CUHKSYSU(BaseDataset):
|
||||||
|
def __init__(self, root, transforms, split):
|
||||||
|
self.name = "CUHK-SYSU"
|
||||||
|
self.img_prefix = osp.join(root, "Image", "SSM")
|
||||||
|
super(CUHKSYSU, self).__init__(root, transforms, split)
|
||||||
|
|
||||||
|
def _load_queries(self):
|
||||||
|
# TestG50: a test protocol, 50 gallery images per query
|
||||||
|
protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat"))
|
||||||
|
protoc = protoc["TestG50"].squeeze()
|
||||||
|
queries = []
|
||||||
|
for item in protoc["Query"]:
|
||||||
|
img_name = str(item["imname"][0, 0][0])
|
||||||
|
roi = item["idlocate"][0, 0][0].astype(np.int32)
|
||||||
|
roi[2:] += roi[:2]
|
||||||
|
queries.append(
|
||||||
|
{
|
||||||
|
"img_name": img_name,
|
||||||
|
"img_path": osp.join(self.img_prefix, img_name),
|
||||||
|
"boxes": roi[np.newaxis, :],
|
||||||
|
"pids": np.array([-100]), # dummy pid
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def _load_split_img_names(self):
|
||||||
|
"""
|
||||||
|
Load the image names for the specific split.
|
||||||
|
"""
|
||||||
|
assert self.split in ("train", "gallery")
|
||||||
|
# gallery images
|
||||||
|
gallery_imgs = loadmat(osp.join(self.root, "annotation", "pool.mat"))
|
||||||
|
gallery_imgs = gallery_imgs["pool"].squeeze()
|
||||||
|
gallery_imgs = [str(a[0]) for a in gallery_imgs]
|
||||||
|
if self.split == "gallery":
|
||||||
|
return gallery_imgs
|
||||||
|
# all images
|
||||||
|
all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat"))
|
||||||
|
all_imgs = all_imgs["Img"].squeeze()
|
||||||
|
all_imgs = [str(a[0][0]) for a in all_imgs]
|
||||||
|
# training images = all images - gallery images
|
||||||
|
training_imgs = sorted(list(set(all_imgs) - set(gallery_imgs)))
|
||||||
|
return training_imgs
|
||||||
|
|
||||||
|
def _load_annotations(self):
|
||||||
|
if self.split == "query":
|
||||||
|
return self._load_queries()
|
||||||
|
|
||||||
|
# load all images and build a dict from image to boxes
|
||||||
|
all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat"))
|
||||||
|
all_imgs = all_imgs["Img"].squeeze()
|
||||||
|
name_to_boxes = {}
|
||||||
|
name_to_pids = {}
|
||||||
|
unlabeled_pid = 5555 # default pid for unlabeled people
|
||||||
|
for img_name, _, boxes in all_imgs:
|
||||||
|
img_name = str(img_name[0])
|
||||||
|
boxes = np.asarray([b[0] for b in boxes[0]])
|
||||||
|
boxes = boxes.reshape(boxes.shape[0], 4) # (x1, y1, w, h)
|
||||||
|
valid_index = np.where((boxes[:, 2] > 0) & (boxes[:, 3] > 0))[0]
|
||||||
|
assert valid_index.size > 0, "Warning: {} has no valid boxes.".format(img_name)
|
||||||
|
boxes = boxes[valid_index]
|
||||||
|
name_to_boxes[img_name] = boxes.astype(np.int32)
|
||||||
|
name_to_pids[img_name] = unlabeled_pid * np.ones(boxes.shape[0], dtype=np.int32)
|
||||||
|
|
||||||
|
def set_box_pid(boxes, box, pids, pid):
|
||||||
|
for i in range(boxes.shape[0]):
|
||||||
|
if np.all(boxes[i] == box):
|
||||||
|
pids[i] = pid
|
||||||
|
return
|
||||||
|
|
||||||
|
# assign a unique pid from 1 to N for each identity
|
||||||
|
if self.split == "train":
|
||||||
|
train = loadmat(osp.join(self.root, "annotation/test/train_test/Train.mat"))
|
||||||
|
train = train["Train"].squeeze()
|
||||||
|
for index, item in enumerate(train):
|
||||||
|
scenes = item[0, 0][2].squeeze()
|
||||||
|
for img_name, box, _ in scenes:
|
||||||
|
img_name = str(img_name[0])
|
||||||
|
box = box.squeeze().astype(np.int32)
|
||||||
|
set_box_pid(name_to_boxes[img_name], box, name_to_pids[img_name], index + 1)
|
||||||
|
else:
|
||||||
|
protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat"))
|
||||||
|
protoc = protoc["TestG50"].squeeze()
|
||||||
|
for index, item in enumerate(protoc):
|
||||||
|
# query
|
||||||
|
im_name = str(item["Query"][0, 0][0][0])
|
||||||
|
box = item["Query"][0, 0][1].squeeze().astype(np.int32)
|
||||||
|
set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1)
|
||||||
|
# gallery
|
||||||
|
gallery = item["Gallery"].squeeze()
|
||||||
|
for im_name, box, _ in gallery:
|
||||||
|
im_name = str(im_name[0])
|
||||||
|
if box.size == 0:
|
||||||
|
break
|
||||||
|
box = box.squeeze().astype(np.int32)
|
||||||
|
set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1)
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
imgs = self._load_split_img_names()
|
||||||
|
for img_name in imgs:
|
||||||
|
boxes = name_to_boxes[img_name]
|
||||||
|
boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2)
|
||||||
|
pids = name_to_pids[img_name]
|
||||||
|
annotations.append(
|
||||||
|
{
|
||||||
|
"img_name": img_name,
|
||||||
|
"img_path": osp.join(self.img_prefix, img_name),
|
||||||
|
"boxes": boxes,
|
||||||
|
"pids": pids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return annotations
|
|
@ -0,0 +1,97 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import os.path as osp
|
||||||
|
import re
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import loadmat
|
||||||
|
|
||||||
|
from .base import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
|
class PRW(BaseDataset):
|
||||||
|
def __init__(self, root, transforms, split):
|
||||||
|
self.name = "PRW"
|
||||||
|
self.img_prefix = osp.join(root, "frames")
|
||||||
|
super(PRW, self).__init__(root, transforms, split)
|
||||||
|
|
||||||
|
def _get_cam_id(self, img_name):
|
||||||
|
match = re.search(r"c\d", img_name).group().replace("c", "")
|
||||||
|
return int(match)
|
||||||
|
|
||||||
|
def _load_queries(self):
|
||||||
|
query_info = osp.join(self.root, "query_info.txt")
|
||||||
|
with open(query_info, "rb") as f:
|
||||||
|
raw = f.readlines()
|
||||||
|
|
||||||
|
queries = []
|
||||||
|
for line in raw:
|
||||||
|
linelist = str(line, "utf-8").split(" ")
|
||||||
|
pid = int(linelist[0])
|
||||||
|
x, y, w, h = (
|
||||||
|
float(linelist[1]),
|
||||||
|
float(linelist[2]),
|
||||||
|
float(linelist[3]),
|
||||||
|
float(linelist[4]),
|
||||||
|
)
|
||||||
|
roi = np.array([x, y, x + w, y + h]).astype(np.int32)
|
||||||
|
roi = np.clip(roi, 0, None) # several coordinates are negative
|
||||||
|
img_name = linelist[5][:-2] + ".jpg"
|
||||||
|
queries.append(
|
||||||
|
{
|
||||||
|
"img_name": img_name,
|
||||||
|
"img_path": osp.join(self.img_prefix, img_name),
|
||||||
|
"boxes": roi[np.newaxis, :],
|
||||||
|
"pids": np.array([pid]),
|
||||||
|
"cam_id": self._get_cam_id(img_name),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return queries
|
||||||
|
|
||||||
|
def _load_split_img_names(self):
|
||||||
|
"""
|
||||||
|
Load the image names for the specific split.
|
||||||
|
"""
|
||||||
|
assert self.split in ("train", "gallery")
|
||||||
|
if self.split == "train":
|
||||||
|
imgs = loadmat(osp.join(self.root, "frame_train.mat"))["img_index_train"]
|
||||||
|
else:
|
||||||
|
imgs = loadmat(osp.join(self.root, "frame_test.mat"))["img_index_test"]
|
||||||
|
return [img[0][0] + ".jpg" for img in imgs]
|
||||||
|
|
||||||
|
def _load_annotations(self):
|
||||||
|
if self.split == "query":
|
||||||
|
return self._load_queries()
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
imgs = self._load_split_img_names()
|
||||||
|
for img_name in imgs:
|
||||||
|
anno_path = osp.join(self.root, "annotations", img_name)
|
||||||
|
anno = loadmat(anno_path)
|
||||||
|
box_key = "box_new"
|
||||||
|
if box_key not in anno.keys():
|
||||||
|
box_key = "anno_file"
|
||||||
|
if box_key not in anno.keys():
|
||||||
|
box_key = "anno_previous"
|
||||||
|
|
||||||
|
rois = anno[box_key][:, 1:]
|
||||||
|
ids = anno[box_key][:, 0]
|
||||||
|
rois = np.clip(rois, 0, None) # several coordinates are negative
|
||||||
|
|
||||||
|
assert len(rois) == len(ids)
|
||||||
|
|
||||||
|
rois[:, 2:] += rois[:, :2]
|
||||||
|
ids[ids == -2] = 5555 # assign pid = 5555 for unlabeled people
|
||||||
|
annotations.append(
|
||||||
|
{
|
||||||
|
"img_name": img_name,
|
||||||
|
"img_path": osp.join(self.img_prefix, img_name),
|
||||||
|
"boxes": rois.astype(np.int32),
|
||||||
|
# (training pids) 1, 2,..., 478, 480, 481, 482, 483, 932, 5555
|
||||||
|
"pids": ids.astype(np.int32),
|
||||||
|
"cam_id": self._get_cam_id(img_name),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return annotations
|
|
@ -0,0 +1,219 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
from yacs.config import CfgNode as CN
|
||||||
|
|
||||||
|
_C = CN()
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Input #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.INPUT = CN()
|
||||||
|
_C.INPUT.DATASET = "CUHK-SYSU"
|
||||||
|
_C.INPUT.DATA_ROOT = "data/CUHK-SYSU"
|
||||||
|
|
||||||
|
# Size of the smallest side of the image
|
||||||
|
_C.INPUT.MIN_SIZE = 900
|
||||||
|
# Maximum size of the side of the image
|
||||||
|
_C.INPUT.MAX_SIZE = 1500
|
||||||
|
|
||||||
|
# Number of images per batch
|
||||||
|
_C.INPUT.BATCH_SIZE_TRAIN = 5
|
||||||
|
_C.INPUT.BATCH_SIZE_TEST = 1
|
||||||
|
|
||||||
|
# Number of data loading threads
|
||||||
|
_C.INPUT.NUM_WORKERS_TRAIN = 5
|
||||||
|
_C.INPUT.NUM_WORKERS_TEST = 1
|
||||||
|
|
||||||
|
# Image augmentation
|
||||||
|
_C.INPUT.IMAGE_CUTOUT = False
|
||||||
|
_C.INPUT.IMAGE_ERASE = False
|
||||||
|
_C.INPUT.IMAGE_MIXUP = False
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# GRID #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.INPUT.IMAGE_GRID = False
|
||||||
|
_C.GRID = CN()
|
||||||
|
_C.GRID.ROTATE = 1
|
||||||
|
_C.GRID.OFFSET = 0
|
||||||
|
_C.GRID.RATIO = 0.5
|
||||||
|
_C.GRID.MODE = 1
|
||||||
|
_C.GRID.PROB = 0.5
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Solver #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.SOLVER = CN()
|
||||||
|
_C.SOLVER.MAX_EPOCHS = 13
|
||||||
|
|
||||||
|
# Learning rate settings
|
||||||
|
_C.SOLVER.BASE_LR = 0.003
|
||||||
|
|
||||||
|
# The epoch milestones to decrease the learning rate by GAMMA
|
||||||
|
_C.SOLVER.LR_DECAY_MILESTONES = [10, 14]
|
||||||
|
_C.SOLVER.GAMMA = 0.1
|
||||||
|
|
||||||
|
_C.SOLVER.WEIGHT_DECAY = 0.0005
|
||||||
|
_C.SOLVER.SGD_MOMENTUM = 0.9
|
||||||
|
|
||||||
|
# Loss weight of RPN regression
|
||||||
|
_C.SOLVER.LW_RPN_REG = 1
|
||||||
|
# Loss weight of RPN classification
|
||||||
|
_C.SOLVER.LW_RPN_CLS = 1
|
||||||
|
|
||||||
|
# Loss weight of Cascade R-CNN and Re-ID (OIM)
|
||||||
|
_C.SOLVER.LW_RCNN_REG_1ST = 10
|
||||||
|
_C.SOLVER.LW_RCNN_CLS_1ST = 1
|
||||||
|
_C.SOLVER.LW_RCNN_REG_2ND = 10
|
||||||
|
_C.SOLVER.LW_RCNN_CLS_2ND = 1
|
||||||
|
_C.SOLVER.LW_RCNN_REG_3RD = 10
|
||||||
|
_C.SOLVER.LW_RCNN_CLS_3RD = 1
|
||||||
|
_C.SOLVER.LW_RCNN_REID_2ND = 0.5
|
||||||
|
_C.SOLVER.LW_RCNN_REID_3RD = 0.5
|
||||||
|
# Loss weight of box reid, softmax loss
|
||||||
|
_C.SOLVER.LW_RCNN_SOFTMAX_2ND = 0.5
|
||||||
|
_C.SOLVER.LW_RCNN_SOFTMAX_3RD = 0.5
|
||||||
|
|
||||||
|
# Set to negative value to disable gradient clipping
|
||||||
|
_C.SOLVER.CLIP_GRADIENTS = 10.0
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# RPN #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.MODEL = CN()
|
||||||
|
_C.MODEL.RPN = CN()
|
||||||
|
# NMS threshold used on RoIs
|
||||||
|
_C.MODEL.RPN.NMS_THRESH = 0.7
|
||||||
|
# Number of anchors per image used to train RPN
|
||||||
|
_C.MODEL.RPN.BATCH_SIZE_TRAIN = 256
|
||||||
|
# Target fraction of foreground examples per RPN minibatch
|
||||||
|
_C.MODEL.RPN.POS_FRAC_TRAIN = 0.5
|
||||||
|
# Overlap threshold for an anchor to be considered foreground (if >= POS_THRESH_TRAIN)
|
||||||
|
_C.MODEL.RPN.POS_THRESH_TRAIN = 0.7
|
||||||
|
# Overlap threshold for an anchor to be considered background (if < NEG_THRESH_TRAIN)
|
||||||
|
_C.MODEL.RPN.NEG_THRESH_TRAIN = 0.3
|
||||||
|
# Number of top scoring RPN RoIs to keep before applying NMS
|
||||||
|
_C.MODEL.RPN.PRE_NMS_TOPN_TRAIN = 12000
|
||||||
|
_C.MODEL.RPN.PRE_NMS_TOPN_TEST = 6000
|
||||||
|
# Number of top scoring RPN RoIs to keep after applying NMS
|
||||||
|
_C.MODEL.RPN.POST_NMS_TOPN_TRAIN = 2000
|
||||||
|
_C.MODEL.RPN.POST_NMS_TOPN_TEST = 300
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# RoI head #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.MODEL.ROI_HEAD = CN()
|
||||||
|
# Whether to use bn neck (i.e. batch normalization after linear)
|
||||||
|
_C.MODEL.ROI_HEAD.BN_NECK = True
|
||||||
|
# Number of RoIs per image used to train RoI head
|
||||||
|
_C.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN = 128
|
||||||
|
# Target fraction of foreground examples per RoI minibatch
|
||||||
|
_C.MODEL.ROI_HEAD.POS_FRAC_TRAIN = 0.25 # 0.5
|
||||||
|
|
||||||
|
_C.MODEL.ROI_HEAD.USE_DIFF_THRESH = True
|
||||||
|
# Overlap threshold for an RoI to be considered foreground (if >= POS_THRESH_TRAIN)
|
||||||
|
_C.MODEL.ROI_HEAD.POS_THRESH_TRAIN = 0.5
|
||||||
|
_C.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND = 0.6
|
||||||
|
_C.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD = 0.7
|
||||||
|
# Overlap threshold for an RoI to be considered background (if < NEG_THRESH_TRAIN)
|
||||||
|
_C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN = 0.5
|
||||||
|
_C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND = 0.6
|
||||||
|
_C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD = 0.7
|
||||||
|
# Minimum score threshold
|
||||||
|
_C.MODEL.ROI_HEAD.SCORE_THRESH_TEST = 0.5
|
||||||
|
# NMS threshold used on boxes
|
||||||
|
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST = 0.4
|
||||||
|
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST = 0.4
|
||||||
|
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND = 0.4
|
||||||
|
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD = 0.5
|
||||||
|
# Maximum number of detected objects
|
||||||
|
_C.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST = 300
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Transformer head #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.MODEL.TRANSFORMER = CN()
|
||||||
|
_C.MODEL.TRANSFORMER.DIM_MODEL = 512
|
||||||
|
_C.MODEL.TRANSFORMER.ENCODER_LAYERS = 1
|
||||||
|
_C.MODEL.TRANSFORMER.N_HEAD = 8
|
||||||
|
_C.MODEL.TRANSFORMER.USE_OUTPUT_LAYER = False
|
||||||
|
_C.MODEL.TRANSFORMER.DROPOUT = 0.
|
||||||
|
_C.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT = True
|
||||||
|
_C.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT = True
|
||||||
|
|
||||||
|
_C.MODEL.TRANSFORMER.USE_DIFF_SCALE = True
|
||||||
|
_C.MODEL.TRANSFORMER.NAMES_1ST = ['scale1','scale2']
|
||||||
|
_C.MODEL.TRANSFORMER.NAMES_2ND = ['scale1','scale2']
|
||||||
|
_C.MODEL.TRANSFORMER.NAMES_3RD = ['scale1','scale2']
|
||||||
|
_C.MODEL.TRANSFORMER.KERNEL_SIZE_1ST = [(1,1),(3,3)]
|
||||||
|
_C.MODEL.TRANSFORMER.KERNEL_SIZE_2ND = [(1,1),(3,3)]
|
||||||
|
_C.MODEL.TRANSFORMER.KERNEL_SIZE_3RD = [(1,1),(3,3)]
|
||||||
|
_C.MODEL.TRANSFORMER.USE_MASK_1ST = False
|
||||||
|
_C.MODEL.TRANSFORMER.USE_MASK_2ND = True
|
||||||
|
_C.MODEL.TRANSFORMER.USE_MASK_3RD = True
|
||||||
|
_C.MODEL.TRANSFORMER.USE_PATCH2VEC = True
|
||||||
|
|
||||||
|
####
|
||||||
|
_C.MODEL.USE_FEATURE_MASK = True
|
||||||
|
_C.MODEL.FEATURE_AUG_TYPE = 'exchange_token' # 'exchange_token', 'jigsaw_token', 'cutout_patch', 'erase_patch', 'mixup_patch', 'jigsaw_patch'
|
||||||
|
_C.MODEL.FEATURE_MASK_SIZE = 4
|
||||||
|
_C.MODEL.MASK_SHAPE = 'stripe' # 'square', 'random'
|
||||||
|
_C.MODEL.MASK_SIZE = 1
|
||||||
|
_C.MODEL.MASK_MODE = 'random_direction' # 'horizontal', 'vertical' for stripe; 'random_size' for square
|
||||||
|
_C.MODEL.MASK_PERCENT = 0.1
|
||||||
|
####
|
||||||
|
_C.MODEL.EMBEDDING_DIM = 256
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Loss #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
_C.MODEL.LOSS = CN()
|
||||||
|
# Size of the lookup table in OIM
|
||||||
|
_C.MODEL.LOSS.LUT_SIZE = 5532
|
||||||
|
# Size of the circular queue in OIM
|
||||||
|
_C.MODEL.LOSS.CQ_SIZE = 5000
|
||||||
|
_C.MODEL.LOSS.OIM_MOMENTUM = 0.5
|
||||||
|
_C.MODEL.LOSS.OIM_SCALAR = 30.0
|
||||||
|
|
||||||
|
_C.MODEL.LOSS.USE_SOFTMAX = True
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Evaluation #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# The period to evaluate the model during training
|
||||||
|
_C.EVAL_PERIOD = 1
|
||||||
|
# Evaluation with GT boxes to verify the upper bound of person search performance
|
||||||
|
_C.EVAL_USE_GT = False
|
||||||
|
# Fast evaluation with cached features
|
||||||
|
_C.EVAL_USE_CACHE = False
|
||||||
|
# Evaluation with Context Bipartite Graph Matching (CBGM) algorithm
|
||||||
|
_C.EVAL_USE_CBGM = False
|
||||||
|
# Gallery size in evaluation, only for CUHK-SYSU
|
||||||
|
_C.EVAL_GALLERY_SIZE = 100
|
||||||
|
# Feature used for evaluation
|
||||||
|
_C.EVAL_FEATURE = 'concat' # 'stage2', 'stage3'
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Miscs #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Save a checkpoint after every this number of epochs
|
||||||
|
_C.CKPT_PERIOD = 1
|
||||||
|
# The period (in terms of iterations) to display training losses
|
||||||
|
_C.DISP_PERIOD = 10
|
||||||
|
# Whether to use tensorboard for visualization
|
||||||
|
_C.TF_BOARD = True
|
||||||
|
# The device loading the model
|
||||||
|
_C.DEVICE = "cuda:0"
|
||||||
|
# Set seed to negative to fully randomize everything
|
||||||
|
_C.SEED = 1
|
||||||
|
# Directory where output files are written
|
||||||
|
_C.OUTPUT_DIR = "./output"
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_cfg():
|
||||||
|
"""
|
||||||
|
Get a copy of the default config.
|
||||||
|
"""
|
||||||
|
return _C.clone()
|
Binary file not shown.
After Width: | Height: | Size: 1.4 MiB |
|
@ -0,0 +1,179 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from eval_func import eval_detection, eval_search_cuhk, eval_search_prw
|
||||||
|
from utils.utils import MetricLogger, SmoothedValue, mkdir, reduce_dict, warmup_lr_scheduler
|
||||||
|
from utils.transforms import mixup_data
|
||||||
|
|
||||||
|
|
||||||
|
def to_device(images, targets, device):
|
||||||
|
images = [image.to(device) for image in images]
|
||||||
|
for t in targets:
|
||||||
|
t["boxes"] = t["boxes"].to(device)
|
||||||
|
t["labels"] = t["labels"].to(device)
|
||||||
|
return images, targets
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(cfg, model, optimizer, data_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3):
|
||||||
|
model.train()
|
||||||
|
metric_logger = MetricLogger(delimiter=" ")
|
||||||
|
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
|
||||||
|
header = "Epoch: [{}]".format(epoch)
|
||||||
|
|
||||||
|
# warmup learning rate in the first epoch
|
||||||
|
if epoch == 0:
|
||||||
|
warmup_factor = 1.0 / 1000
|
||||||
|
warmup_iters = len(data_loader) - 1
|
||||||
|
warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
|
||||||
|
|
||||||
|
for i, (images, targets) in enumerate(
|
||||||
|
metric_logger.log_every(data_loader, cfg.DISP_PERIOD, header)
|
||||||
|
):
|
||||||
|
images, targets = to_device(images, targets, device)
|
||||||
|
|
||||||
|
# if using image based data augmentation
|
||||||
|
if cfg.INPUT.IMAGE_MIXUP:
|
||||||
|
images = mixup_data(images, alpha=0.8)
|
||||||
|
|
||||||
|
loss_dict, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = model(images, targets)
|
||||||
|
|
||||||
|
if cfg.MODEL.LOSS.USE_SOFTMAX:
|
||||||
|
softmax_loss_2nd = cfg.SOLVER.LW_RCNN_SOFTMAX_2ND * softmax_criterion_s2(feats_reid_2nd, targets_reid_2nd)
|
||||||
|
softmax_loss_3rd = cfg.SOLVER.LW_RCNN_SOFTMAX_3RD * softmax_criterion_s3(feats_reid_3rd, targets_reid_3rd)
|
||||||
|
loss_dict.update(loss_box_softmax_2nd=softmax_loss_2nd)
|
||||||
|
loss_dict.update(loss_box_softmax_3rd=softmax_loss_3rd)
|
||||||
|
|
||||||
|
losses = sum(loss for loss in loss_dict.values())
|
||||||
|
|
||||||
|
# reduce losses over all GPUs for logging purposes
|
||||||
|
loss_dict_reduced = reduce_dict(loss_dict)
|
||||||
|
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
||||||
|
loss_value = losses_reduced.item()
|
||||||
|
|
||||||
|
if not math.isfinite(loss_value):
|
||||||
|
print(f"Loss is {loss_value}, stopping training")
|
||||||
|
print(loss_dict_reduced)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
losses.backward()
|
||||||
|
if cfg.SOLVER.CLIP_GRADIENTS > 0:
|
||||||
|
clip_grad_norm_(model.parameters(), cfg.SOLVER.CLIP_GRADIENTS)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if epoch == 0:
|
||||||
|
warmup_scheduler.step()
|
||||||
|
|
||||||
|
metric_logger.update(loss=loss_value, **loss_dict_reduced)
|
||||||
|
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||||
|
if tfboard:
|
||||||
|
iter = epoch * len(data_loader) + i
|
||||||
|
for k, v in loss_dict_reduced.items():
|
||||||
|
tfboard.add_scalars("train", {k: v}, iter)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_performance(
|
||||||
|
model, gallery_loader, query_loader, device, use_gt=False, use_cache=False, use_cbgm=False, gallery_size=100):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
use_gt (bool, optional): Whether to use GT as detection results to verify the upper
|
||||||
|
bound of person search performance. Defaults to False.
|
||||||
|
use_cache (bool, optional): Whether to use the cached features. Defaults to False.
|
||||||
|
use_cbgm (bool, optional): Whether to use Context Bipartite Graph Matching algorithm.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
if use_cache:
|
||||||
|
eval_cache = torch.load("data/eval_cache/eval_cache.pth")
|
||||||
|
gallery_dets = eval_cache["gallery_dets"]
|
||||||
|
gallery_feats = eval_cache["gallery_feats"]
|
||||||
|
query_dets = eval_cache["query_dets"]
|
||||||
|
query_feats = eval_cache["query_feats"]
|
||||||
|
query_box_feats = eval_cache["query_box_feats"]
|
||||||
|
else:
|
||||||
|
gallery_dets, gallery_feats = [], []
|
||||||
|
for images, targets in tqdm(gallery_loader, ncols=0):
|
||||||
|
images, targets = to_device(images, targets, device)
|
||||||
|
if not use_gt:
|
||||||
|
outputs = model(images)
|
||||||
|
else:
|
||||||
|
boxes = targets[0]["boxes"]
|
||||||
|
n_boxes = boxes.size(0)
|
||||||
|
embeddings = model(images, targets)
|
||||||
|
outputs = [
|
||||||
|
{
|
||||||
|
"boxes": boxes,
|
||||||
|
"embeddings": torch.cat(embeddings),
|
||||||
|
"labels": torch.ones(n_boxes).to(device),
|
||||||
|
"scores": torch.ones(n_boxes).to(device),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1)
|
||||||
|
gallery_dets.append(box_w_scores.cpu().numpy())
|
||||||
|
gallery_feats.append(output["embeddings"].cpu().numpy())
|
||||||
|
|
||||||
|
# regarding query image as gallery to detect all people
|
||||||
|
# i.e. query person + surrounding people (context information)
|
||||||
|
query_dets, query_feats = [], []
|
||||||
|
for images, targets in tqdm(query_loader, ncols=0):
|
||||||
|
images, targets = to_device(images, targets, device)
|
||||||
|
# targets will be modified in the model, so deepcopy it
|
||||||
|
outputs = model(images, deepcopy(targets), query_img_as_gallery=True)
|
||||||
|
|
||||||
|
# consistency check
|
||||||
|
gt_box = targets[0]["boxes"].squeeze()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
gt_box - outputs[0]["boxes"][0]
|
||||||
|
).sum() <= 0.001, "GT box must be the first one in the detected boxes of query image"
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1)
|
||||||
|
query_dets.append(box_w_scores.cpu().numpy())
|
||||||
|
query_feats.append(output["embeddings"].cpu().numpy())
|
||||||
|
|
||||||
|
# extract the features of query boxes
|
||||||
|
query_box_feats = []
|
||||||
|
for images, targets in tqdm(query_loader, ncols=0):
|
||||||
|
images, targets = to_device(images, targets, device)
|
||||||
|
embeddings = model(images, targets)
|
||||||
|
assert len(embeddings) == 1, "batch size in test phase should be 1"
|
||||||
|
query_box_feats.append(embeddings[0].cpu().numpy())
|
||||||
|
|
||||||
|
mkdir("data/eval_cache")
|
||||||
|
save_dict = {
|
||||||
|
"gallery_dets": gallery_dets,
|
||||||
|
"gallery_feats": gallery_feats,
|
||||||
|
"query_dets": query_dets,
|
||||||
|
"query_feats": query_feats,
|
||||||
|
"query_box_feats": query_box_feats,
|
||||||
|
}
|
||||||
|
torch.save(save_dict, "data/eval_cache/eval_cache.pth")
|
||||||
|
|
||||||
|
eval_detection(gallery_loader.dataset, gallery_dets, det_thresh=0.01)
|
||||||
|
eval_search_func = (
|
||||||
|
eval_search_cuhk if gallery_loader.dataset.name == "CUHK-SYSU" else eval_search_prw
|
||||||
|
)
|
||||||
|
eval_search_func(
|
||||||
|
gallery_loader.dataset,
|
||||||
|
query_loader.dataset,
|
||||||
|
gallery_dets,
|
||||||
|
gallery_feats,
|
||||||
|
query_box_feats,
|
||||||
|
query_dets,
|
||||||
|
query_feats,
|
||||||
|
cbgm=use_cbgm,
|
||||||
|
gallery_size=gallery_size,
|
||||||
|
)
|
|
@ -0,0 +1,488 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
from scipy.io import loadmat
|
||||||
|
from sklearn.metrics import average_precision_score
|
||||||
|
|
||||||
|
from utils.km import run_kuhn_munkres
|
||||||
|
from utils.utils import write_json
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_iou(a, b):
|
||||||
|
x1 = max(a[0], b[0])
|
||||||
|
y1 = max(a[1], b[1])
|
||||||
|
x2 = min(a[2], b[2])
|
||||||
|
y2 = min(a[3], b[3])
|
||||||
|
inter = max(0, x2 - x1) * max(0, y2 - y1)
|
||||||
|
union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
|
||||||
|
return inter * 1.0 / union
|
||||||
|
|
||||||
|
|
||||||
|
def eval_detection(
|
||||||
|
gallery_dataset, gallery_dets, det_thresh=0.5, iou_thresh=0.5, labeled_only=False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
gallery_det (list of ndarray): n_det x [x1, y1, x2, y2, score] per image
|
||||||
|
det_thresh (float): filter out gallery detections whose scores below this
|
||||||
|
iou_thresh (float): treat as true positive if IoU is above this threshold
|
||||||
|
labeled_only (bool): filter out unlabeled background people
|
||||||
|
"""
|
||||||
|
assert len(gallery_dataset) == len(gallery_dets)
|
||||||
|
annos = gallery_dataset.annotations
|
||||||
|
|
||||||
|
y_true, y_score = [], []
|
||||||
|
count_gt, count_tp = 0, 0
|
||||||
|
for anno, det in zip(annos, gallery_dets):
|
||||||
|
gt_boxes = anno["boxes"]
|
||||||
|
if labeled_only:
|
||||||
|
# exclude the unlabeled people (pid == 5555)
|
||||||
|
inds = np.where(anno["pids"].ravel() != 5555)[0]
|
||||||
|
if len(inds) == 0:
|
||||||
|
continue
|
||||||
|
gt_boxes = gt_boxes[inds]
|
||||||
|
num_gt = gt_boxes.shape[0]
|
||||||
|
|
||||||
|
if det != []:
|
||||||
|
det = np.asarray(det)
|
||||||
|
inds = np.where(det[:, 4].ravel() >= det_thresh)[0]
|
||||||
|
det = det[inds]
|
||||||
|
num_det = det.shape[0]
|
||||||
|
else:
|
||||||
|
num_det = 0
|
||||||
|
if num_det == 0:
|
||||||
|
count_gt += num_gt
|
||||||
|
continue
|
||||||
|
|
||||||
|
ious = np.zeros((num_gt, num_det), dtype=np.float32)
|
||||||
|
for i in range(num_gt):
|
||||||
|
for j in range(num_det):
|
||||||
|
ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4])
|
||||||
|
tfmat = ious >= iou_thresh
|
||||||
|
# for each det, keep only the largest iou of all the gt
|
||||||
|
for j in range(num_det):
|
||||||
|
largest_ind = np.argmax(ious[:, j])
|
||||||
|
for i in range(num_gt):
|
||||||
|
if i != largest_ind:
|
||||||
|
tfmat[i, j] = False
|
||||||
|
# for each gt, keep only the largest iou of all the det
|
||||||
|
for i in range(num_gt):
|
||||||
|
largest_ind = np.argmax(ious[i, :])
|
||||||
|
for j in range(num_det):
|
||||||
|
if j != largest_ind:
|
||||||
|
tfmat[i, j] = False
|
||||||
|
for j in range(num_det):
|
||||||
|
y_score.append(det[j, -1])
|
||||||
|
y_true.append(tfmat[:, j].any())
|
||||||
|
count_tp += tfmat.sum()
|
||||||
|
count_gt += num_gt
|
||||||
|
|
||||||
|
det_rate = count_tp * 1.0 / count_gt
|
||||||
|
ap = average_precision_score(y_true, y_score) * det_rate
|
||||||
|
|
||||||
|
print("{} detection:".format("labeled only" if labeled_only else "all"))
|
||||||
|
print(" recall = {:.2%}".format(det_rate))
|
||||||
|
if not labeled_only:
|
||||||
|
print(" ap = {:.2%}".format(ap))
|
||||||
|
return det_rate, ap
|
||||||
|
|
||||||
|
|
||||||
|
def eval_search_cuhk(
|
||||||
|
gallery_dataset,
|
||||||
|
query_dataset,
|
||||||
|
gallery_dets,
|
||||||
|
gallery_feats,
|
||||||
|
query_box_feats,
|
||||||
|
query_dets,
|
||||||
|
query_feats,
|
||||||
|
k1=10,
|
||||||
|
k2=3,
|
||||||
|
det_thresh=0.5,
|
||||||
|
cbgm=False,
|
||||||
|
gallery_size=100,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
gallery_dataset/query_dataset: an instance of BaseDataset
|
||||||
|
gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image
|
||||||
|
gallery_feat (list of ndarray): n_det x D features per image
|
||||||
|
query_feat (list of ndarray): D dimensional features per query image
|
||||||
|
det_thresh (float): filter out gallery detections whose scores below this
|
||||||
|
gallery_size (int): gallery size [-1, 50, 100, 500, 1000, 2000, 4000]
|
||||||
|
-1 for using full set
|
||||||
|
"""
|
||||||
|
assert len(gallery_dataset) == len(gallery_dets)
|
||||||
|
assert len(gallery_dataset) == len(gallery_feats)
|
||||||
|
assert len(query_dataset) == len(query_box_feats)
|
||||||
|
|
||||||
|
use_full_set = gallery_size == -1
|
||||||
|
fname = "TestG{}".format(gallery_size if not use_full_set else 50)
|
||||||
|
protoc = loadmat(osp.join(gallery_dataset.root, "annotation/test/train_test", fname + ".mat"))
|
||||||
|
protoc = protoc[fname].squeeze()
|
||||||
|
|
||||||
|
# mapping from gallery image to (det, feat)
|
||||||
|
annos = gallery_dataset.annotations
|
||||||
|
name_to_det_feat = {}
|
||||||
|
for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
|
||||||
|
name = anno["img_name"]
|
||||||
|
if det != []:
|
||||||
|
scores = det[:, 4].ravel()
|
||||||
|
inds = np.where(scores >= det_thresh)[0]
|
||||||
|
if len(inds) > 0:
|
||||||
|
name_to_det_feat[name] = (det[inds], feat[inds])
|
||||||
|
|
||||||
|
aps = []
|
||||||
|
accs = []
|
||||||
|
topk = [1, 5, 10]
|
||||||
|
ret = {"image_root": gallery_dataset.img_prefix, "results": []}
|
||||||
|
for i in range(len(query_dataset)):
|
||||||
|
y_true, y_score = [], []
|
||||||
|
imgs, rois = [], []
|
||||||
|
count_gt, count_tp = 0, 0
|
||||||
|
# get L2-normalized feature vector
|
||||||
|
feat_q = query_box_feats[i].ravel()
|
||||||
|
# ignore the query image
|
||||||
|
query_imname = str(protoc["Query"][i]["imname"][0, 0][0])
|
||||||
|
query_roi = protoc["Query"][i]["idlocate"][0, 0][0].astype(np.int32)
|
||||||
|
query_roi[2:] += query_roi[:2]
|
||||||
|
query_gt = []
|
||||||
|
tested = set([query_imname])
|
||||||
|
|
||||||
|
name2sim = {}
|
||||||
|
name2gt = {}
|
||||||
|
sims = []
|
||||||
|
imgs_cbgm = []
|
||||||
|
# 1. Go through the gallery samples defined by the protocol
|
||||||
|
for item in protoc["Gallery"][i].squeeze():
|
||||||
|
gallery_imname = str(item[0][0])
|
||||||
|
# some contain the query (gt not empty), some not
|
||||||
|
gt = item[1][0].astype(np.int32)
|
||||||
|
count_gt += gt.size > 0
|
||||||
|
# compute distance between query and gallery dets
|
||||||
|
if gallery_imname not in name_to_det_feat:
|
||||||
|
continue
|
||||||
|
det, feat_g = name_to_det_feat[gallery_imname]
|
||||||
|
# no detection in this gallery, skip it
|
||||||
|
if det.shape[0] == 0:
|
||||||
|
continue
|
||||||
|
# get L2-normalized feature matrix NxD
|
||||||
|
assert feat_g.size == np.prod(feat_g.shape[:2])
|
||||||
|
feat_g = feat_g.reshape(feat_g.shape[:2])
|
||||||
|
# compute cosine similarities
|
||||||
|
sim = feat_g.dot(feat_q).ravel()
|
||||||
|
|
||||||
|
if gallery_imname in name2sim:
|
||||||
|
continue
|
||||||
|
name2sim[gallery_imname] = sim
|
||||||
|
name2gt[gallery_imname] = gt
|
||||||
|
sims.extend(list(sim))
|
||||||
|
imgs_cbgm.extend([gallery_imname] * len(sim))
|
||||||
|
# 2. Go through the remaining gallery images if using full set
|
||||||
|
if use_full_set:
|
||||||
|
for gallery_imname in gallery_dataset.imgs:
|
||||||
|
if gallery_imname in tested:
|
||||||
|
continue
|
||||||
|
if gallery_imname not in name_to_det_feat:
|
||||||
|
continue
|
||||||
|
det, feat_g = name_to_det_feat[gallery_imname]
|
||||||
|
# get L2-normalized feature matrix NxD
|
||||||
|
assert feat_g.size == np.prod(feat_g.shape[:2])
|
||||||
|
feat_g = feat_g.reshape(feat_g.shape[:2])
|
||||||
|
# compute cosine similarities
|
||||||
|
sim = feat_g.dot(feat_q).ravel()
|
||||||
|
# guaranteed no target query in these gallery images
|
||||||
|
label = np.zeros(len(sim), dtype=np.int32)
|
||||||
|
y_true.extend(list(label))
|
||||||
|
y_score.extend(list(sim))
|
||||||
|
imgs.extend([gallery_imname] * len(sim))
|
||||||
|
rois.extend(list(det))
|
||||||
|
|
||||||
|
if cbgm:
|
||||||
|
# -------- Context Bipartite Graph Matching (CBGM) ------- #
|
||||||
|
sims = np.array(sims)
|
||||||
|
imgs_cbgm = np.array(imgs_cbgm)
|
||||||
|
# only process the top-k1 gallery images for efficiency
|
||||||
|
inds = np.argsort(sims)[-k1:]
|
||||||
|
imgs_cbgm = set(imgs_cbgm[inds])
|
||||||
|
for img in imgs_cbgm:
|
||||||
|
sim = name2sim[img]
|
||||||
|
det, feat_g = name_to_det_feat[img]
|
||||||
|
# only regard the people with top-k2 detection confidence
|
||||||
|
# in the query image as context information
|
||||||
|
qboxes = query_dets[i][:k2]
|
||||||
|
qfeats = query_feats[i][:k2]
|
||||||
|
assert (
|
||||||
|
query_roi - qboxes[0][:4]
|
||||||
|
).sum() <= 0.001, "query_roi must be the first one in pboxes"
|
||||||
|
|
||||||
|
# build the bipartite graph and run Kuhn-Munkres (K-M) algorithm
|
||||||
|
# to find the best match
|
||||||
|
graph = []
|
||||||
|
for indx_i, pfeat in enumerate(qfeats):
|
||||||
|
for indx_j, gfeat in enumerate(feat_g):
|
||||||
|
graph.append((indx_i, indx_j, (pfeat * gfeat).sum()))
|
||||||
|
km_res, max_val = run_kuhn_munkres(graph)
|
||||||
|
|
||||||
|
# revise the similarity between query person and its matching
|
||||||
|
for indx_i, indx_j, _ in km_res:
|
||||||
|
# 0 denotes the query roi
|
||||||
|
if indx_i == 0:
|
||||||
|
sim[indx_j] = max_val
|
||||||
|
break
|
||||||
|
for gallery_imname, sim in name2sim.items():
|
||||||
|
gt = name2gt[gallery_imname]
|
||||||
|
det, feat_g = name_to_det_feat[gallery_imname]
|
||||||
|
# assign label for each det
|
||||||
|
label = np.zeros(len(sim), dtype=np.int32)
|
||||||
|
if gt.size > 0:
|
||||||
|
w, h = gt[2], gt[3]
|
||||||
|
gt[2:] += gt[:2]
|
||||||
|
query_gt.append({"img": str(gallery_imname), "roi": list(map(float, list(gt)))})
|
||||||
|
iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
|
||||||
|
inds = np.argsort(sim)[::-1]
|
||||||
|
sim = sim[inds]
|
||||||
|
det = det[inds]
|
||||||
|
# only set the first matched det as true positive
|
||||||
|
for j, roi in enumerate(det[:, :4]):
|
||||||
|
if _compute_iou(roi, gt) >= iou_thresh:
|
||||||
|
label[j] = 1
|
||||||
|
count_tp += 1
|
||||||
|
break
|
||||||
|
y_true.extend(list(label))
|
||||||
|
y_score.extend(list(sim))
|
||||||
|
imgs.extend([gallery_imname] * len(sim))
|
||||||
|
rois.extend(list(det))
|
||||||
|
tested.add(gallery_imname)
|
||||||
|
# 3. Compute AP for this query (need to scale by recall rate)
|
||||||
|
y_score = np.asarray(y_score)
|
||||||
|
y_true = np.asarray(y_true)
|
||||||
|
assert count_tp <= count_gt
|
||||||
|
recall_rate = count_tp * 1.0 / count_gt
|
||||||
|
ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
|
||||||
|
aps.append(ap)
|
||||||
|
inds = np.argsort(y_score)[::-1]
|
||||||
|
y_score = y_score[inds]
|
||||||
|
y_true = y_true[inds]
|
||||||
|
accs.append([min(1, sum(y_true[:k])) for k in topk])
|
||||||
|
# 4. Save result for JSON dump
|
||||||
|
new_entry = {
|
||||||
|
"query_img": str(query_imname),
|
||||||
|
"query_roi": list(map(float, list(query_roi))),
|
||||||
|
"query_gt": query_gt,
|
||||||
|
"gallery": [],
|
||||||
|
}
|
||||||
|
# only record wrong results
|
||||||
|
if int(y_true[0]):
|
||||||
|
continue
|
||||||
|
# only save top-10 predictions
|
||||||
|
for k in range(10):
|
||||||
|
new_entry["gallery"].append(
|
||||||
|
{
|
||||||
|
"img": str(imgs[inds[k]]),
|
||||||
|
"roi": list(map(float, list(rois[inds[k]]))),
|
||||||
|
"score": float(y_score[k]),
|
||||||
|
"correct": int(y_true[k]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ret["results"].append(new_entry)
|
||||||
|
|
||||||
|
print("search ranking:")
|
||||||
|
print(" mAP = {:.2%}".format(np.mean(aps)))
|
||||||
|
accs = np.mean(accs, axis=0)
|
||||||
|
for i, k in enumerate(topk):
|
||||||
|
print(" top-{:2d} = {:.2%}".format(k, accs[i]))
|
||||||
|
|
||||||
|
write_json(ret, "vis/results.json")
|
||||||
|
|
||||||
|
ret["mAP"] = np.mean(aps)
|
||||||
|
ret["accs"] = accs
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def eval_search_prw(
|
||||||
|
gallery_dataset,
|
||||||
|
query_dataset,
|
||||||
|
gallery_dets,
|
||||||
|
gallery_feats,
|
||||||
|
query_box_feats,
|
||||||
|
query_dets,
|
||||||
|
query_feats,
|
||||||
|
k1=30,
|
||||||
|
k2=4,
|
||||||
|
det_thresh=0.5,
|
||||||
|
cbgm=False,
|
||||||
|
gallery_size=None, # not used in PRW
|
||||||
|
ignore_cam_id=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image
|
||||||
|
gallery_feat (list of ndarray): n_det x D features per image
|
||||||
|
query_feat (list of ndarray): D dimensional features per query image
|
||||||
|
det_thresh (float): filter out gallery detections whose scores below this
|
||||||
|
gallery_size (int): -1 for using full set
|
||||||
|
ignore_cam_id (bool): Set to True acoording to CUHK-SYSU,
|
||||||
|
although it's a common practice to focus on cross-cam match only.
|
||||||
|
"""
|
||||||
|
assert len(gallery_dataset) == len(gallery_dets)
|
||||||
|
assert len(gallery_dataset) == len(gallery_feats)
|
||||||
|
assert len(query_dataset) == len(query_box_feats)
|
||||||
|
|
||||||
|
annos = gallery_dataset.annotations
|
||||||
|
name_to_det_feat = {}
|
||||||
|
for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
|
||||||
|
name = anno["img_name"]
|
||||||
|
scores = det[:, 4].ravel()
|
||||||
|
inds = np.where(scores >= det_thresh)[0]
|
||||||
|
if len(inds) > 0:
|
||||||
|
name_to_det_feat[name] = (det[inds], feat[inds])
|
||||||
|
|
||||||
|
aps = []
|
||||||
|
accs = []
|
||||||
|
topk = [1, 5, 10]
|
||||||
|
ret = {"image_root": gallery_dataset.img_prefix, "results": []}
|
||||||
|
for i in range(len(query_dataset)):
|
||||||
|
y_true, y_score = [], []
|
||||||
|
imgs, rois = [], []
|
||||||
|
count_gt, count_tp = 0, 0
|
||||||
|
|
||||||
|
feat_p = query_box_feats[i].ravel()
|
||||||
|
|
||||||
|
query_imname = query_dataset.annotations[i]["img_name"]
|
||||||
|
query_roi = query_dataset.annotations[i]["boxes"]
|
||||||
|
query_pid = query_dataset.annotations[i]["pids"]
|
||||||
|
query_cam = query_dataset.annotations[i]["cam_id"]
|
||||||
|
|
||||||
|
# Find all occurence of this query
|
||||||
|
gallery_imgs = []
|
||||||
|
for x in annos:
|
||||||
|
if query_pid in x["pids"] and x["img_name"] != query_imname:
|
||||||
|
gallery_imgs.append(x)
|
||||||
|
query_gts = {}
|
||||||
|
for item in gallery_imgs:
|
||||||
|
query_gts[item["img_name"]] = item["boxes"][item["pids"] == query_pid]
|
||||||
|
|
||||||
|
# Construct gallery set for this query
|
||||||
|
if ignore_cam_id:
|
||||||
|
gallery_imgs = []
|
||||||
|
for x in annos:
|
||||||
|
if x["img_name"] != query_imname:
|
||||||
|
gallery_imgs.append(x)
|
||||||
|
else:
|
||||||
|
gallery_imgs = []
|
||||||
|
for x in annos:
|
||||||
|
if x["img_name"] != query_imname and x["cam_id"] != query_cam:
|
||||||
|
gallery_imgs.append(x)
|
||||||
|
|
||||||
|
name2sim = {}
|
||||||
|
sims = []
|
||||||
|
imgs_cbgm = []
|
||||||
|
# 1. Go through all gallery samples
|
||||||
|
for item in gallery_imgs:
|
||||||
|
gallery_imname = item["img_name"]
|
||||||
|
# some contain the query (gt not empty), some not
|
||||||
|
count_gt += gallery_imname in query_gts
|
||||||
|
# compute distance between query and gallery dets
|
||||||
|
if gallery_imname not in name_to_det_feat:
|
||||||
|
continue
|
||||||
|
det, feat_g = name_to_det_feat[gallery_imname]
|
||||||
|
# get L2-normalized feature matrix NxD
|
||||||
|
assert feat_g.size == np.prod(feat_g.shape[:2])
|
||||||
|
feat_g = feat_g.reshape(feat_g.shape[:2])
|
||||||
|
# compute cosine similarities
|
||||||
|
sim = feat_g.dot(feat_p).ravel()
|
||||||
|
|
||||||
|
if gallery_imname in name2sim:
|
||||||
|
continue
|
||||||
|
name2sim[gallery_imname] = sim
|
||||||
|
sims.extend(list(sim))
|
||||||
|
imgs_cbgm.extend([gallery_imname] * len(sim))
|
||||||
|
|
||||||
|
if cbgm:
|
||||||
|
sims = np.array(sims)
|
||||||
|
imgs_cbgm = np.array(imgs_cbgm)
|
||||||
|
inds = np.argsort(sims)[-k1:]
|
||||||
|
imgs_cbgm = set(imgs_cbgm[inds])
|
||||||
|
for img in imgs_cbgm:
|
||||||
|
sim = name2sim[img]
|
||||||
|
det, feat_g = name_to_det_feat[img]
|
||||||
|
qboxes = query_dets[i][:k2]
|
||||||
|
qfeats = query_feats[i][:k2]
|
||||||
|
# assert (
|
||||||
|
# query_roi - qboxes[0][:4]
|
||||||
|
# ).sum() <= 0.001, "query_roi must be the first one in pboxes"
|
||||||
|
|
||||||
|
graph = []
|
||||||
|
for indx_i, pfeat in enumerate(qfeats):
|
||||||
|
for indx_j, gfeat in enumerate(feat_g):
|
||||||
|
graph.append((indx_i, indx_j, (pfeat * gfeat).sum()))
|
||||||
|
km_res, max_val = run_kuhn_munkres(graph)
|
||||||
|
|
||||||
|
for indx_i, indx_j, _ in km_res:
|
||||||
|
if indx_i == 0:
|
||||||
|
sim[indx_j] = max_val
|
||||||
|
break
|
||||||
|
for gallery_imname, sim in name2sim.items():
|
||||||
|
det, feat_g = name_to_det_feat[gallery_imname]
|
||||||
|
# assign label for each det
|
||||||
|
label = np.zeros(len(sim), dtype=np.int32)
|
||||||
|
if gallery_imname in query_gts:
|
||||||
|
gt = query_gts[gallery_imname].ravel()
|
||||||
|
w, h = gt[2] - gt[0], gt[3] - gt[1]
|
||||||
|
iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
|
||||||
|
inds = np.argsort(sim)[::-1]
|
||||||
|
sim = sim[inds]
|
||||||
|
det = det[inds]
|
||||||
|
# only set the first matched det as true positive
|
||||||
|
for j, roi in enumerate(det[:, :4]):
|
||||||
|
if _compute_iou(roi, gt) >= iou_thresh:
|
||||||
|
label[j] = 1
|
||||||
|
count_tp += 1
|
||||||
|
break
|
||||||
|
y_true.extend(list(label))
|
||||||
|
y_score.extend(list(sim))
|
||||||
|
imgs.extend([gallery_imname] * len(sim))
|
||||||
|
rois.extend(list(det))
|
||||||
|
|
||||||
|
# 2. Compute AP for this query (need to scale by recall rate)
|
||||||
|
y_score = np.asarray(y_score)
|
||||||
|
y_true = np.asarray(y_true)
|
||||||
|
assert count_tp <= count_gt
|
||||||
|
recall_rate = count_tp * 1.0 / count_gt
|
||||||
|
ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
|
||||||
|
aps.append(ap)
|
||||||
|
inds = np.argsort(y_score)[::-1]
|
||||||
|
y_score = y_score[inds]
|
||||||
|
y_true = y_true[inds]
|
||||||
|
accs.append([min(1, sum(y_true[:k])) for k in topk])
|
||||||
|
# 4. Save result for JSON dump
|
||||||
|
new_entry = {
|
||||||
|
"query_img": str(query_imname),
|
||||||
|
"query_roi": list(map(float, list(query_roi.squeeze()))),
|
||||||
|
"query_gt": query_gts,
|
||||||
|
"gallery": [],
|
||||||
|
}
|
||||||
|
# only save top-10 predictions
|
||||||
|
for k in range(10):
|
||||||
|
new_entry["gallery"].append(
|
||||||
|
{
|
||||||
|
"img": str(imgs[inds[k]]),
|
||||||
|
"roi": list(map(float, list(rois[inds[k]]))),
|
||||||
|
"score": float(y_score[k]),
|
||||||
|
"correct": int(y_true[k]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ret["results"].append(new_entry)
|
||||||
|
|
||||||
|
print("search ranking:")
|
||||||
|
mAP = np.mean(aps)
|
||||||
|
print(" mAP = {:.2%}".format(mAP))
|
||||||
|
accs = np.mean(accs, axis=0)
|
||||||
|
for i, k in enumerate(topk):
|
||||||
|
print(" top-{:2d} = {:.2%}".format(k, accs[i]))
|
||||||
|
|
||||||
|
# write_json(ret, "vis/results.json")
|
||||||
|
|
||||||
|
ret["mAP"] = np.mean(aps)
|
||||||
|
ret["accs"] = accs
|
||||||
|
return ret
|
|
@ -0,0 +1,76 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import autograd, nn
|
||||||
|
|
||||||
|
class OIM(autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, inputs, targets, lut, cq, header, momentum):
|
||||||
|
ctx.save_for_backward(inputs, targets, lut, cq, header, momentum)
|
||||||
|
outputs_labeled = inputs.mm(lut.t())
|
||||||
|
outputs_unlabeled = inputs.mm(cq.t())
|
||||||
|
return torch.cat([outputs_labeled, outputs_unlabeled], dim=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_outputs):
|
||||||
|
inputs, targets, lut, cq, header, momentum = ctx.saved_tensors
|
||||||
|
|
||||||
|
grad_inputs = None
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
grad_inputs = grad_outputs.mm(torch.cat([lut, cq], dim=0))
|
||||||
|
if grad_inputs.dtype == torch.float16:
|
||||||
|
grad_inputs = grad_inputs.to(torch.float32)
|
||||||
|
|
||||||
|
for x, y in zip(inputs, targets):
|
||||||
|
if y < len(lut):
|
||||||
|
lut[y] = momentum * lut[y] + (1.0 - momentum) * x
|
||||||
|
lut[y] /= lut[y].norm()
|
||||||
|
else:
|
||||||
|
cq[header] = x
|
||||||
|
header = (header + 1) % cq.size(0)
|
||||||
|
return grad_inputs, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def oim(inputs, targets, lut, cq, header, momentum=0.5):
|
||||||
|
return OIM.apply(inputs, targets, lut, cq, torch.tensor(header), torch.tensor(momentum))
|
||||||
|
|
||||||
|
|
||||||
|
class OIMLoss(nn.Module):
|
||||||
|
def __init__(self, num_features, num_pids, num_cq_size, oim_momentum, oim_scalar):
|
||||||
|
super(OIMLoss, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.num_pids = num_pids
|
||||||
|
self.num_unlabeled = num_cq_size
|
||||||
|
self.momentum = oim_momentum
|
||||||
|
self.oim_scalar = oim_scalar
|
||||||
|
|
||||||
|
self.register_buffer("lut", torch.zeros(self.num_pids, self.num_features))
|
||||||
|
self.register_buffer("cq", torch.zeros(self.num_unlabeled, self.num_features))
|
||||||
|
|
||||||
|
self.header_cq = 0
|
||||||
|
|
||||||
|
def forward(self, inputs, roi_label):
|
||||||
|
# merge into one batch, background label = 0
|
||||||
|
targets = torch.cat(roi_label)
|
||||||
|
label = targets - 1 # background label = -1
|
||||||
|
|
||||||
|
inds = label >= 0
|
||||||
|
|
||||||
|
label = label[inds]
|
||||||
|
inputs = inputs[inds.unsqueeze(1).expand_as(inputs)].view(-1, self.num_features)
|
||||||
|
|
||||||
|
projected = oim(inputs, label, self.lut, self.cq, self.header_cq, momentum=self.momentum)
|
||||||
|
# projected - Tensor [M, lut+cq], e.g., [M, 482+500]=[M, 982]
|
||||||
|
|
||||||
|
projected *= self.oim_scalar
|
||||||
|
|
||||||
|
self.header_cq = (
|
||||||
|
self.header_cq + (label >= self.num_pids).long().sum().item()
|
||||||
|
) % self.num_unlabeled
|
||||||
|
|
||||||
|
loss_oim = F.cross_entropy(projected, label, ignore_index=5554)
|
||||||
|
|
||||||
|
return loss_oim, inputs, label
|
|
@ -0,0 +1,62 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class SoftmaxLoss(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(SoftmaxLoss, self).__init__()
|
||||||
|
|
||||||
|
self.feat_dim = cfg.MODEL.EMBEDDING_DIM
|
||||||
|
self.num_classes = cfg.MODEL.LOSS.LUT_SIZE
|
||||||
|
|
||||||
|
self.bottleneck = nn.BatchNorm1d(self.feat_dim)
|
||||||
|
self.bottleneck.bias.requires_grad_(False) # no shift
|
||||||
|
self.classifier = nn.Linear(self.feat_dim, self.num_classes, bias=False)
|
||||||
|
|
||||||
|
self.bottleneck.apply(weights_init_kaiming)
|
||||||
|
self.classifier.apply(weights_init_classifier)
|
||||||
|
|
||||||
|
def forward(self, inputs, labels):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: feature matrix with shape (batch_size, feat_dim).
|
||||||
|
labels: ground truth labels with shape (num_classes).
|
||||||
|
"""
|
||||||
|
assert inputs.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
|
||||||
|
|
||||||
|
target = labels.clone()
|
||||||
|
target[target >= self.num_classes] = 5554
|
||||||
|
|
||||||
|
feat = self.bottleneck(inputs)
|
||||||
|
score = self.classifier(feat)
|
||||||
|
loss = F.cross_entropy(score, target, ignore_index=5554)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init_kaiming(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Linear') != -1:
|
||||||
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
|
||||||
|
nn.init.constant_(m.bias, 0.0)
|
||||||
|
elif classname.find('Conv') != -1:
|
||||||
|
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0.0)
|
||||||
|
elif classname.find('BatchNorm') != -1:
|
||||||
|
if m.affine:
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
nn.init.constant_(m.bias, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def weights_init_classifier(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find('Linear') != -1:
|
||||||
|
nn.init.normal_(m.weight, std=0.001)
|
||||||
|
if m.bias:
|
||||||
|
nn.init.constant_(m.bias, 0.0)
|
||||||
|
|
|
@ -0,0 +1,765 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import init
|
||||||
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||||
|
from torchvision.models.detection.roi_heads import RoIHeads
|
||||||
|
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
|
||||||
|
from torchvision.models.detection.transform import GeneralizedRCNNTransform
|
||||||
|
from torchvision.ops import MultiScaleRoIAlign
|
||||||
|
from torchvision.ops import boxes as box_ops
|
||||||
|
from torchvision.models.detection import _utils as det_utils
|
||||||
|
|
||||||
|
from loss.oim import OIMLoss
|
||||||
|
from models.resnet import build_resnet
|
||||||
|
from models.transformer import TransformerHead
|
||||||
|
|
||||||
|
|
||||||
|
class COAT(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(COAT, self).__init__()
|
||||||
|
|
||||||
|
backbone, _ = build_resnet(name="resnet50", pretrained=True)
|
||||||
|
anchor_generator = AnchorGenerator(
|
||||||
|
sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
|
||||||
|
)
|
||||||
|
head = RPNHead(
|
||||||
|
in_channels=backbone.out_channels,
|
||||||
|
num_anchors=anchor_generator.num_anchors_per_location()[0],
|
||||||
|
)
|
||||||
|
pre_nms_top_n = dict(
|
||||||
|
training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST
|
||||||
|
)
|
||||||
|
post_nms_top_n = dict(
|
||||||
|
training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST
|
||||||
|
)
|
||||||
|
rpn = RegionProposalNetwork(
|
||||||
|
anchor_generator=anchor_generator,
|
||||||
|
head=head,
|
||||||
|
fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN,
|
||||||
|
bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN,
|
||||||
|
batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN,
|
||||||
|
positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN,
|
||||||
|
pre_nms_top_n=pre_nms_top_n,
|
||||||
|
post_nms_top_n=post_nms_top_n,
|
||||||
|
nms_thresh=cfg.MODEL.RPN.NMS_THRESH,
|
||||||
|
)
|
||||||
|
|
||||||
|
box_head = TransformerHead(
|
||||||
|
cfg=cfg,
|
||||||
|
trans_names=cfg.MODEL.TRANSFORMER.NAMES_1ST,
|
||||||
|
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_1ST,
|
||||||
|
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_1ST,
|
||||||
|
)
|
||||||
|
box_head_2nd = TransformerHead(
|
||||||
|
cfg=cfg,
|
||||||
|
trans_names=cfg.MODEL.TRANSFORMER.NAMES_2ND,
|
||||||
|
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_2ND,
|
||||||
|
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_2ND,
|
||||||
|
)
|
||||||
|
box_head_3rd = TransformerHead(
|
||||||
|
cfg=cfg,
|
||||||
|
trans_names=cfg.MODEL.TRANSFORMER.NAMES_3RD,
|
||||||
|
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_3RD,
|
||||||
|
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_3RD,
|
||||||
|
)
|
||||||
|
|
||||||
|
faster_rcnn_predictor = FastRCNNPredictor(2048, 2)
|
||||||
|
box_roi_pool = MultiScaleRoIAlign(
|
||||||
|
featmap_names=["feat_res4"], output_size=14, sampling_ratio=2
|
||||||
|
)
|
||||||
|
box_predictor = BBoxRegressor(2048, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK)
|
||||||
|
roi_heads = CascadedROIHeads(
|
||||||
|
cfg=cfg,
|
||||||
|
# Cascade Transformer Head
|
||||||
|
faster_rcnn_predictor=faster_rcnn_predictor,
|
||||||
|
box_head_2nd=box_head_2nd,
|
||||||
|
box_head_3rd=box_head_3rd,
|
||||||
|
# parent class
|
||||||
|
box_roi_pool=box_roi_pool,
|
||||||
|
box_head=box_head,
|
||||||
|
box_predictor=box_predictor,
|
||||||
|
fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN,
|
||||||
|
bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN,
|
||||||
|
batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN,
|
||||||
|
positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN,
|
||||||
|
bbox_reg_weights=None,
|
||||||
|
score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST,
|
||||||
|
nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST,
|
||||||
|
detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
transform = GeneralizedRCNNTransform(
|
||||||
|
min_size=cfg.INPUT.MIN_SIZE,
|
||||||
|
max_size=cfg.INPUT.MAX_SIZE,
|
||||||
|
image_mean=[0.485, 0.456, 0.406],
|
||||||
|
image_std=[0.229, 0.224, 0.225],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backbone = backbone
|
||||||
|
self.rpn = rpn
|
||||||
|
self.roi_heads = roi_heads
|
||||||
|
self.transform = transform
|
||||||
|
self.eval_feat = cfg.EVAL_FEATURE
|
||||||
|
|
||||||
|
# loss weights
|
||||||
|
self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG
|
||||||
|
self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS
|
||||||
|
self.lw_rcnn_reg_1st = cfg.SOLVER.LW_RCNN_REG_1ST
|
||||||
|
self.lw_rcnn_cls_1st = cfg.SOLVER.LW_RCNN_CLS_1ST
|
||||||
|
self.lw_rcnn_reg_2nd = cfg.SOLVER.LW_RCNN_REG_2ND
|
||||||
|
self.lw_rcnn_cls_2nd = cfg.SOLVER.LW_RCNN_CLS_2ND
|
||||||
|
self.lw_rcnn_reg_3rd = cfg.SOLVER.LW_RCNN_REG_3RD
|
||||||
|
self.lw_rcnn_cls_3rd = cfg.SOLVER.LW_RCNN_CLS_3RD
|
||||||
|
self.lw_rcnn_reid_2nd = cfg.SOLVER.LW_RCNN_REID_2ND
|
||||||
|
self.lw_rcnn_reid_3rd = cfg.SOLVER.LW_RCNN_REID_3RD
|
||||||
|
|
||||||
|
def inference(self, images, targets=None, query_img_as_gallery=False):
|
||||||
|
original_image_sizes = [img.shape[-2:] for img in images]
|
||||||
|
images, targets = self.transform(images, targets)
|
||||||
|
features = self.backbone(images.tensors)
|
||||||
|
|
||||||
|
if query_img_as_gallery:
|
||||||
|
assert targets is not None
|
||||||
|
|
||||||
|
if targets is not None and not query_img_as_gallery:
|
||||||
|
# query
|
||||||
|
boxes = [t["boxes"] for t in targets]
|
||||||
|
box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes)
|
||||||
|
box_features_2nd = self.roi_heads.box_head_2nd(box_features)
|
||||||
|
embeddings_2nd, _ = self.roi_heads.embedding_head_2nd(box_features_2nd)
|
||||||
|
box_features_3rd = self.roi_heads.box_head_3rd(box_features)
|
||||||
|
embeddings_3rd, _ = self.roi_heads.embedding_head_3rd(box_features_3rd)
|
||||||
|
if self.eval_feat == 'concat':
|
||||||
|
embeddings = torch.cat((embeddings_2nd, embeddings_3rd), dim=1)
|
||||||
|
elif self.eval_feat == 'stage2':
|
||||||
|
embeddings = embeddings_2nd
|
||||||
|
elif self.eval_feat == 'stage3':
|
||||||
|
embeddings = embeddings_3rd
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown evaluation feature name")
|
||||||
|
return embeddings.split(1, 0)
|
||||||
|
else:
|
||||||
|
# gallery
|
||||||
|
boxes, _ = self.rpn(images, features, targets)
|
||||||
|
detections = self.roi_heads(features, boxes, images.image_sizes, targets, query_img_as_gallery)[0]
|
||||||
|
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def forward(self, images, targets=None, query_img_as_gallery=False):
|
||||||
|
if not self.training:
|
||||||
|
return self.inference(images, targets, query_img_as_gallery)
|
||||||
|
images, targets = self.transform(images, targets)
|
||||||
|
features = self.backbone(images.tensors)
|
||||||
|
boxes, rpn_losses = self.rpn(images, features, targets)
|
||||||
|
|
||||||
|
_, rcnn_losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = self.roi_heads(features, boxes, images.image_sizes, targets)
|
||||||
|
|
||||||
|
# rename rpn losses to be consistent with detection losses
|
||||||
|
rpn_losses["loss_rpn_reg"] = rpn_losses.pop("loss_rpn_box_reg")
|
||||||
|
rpn_losses["loss_rpn_cls"] = rpn_losses.pop("loss_objectness")
|
||||||
|
|
||||||
|
losses = {}
|
||||||
|
losses.update(rcnn_losses)
|
||||||
|
losses.update(rpn_losses)
|
||||||
|
|
||||||
|
# apply loss weights
|
||||||
|
losses["loss_rpn_reg"] *= self.lw_rpn_reg
|
||||||
|
losses["loss_rpn_cls"] *= self.lw_rpn_cls
|
||||||
|
losses["loss_rcnn_reg_1st"] *= self.lw_rcnn_reg_1st
|
||||||
|
losses["loss_rcnn_cls_1st"] *= self.lw_rcnn_cls_1st
|
||||||
|
losses["loss_rcnn_reg_2nd"] *= self.lw_rcnn_reg_2nd
|
||||||
|
losses["loss_rcnn_cls_2nd"] *= self.lw_rcnn_cls_2nd
|
||||||
|
losses["loss_rcnn_reg_3rd"] *= self.lw_rcnn_reg_3rd
|
||||||
|
losses["loss_rcnn_cls_3rd"] *= self.lw_rcnn_cls_3rd
|
||||||
|
losses["loss_rcnn_reid_2nd"] *= self.lw_rcnn_reid_2nd
|
||||||
|
losses["loss_rcnn_reid_3rd"] *= self.lw_rcnn_reid_3rd
|
||||||
|
|
||||||
|
return losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
|
||||||
|
|
||||||
|
class CascadedROIHeads(RoIHeads):
|
||||||
|
'''
|
||||||
|
https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py
|
||||||
|
'''
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
faster_rcnn_predictor,
|
||||||
|
box_head_2nd,
|
||||||
|
box_head_3rd,
|
||||||
|
*args,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super(CascadedROIHeads, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# ROI head
|
||||||
|
self.use_diff_thresh=cfg.MODEL.ROI_HEAD.USE_DIFF_THRESH
|
||||||
|
self.nms_thresh_1st = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST
|
||||||
|
self.nms_thresh_2nd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND
|
||||||
|
self.nms_thresh_3rd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD
|
||||||
|
self.fg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN
|
||||||
|
self.bg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN
|
||||||
|
self.fg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND
|
||||||
|
self.bg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND
|
||||||
|
self.fg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD
|
||||||
|
self.bg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD
|
||||||
|
|
||||||
|
# Regression head
|
||||||
|
self.box_predictor_1st = faster_rcnn_predictor
|
||||||
|
self.box_predictor_2nd = self.box_predictor
|
||||||
|
self.box_predictor_3rd = deepcopy(self.box_predictor)
|
||||||
|
|
||||||
|
# Transformer head
|
||||||
|
self.box_head_1st = self.box_head
|
||||||
|
self.box_head_2nd = box_head_2nd
|
||||||
|
self.box_head_3rd = box_head_3rd
|
||||||
|
|
||||||
|
# feature mask
|
||||||
|
self.use_feature_mask = cfg.MODEL.USE_FEATURE_MASK
|
||||||
|
self.feature_mask_size = cfg.MODEL.FEATURE_MASK_SIZE
|
||||||
|
|
||||||
|
# Feature embedding
|
||||||
|
embedding_dim = cfg.MODEL.EMBEDDING_DIM
|
||||||
|
self.embedding_head_2nd = NormAwareEmbedding(featmap_names=["before_trans", "after_trans"], in_channels=[1024, 2048], dim=embedding_dim)
|
||||||
|
self.embedding_head_3rd = deepcopy(self.embedding_head_2nd)
|
||||||
|
|
||||||
|
# OIM
|
||||||
|
num_pids = cfg.MODEL.LOSS.LUT_SIZE
|
||||||
|
num_cq_size = cfg.MODEL.LOSS.CQ_SIZE
|
||||||
|
oim_momentum = cfg.MODEL.LOSS.OIM_MOMENTUM
|
||||||
|
oim_scalar = cfg.MODEL.LOSS.OIM_SCALAR
|
||||||
|
self.reid_loss_2nd = OIMLoss(embedding_dim, num_pids, num_cq_size, oim_momentum, oim_scalar)
|
||||||
|
self.reid_loss_3rd = deepcopy(self.reid_loss_2nd)
|
||||||
|
|
||||||
|
# rename the method inherited from parent class
|
||||||
|
self.postprocess_proposals = self.postprocess_detections
|
||||||
|
|
||||||
|
# evaluation
|
||||||
|
self.eval_feat = cfg.EVAL_FEATURE
|
||||||
|
|
||||||
|
def forward(self, features, boxes, image_shapes, targets=None, query_img_as_gallery=False):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
features (List[Tensor])
|
||||||
|
boxes (List[Tensor[N, 4]])
|
||||||
|
image_shapes (List[Tuple[H, W]])
|
||||||
|
targets (List[Dict])
|
||||||
|
"""
|
||||||
|
cws = True
|
||||||
|
gt_det_2nd = None
|
||||||
|
gt_det_3rd = None
|
||||||
|
feats_reid_2nd = None
|
||||||
|
feats_reid_3rd = None
|
||||||
|
targets_reid_2nd = None
|
||||||
|
targets_reid_3rd = None
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
if self.use_diff_thresh:
|
||||||
|
self.proposal_matcher = det_utils.Matcher(
|
||||||
|
self.fg_iou_thresh_1st,
|
||||||
|
self.bg_iou_thresh_1st,
|
||||||
|
allow_low_quality_matches=False)
|
||||||
|
boxes, _, box_pid_labels_1st, box_reg_targets_1st = self.select_training_samples(
|
||||||
|
boxes, targets
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------- The first stage ------------------ #
|
||||||
|
box_features_1st = self.box_roi_pool(features, boxes, image_shapes)
|
||||||
|
box_features_1st = self.box_head_1st(box_features_1st)
|
||||||
|
box_cls_scores_1st, box_regs_1st = self.box_predictor_1st(box_features_1st["after_trans"])
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
boxes = self.get_boxes(box_regs_1st, boxes, image_shapes)
|
||||||
|
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
|
||||||
|
if self.use_diff_thresh:
|
||||||
|
self.proposal_matcher = det_utils.Matcher(
|
||||||
|
self.fg_iou_thresh_2nd,
|
||||||
|
self.bg_iou_thresh_2nd,
|
||||||
|
allow_low_quality_matches=False)
|
||||||
|
boxes, _, box_pid_labels_2nd, box_reg_targets_2nd = self.select_training_samples(boxes, targets)
|
||||||
|
else:
|
||||||
|
orig_thresh = self.nms_thresh # 0.4
|
||||||
|
self.nms_thresh = self.nms_thresh_1st
|
||||||
|
boxes, scores, _ = self.postprocess_proposals(
|
||||||
|
box_cls_scores_1st, box_regs_1st, boxes, image_shapes
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.training and query_img_as_gallery:
|
||||||
|
# When regarding the query image as gallery, GT boxes may be excluded
|
||||||
|
# from detected boxes. To avoid this, we compulsorily include GT in the
|
||||||
|
# detection results. Additionally, CWS should be disabled as the
|
||||||
|
# confidences of these people in query image are 1
|
||||||
|
cws = False
|
||||||
|
gt_box = [targets[0]["boxes"]]
|
||||||
|
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
|
||||||
|
gt_box_features = self.box_head_2nd(gt_box_features)
|
||||||
|
embeddings, _ = self.embedding_head_2nd(gt_box_features)
|
||||||
|
gt_det_2nd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
|
||||||
|
|
||||||
|
# no detection predicted by Faster R-CNN head in test phase
|
||||||
|
if boxes[0].shape[0] == 0:
|
||||||
|
assert not self.training
|
||||||
|
boxes = gt_det_2nd["boxes"] if gt_det_2nd else torch.zeros(0, 4)
|
||||||
|
labels = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
|
||||||
|
scores = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
|
||||||
|
if self.eval_feat == 'concat':
|
||||||
|
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_2nd["embeddings"]), dim=1) if gt_det_2nd else torch.zeros(0, 512)
|
||||||
|
elif self.eval_feat == 'stage2' or self.eval_feat == 'stage3':
|
||||||
|
embeddings = gt_det_2nd["embeddings"] if gt_det_2nd else torch.zeros(0, 256)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown evaluation feature name")
|
||||||
|
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
|
||||||
|
|
||||||
|
# --------------------- The second stage -------------------- #
|
||||||
|
box_features = self.box_roi_pool(features, boxes, image_shapes)
|
||||||
|
box_features = self.box_head_2nd(box_features)
|
||||||
|
box_regs_2nd = self.box_predictor_2nd(box_features["after_trans"])
|
||||||
|
box_embeddings_2nd, box_cls_scores_2nd = self.embedding_head_2nd(box_features)
|
||||||
|
if box_cls_scores_2nd.dim() == 0:
|
||||||
|
box_cls_scores_2nd = box_cls_scores_2nd.unsqueeze(0)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
boxes = self.get_boxes(box_regs_2nd, boxes, image_shapes)
|
||||||
|
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
|
||||||
|
if self.use_diff_thresh:
|
||||||
|
self.proposal_matcher = det_utils.Matcher(
|
||||||
|
self.fg_iou_thresh_3rd,
|
||||||
|
self.bg_iou_thresh_3rd,
|
||||||
|
allow_low_quality_matches=False)
|
||||||
|
boxes, _, box_pid_labels_3rd, box_reg_targets_3rd = self.select_training_samples(boxes, targets)
|
||||||
|
else:
|
||||||
|
self.nms_thresh = self.nms_thresh_2nd
|
||||||
|
if self.eval_feat != 'stage2':
|
||||||
|
boxes, scores, _, _ = self.postprocess_boxes(
|
||||||
|
box_cls_scores_2nd,
|
||||||
|
box_regs_2nd,
|
||||||
|
box_embeddings_2nd,
|
||||||
|
boxes,
|
||||||
|
image_shapes,
|
||||||
|
fcs=scores,
|
||||||
|
gt_det=None,
|
||||||
|
cws=cws,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.training and query_img_as_gallery and self.eval_feat != 'stage2':
|
||||||
|
cws = False
|
||||||
|
gt_box = [targets[0]["boxes"]]
|
||||||
|
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
|
||||||
|
gt_box_features = self.box_head_3rd(gt_box_features)
|
||||||
|
embeddings, _ = self.embedding_head_3rd(gt_box_features)
|
||||||
|
gt_det_3rd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
|
||||||
|
|
||||||
|
# no detection predicted by Faster R-CNN head in test phase
|
||||||
|
if boxes[0].shape[0] == 0 and self.eval_feat != 'stage2':
|
||||||
|
assert not self.training
|
||||||
|
boxes = gt_det_3rd["boxes"] if gt_det_3rd else torch.zeros(0, 4)
|
||||||
|
labels = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
|
||||||
|
scores = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
|
||||||
|
if self.eval_feat == 'concat':
|
||||||
|
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_3rd["embeddings"]), dim=1) if gt_det_3rd else torch.zeros(0, 512)
|
||||||
|
elif self.eval_feat == 'stage3':
|
||||||
|
embeddings = gt_det_2nd["embeddings"] if gt_det_3rd else torch.zeros(0, 256)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown evaluation feature name")
|
||||||
|
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
|
||||||
|
|
||||||
|
# --------------------- The third stage -------------------- #
|
||||||
|
box_features = self.box_roi_pool(features, boxes, image_shapes)
|
||||||
|
|
||||||
|
if not self.training:
|
||||||
|
box_features_2nd = self.box_head_2nd(box_features)
|
||||||
|
box_embeddings_2nd, _ = self.embedding_head_2nd(box_features_2nd)
|
||||||
|
|
||||||
|
box_features = self.box_head_3rd(box_features)
|
||||||
|
box_regs_3rd = self.box_predictor_3rd(box_features["after_trans"])
|
||||||
|
box_embeddings_3rd, box_cls_scores_3rd = self.embedding_head_3rd(box_features)
|
||||||
|
if box_cls_scores_3rd.dim() == 0:
|
||||||
|
box_cls_scores_3rd = box_cls_scores_3rd.unsqueeze(0)
|
||||||
|
|
||||||
|
result, losses = [], {}
|
||||||
|
if self.training:
|
||||||
|
box_labels_1st = [y.clamp(0, 1) for y in box_pid_labels_1st]
|
||||||
|
box_labels_2nd = [y.clamp(0, 1) for y in box_pid_labels_2nd]
|
||||||
|
box_labels_3rd = [y.clamp(0, 1) for y in box_pid_labels_3rd]
|
||||||
|
losses = detection_losses(
|
||||||
|
box_cls_scores_1st,
|
||||||
|
box_regs_1st,
|
||||||
|
box_labels_1st,
|
||||||
|
box_reg_targets_1st,
|
||||||
|
box_cls_scores_2nd,
|
||||||
|
box_regs_2nd,
|
||||||
|
box_labels_2nd,
|
||||||
|
box_reg_targets_2nd,
|
||||||
|
box_cls_scores_3rd,
|
||||||
|
box_regs_3rd,
|
||||||
|
box_labels_3rd,
|
||||||
|
box_reg_targets_3rd,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_rcnn_reid_2nd, feats_reid_2nd, targets_reid_2nd = self.reid_loss_2nd(box_embeddings_2nd, box_pid_labels_2nd)
|
||||||
|
loss_rcnn_reid_3rd, feats_reid_3rd, targets_reid_3rd = self.reid_loss_3rd(box_embeddings_3rd, box_pid_labels_3rd)
|
||||||
|
losses.update(loss_rcnn_reid_2nd=loss_rcnn_reid_2nd)
|
||||||
|
losses.update(loss_rcnn_reid_3rd=loss_rcnn_reid_3rd)
|
||||||
|
else:
|
||||||
|
if self.eval_feat == 'stage2':
|
||||||
|
boxes, scores, embeddings_2nd, labels = self.postprocess_boxes(
|
||||||
|
box_cls_scores_2nd,
|
||||||
|
box_regs_2nd,
|
||||||
|
box_embeddings_2nd,
|
||||||
|
boxes,
|
||||||
|
image_shapes,
|
||||||
|
fcs=scores,
|
||||||
|
gt_det=gt_det_2nd,
|
||||||
|
cws=cws,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.nms_thresh = self.nms_thresh_3rd
|
||||||
|
_, _, embeddings_2nd, _ = self.postprocess_boxes(
|
||||||
|
box_cls_scores_3rd,
|
||||||
|
box_regs_3rd,
|
||||||
|
box_embeddings_2nd,
|
||||||
|
boxes,
|
||||||
|
image_shapes,
|
||||||
|
fcs=scores,
|
||||||
|
gt_det=gt_det_2nd,
|
||||||
|
cws=cws,
|
||||||
|
)
|
||||||
|
boxes, scores, embeddings_3rd, labels = self.postprocess_boxes(
|
||||||
|
box_cls_scores_3rd,
|
||||||
|
box_regs_3rd,
|
||||||
|
box_embeddings_3rd,
|
||||||
|
boxes,
|
||||||
|
image_shapes,
|
||||||
|
fcs=scores,
|
||||||
|
gt_det=gt_det_3rd,
|
||||||
|
cws=cws,
|
||||||
|
)
|
||||||
|
# set to original thresh after finishing postprocess
|
||||||
|
self.nms_thresh = orig_thresh
|
||||||
|
|
||||||
|
num_images = len(boxes)
|
||||||
|
for i in range(num_images):
|
||||||
|
if self.eval_feat == 'concat':
|
||||||
|
embeddings = torch.cat((embeddings_2nd[i],embeddings_3rd[i]), dim=1)
|
||||||
|
elif self.eval_feat == 'stage2':
|
||||||
|
embeddings = embeddings_2nd[i]
|
||||||
|
elif self.eval_feat == 'stage3':
|
||||||
|
embeddings = embeddings_3rd[i]
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown evaluation feature name")
|
||||||
|
result.append(
|
||||||
|
dict(
|
||||||
|
boxes=boxes[i],
|
||||||
|
labels=labels[i],
|
||||||
|
scores=scores[i],
|
||||||
|
embeddings=embeddings
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result, losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
|
||||||
|
|
||||||
|
def get_boxes(self, box_regression, proposals, image_shapes):
|
||||||
|
"""
|
||||||
|
Get boxes from proposals.
|
||||||
|
"""
|
||||||
|
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
|
||||||
|
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
||||||
|
pred_boxes = pred_boxes.split(boxes_per_image, 0)
|
||||||
|
|
||||||
|
all_boxes = []
|
||||||
|
for boxes, image_shape in zip(pred_boxes, image_shapes):
|
||||||
|
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
||||||
|
# remove predictions with the background label
|
||||||
|
boxes = boxes[:, 1:].reshape(-1, 4)
|
||||||
|
all_boxes.append(boxes)
|
||||||
|
|
||||||
|
return all_boxes
|
||||||
|
|
||||||
|
def postprocess_boxes(
|
||||||
|
self,
|
||||||
|
class_logits,
|
||||||
|
box_regression,
|
||||||
|
embeddings,
|
||||||
|
proposals,
|
||||||
|
image_shapes,
|
||||||
|
fcs=None,
|
||||||
|
gt_det=None,
|
||||||
|
cws=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement
|
||||||
|
First Classification Score (FCS).
|
||||||
|
"""
|
||||||
|
device = class_logits.device
|
||||||
|
|
||||||
|
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
|
||||||
|
pred_boxes = self.box_coder.decode(box_regression, proposals)
|
||||||
|
|
||||||
|
if fcs is not None:
|
||||||
|
# Fist Classification Score (FCS)
|
||||||
|
pred_scores = fcs[0]
|
||||||
|
else:
|
||||||
|
pred_scores = torch.sigmoid(class_logits)
|
||||||
|
if cws:
|
||||||
|
# Confidence Weighted Similarity (CWS)
|
||||||
|
embeddings = embeddings * pred_scores.view(-1, 1)
|
||||||
|
|
||||||
|
# split boxes and scores per image
|
||||||
|
pred_boxes = pred_boxes.split(boxes_per_image, 0)
|
||||||
|
pred_scores = pred_scores.split(boxes_per_image, 0)
|
||||||
|
pred_embeddings = embeddings.split(boxes_per_image, 0)
|
||||||
|
|
||||||
|
all_boxes = []
|
||||||
|
all_scores = []
|
||||||
|
all_labels = []
|
||||||
|
all_embeddings = []
|
||||||
|
for boxes, scores, embeddings, image_shape in zip(
|
||||||
|
pred_boxes, pred_scores, pred_embeddings, image_shapes
|
||||||
|
):
|
||||||
|
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
|
||||||
|
|
||||||
|
# create labels for each prediction
|
||||||
|
labels = torch.ones(scores.size(0), device=device)
|
||||||
|
|
||||||
|
# remove predictions with the background label
|
||||||
|
boxes = boxes[:, 1:]
|
||||||
|
scores = scores.unsqueeze(1)
|
||||||
|
labels = labels.unsqueeze(1)
|
||||||
|
|
||||||
|
# batch everything, by making every class prediction be a separate instance
|
||||||
|
boxes = boxes.reshape(-1, 4)
|
||||||
|
scores = scores.flatten()
|
||||||
|
labels = labels.flatten()
|
||||||
|
embeddings = embeddings.reshape(-1, self.embedding_head_2nd.dim)
|
||||||
|
|
||||||
|
# remove low scoring boxes
|
||||||
|
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
|
||||||
|
boxes, scores, labels, embeddings = (
|
||||||
|
boxes[inds],
|
||||||
|
scores[inds],
|
||||||
|
labels[inds],
|
||||||
|
embeddings[inds],
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove empty boxes
|
||||||
|
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
|
||||||
|
boxes, scores, labels, embeddings = (
|
||||||
|
boxes[keep],
|
||||||
|
scores[keep],
|
||||||
|
labels[keep],
|
||||||
|
embeddings[keep],
|
||||||
|
)
|
||||||
|
|
||||||
|
if gt_det is not None:
|
||||||
|
# include GT into the detection results
|
||||||
|
boxes = torch.cat((boxes, gt_det["boxes"]), dim=0)
|
||||||
|
labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0)
|
||||||
|
scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0)
|
||||||
|
embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0)
|
||||||
|
|
||||||
|
# non-maximum suppression, independently done per class
|
||||||
|
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
|
||||||
|
# keep only topk scoring predictions
|
||||||
|
keep = keep[: self.detections_per_img]
|
||||||
|
boxes, scores, labels, embeddings = (
|
||||||
|
boxes[keep],
|
||||||
|
scores[keep],
|
||||||
|
labels[keep],
|
||||||
|
embeddings[keep],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_boxes.append(boxes)
|
||||||
|
all_scores.append(scores)
|
||||||
|
all_labels.append(labels)
|
||||||
|
all_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
return all_boxes, all_scores, all_embeddings, all_labels
|
||||||
|
|
||||||
|
|
||||||
|
class NormAwareEmbedding(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements the Norm-Aware Embedding proposed in
|
||||||
|
Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256):
|
||||||
|
super(NormAwareEmbedding, self).__init__()
|
||||||
|
self.featmap_names = featmap_names
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
self.projectors = nn.ModuleDict()
|
||||||
|
indv_dims = self._split_embedding_dim()
|
||||||
|
for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims):
|
||||||
|
proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim))
|
||||||
|
init.normal_(proj[0].weight, std=0.01)
|
||||||
|
init.normal_(proj[1].weight, std=0.01)
|
||||||
|
init.constant_(proj[0].bias, 0)
|
||||||
|
init.constant_(proj[1].bias, 0)
|
||||||
|
self.projectors[ftname] = proj
|
||||||
|
|
||||||
|
self.rescaler = nn.BatchNorm1d(1, affine=True)
|
||||||
|
|
||||||
|
def forward(self, featmaps):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
featmaps: OrderedDict[Tensor], and in featmap_names you can choose which
|
||||||
|
featmaps to use
|
||||||
|
Returns:
|
||||||
|
tensor of size (BatchSize, dim), L2 normalized embeddings.
|
||||||
|
tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits.
|
||||||
|
"""
|
||||||
|
assert len(featmaps) == len(self.featmap_names)
|
||||||
|
if len(featmaps) == 1:
|
||||||
|
k, v = featmaps.items()[0]
|
||||||
|
v = self._flatten_fc_input(v)
|
||||||
|
embeddings = self.projectors[k](v)
|
||||||
|
norms = embeddings.norm(2, 1, keepdim=True)
|
||||||
|
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
|
||||||
|
norms = self.rescaler(norms).squeeze()
|
||||||
|
return embeddings, norms
|
||||||
|
else:
|
||||||
|
outputs = []
|
||||||
|
for k, v in featmaps.items():
|
||||||
|
v = self._flatten_fc_input(v)
|
||||||
|
outputs.append(self.projectors[k](v))
|
||||||
|
embeddings = torch.cat(outputs, dim=1)
|
||||||
|
norms = embeddings.norm(2, 1, keepdim=True)
|
||||||
|
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
|
||||||
|
norms = self.rescaler(norms).squeeze()
|
||||||
|
return embeddings, norms
|
||||||
|
|
||||||
|
def _flatten_fc_input(self, x):
|
||||||
|
if x.ndimension() == 4:
|
||||||
|
assert list(x.shape[2:]) == [1, 1]
|
||||||
|
return x.flatten(start_dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _split_embedding_dim(self):
|
||||||
|
parts = len(self.in_channels)
|
||||||
|
tmp = [self.dim // parts] * parts
|
||||||
|
if sum(tmp) == self.dim:
|
||||||
|
return tmp
|
||||||
|
else:
|
||||||
|
res = self.dim % parts
|
||||||
|
for i in range(1, res + 1):
|
||||||
|
tmp[-i] += 1
|
||||||
|
assert sum(tmp) == self.dim
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxRegressor(nn.Module):
|
||||||
|
"""
|
||||||
|
Bounding box regression layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, num_classes=2, bn_neck=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels (int): Input channels.
|
||||||
|
num_classes (int, optional): Defaults to 2 (background and pedestrian).
|
||||||
|
bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True.
|
||||||
|
"""
|
||||||
|
super(BBoxRegressor, self).__init__()
|
||||||
|
if bn_neck:
|
||||||
|
self.bbox_pred = nn.Sequential(
|
||||||
|
nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes)
|
||||||
|
)
|
||||||
|
init.normal_(self.bbox_pred[0].weight, std=0.01)
|
||||||
|
init.normal_(self.bbox_pred[1].weight, std=0.01)
|
||||||
|
init.constant_(self.bbox_pred[0].bias, 0)
|
||||||
|
init.constant_(self.bbox_pred[1].bias, 0)
|
||||||
|
else:
|
||||||
|
self.bbox_pred = nn.Linear(in_channels, 4 * num_classes)
|
||||||
|
init.normal_(self.bbox_pred.weight, std=0.01)
|
||||||
|
init.constant_(self.bbox_pred.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.ndimension() == 4:
|
||||||
|
if list(x.shape[2:]) != [1, 1]:
|
||||||
|
x = F.adaptive_avg_pool2d(x, output_size=1)
|
||||||
|
x = x.flatten(start_dim=1)
|
||||||
|
bbox_deltas = self.bbox_pred(x)
|
||||||
|
return bbox_deltas
|
||||||
|
|
||||||
|
|
||||||
|
def detection_losses(
|
||||||
|
box_cls_scores_1st,
|
||||||
|
box_regs_1st,
|
||||||
|
box_labels_1st,
|
||||||
|
box_reg_targets_1st,
|
||||||
|
box_cls_scores_2nd,
|
||||||
|
box_regs_2nd,
|
||||||
|
box_labels_2nd,
|
||||||
|
box_reg_targets_2nd,
|
||||||
|
box_cls_scores_3rd,
|
||||||
|
box_regs_3rd,
|
||||||
|
box_labels_3rd,
|
||||||
|
box_reg_targets_3rd,
|
||||||
|
):
|
||||||
|
# --------------------- The first stage -------------------- #
|
||||||
|
box_labels_1st = torch.cat(box_labels_1st, dim=0)
|
||||||
|
box_reg_targets_1st = torch.cat(box_reg_targets_1st, dim=0)
|
||||||
|
loss_rcnn_cls_1st = F.cross_entropy(box_cls_scores_1st, box_labels_1st)
|
||||||
|
|
||||||
|
# get indices that correspond to the regression targets for the
|
||||||
|
# corresponding ground truth labels, to be used with advanced indexing
|
||||||
|
sampled_pos_inds_subset = torch.nonzero(box_labels_1st > 0).squeeze(1)
|
||||||
|
labels_pos = box_labels_1st[sampled_pos_inds_subset]
|
||||||
|
N = box_cls_scores_1st.size(0)
|
||||||
|
box_regs_1st = box_regs_1st.reshape(N, -1, 4)
|
||||||
|
|
||||||
|
loss_rcnn_reg_1st = F.smooth_l1_loss(
|
||||||
|
box_regs_1st[sampled_pos_inds_subset, labels_pos],
|
||||||
|
box_reg_targets_1st[sampled_pos_inds_subset],
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
loss_rcnn_reg_1st = loss_rcnn_reg_1st / box_labels_1st.numel()
|
||||||
|
|
||||||
|
# --------------------- The second stage -------------------- #
|
||||||
|
box_labels_2nd = torch.cat(box_labels_2nd, dim=0)
|
||||||
|
box_reg_targets_2nd = torch.cat(box_reg_targets_2nd, dim=0)
|
||||||
|
loss_rcnn_cls_2nd = F.binary_cross_entropy_with_logits(box_cls_scores_2nd, box_labels_2nd.float())
|
||||||
|
|
||||||
|
sampled_pos_inds_subset = torch.nonzero(box_labels_2nd > 0).squeeze(1)
|
||||||
|
labels_pos = box_labels_2nd[sampled_pos_inds_subset]
|
||||||
|
N = box_cls_scores_2nd.size(0)
|
||||||
|
box_regs_2nd = box_regs_2nd.reshape(N, -1, 4)
|
||||||
|
|
||||||
|
loss_rcnn_reg_2nd = F.smooth_l1_loss(
|
||||||
|
box_regs_2nd[sampled_pos_inds_subset, labels_pos],
|
||||||
|
box_reg_targets_2nd[sampled_pos_inds_subset],
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
loss_rcnn_reg_2nd = loss_rcnn_reg_2nd / box_labels_2nd.numel()
|
||||||
|
|
||||||
|
# --------------------- The third stage -------------------- #
|
||||||
|
box_labels_3rd = torch.cat(box_labels_3rd, dim=0)
|
||||||
|
box_reg_targets_3rd = torch.cat(box_reg_targets_3rd, dim=0)
|
||||||
|
loss_rcnn_cls_3rd = F.binary_cross_entropy_with_logits(box_cls_scores_3rd, box_labels_3rd.float())
|
||||||
|
|
||||||
|
sampled_pos_inds_subset = torch.nonzero(box_labels_3rd > 0).squeeze(1)
|
||||||
|
labels_pos = box_labels_3rd[sampled_pos_inds_subset]
|
||||||
|
N = box_cls_scores_3rd.size(0)
|
||||||
|
box_regs_3rd = box_regs_3rd.reshape(N, -1, 4)
|
||||||
|
|
||||||
|
loss_rcnn_reg_3rd = F.smooth_l1_loss(
|
||||||
|
box_regs_3rd[sampled_pos_inds_subset, labels_pos],
|
||||||
|
box_reg_targets_3rd[sampled_pos_inds_subset],
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
loss_rcnn_reg_3rd = loss_rcnn_reg_3rd / box_labels_3rd.numel()
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
loss_rcnn_cls_1st=loss_rcnn_cls_1st,
|
||||||
|
loss_rcnn_reg_1st=loss_rcnn_reg_1st,
|
||||||
|
loss_rcnn_cls_2nd=loss_rcnn_cls_2nd,
|
||||||
|
loss_rcnn_reg_2nd=loss_rcnn_reg_2nd,
|
||||||
|
loss_rcnn_cls_3rd=loss_rcnn_cls_3rd,
|
||||||
|
loss_rcnn_reg_3rd=loss_rcnn_reg_3rd,
|
||||||
|
)
|
|
@ -0,0 +1,53 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class Backbone(nn.Sequential):
|
||||||
|
def __init__(self, resnet):
|
||||||
|
super(Backbone, self).__init__(
|
||||||
|
OrderedDict(
|
||||||
|
[
|
||||||
|
["conv1", resnet.conv1],
|
||||||
|
["bn1", resnet.bn1],
|
||||||
|
["relu", resnet.relu],
|
||||||
|
["maxpool", resnet.maxpool],
|
||||||
|
["layer1", resnet.layer1], # res2
|
||||||
|
["layer2", resnet.layer2], # res3
|
||||||
|
["layer3", resnet.layer3], # res4
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.out_channels = 1024
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# using the forward method from nn.Sequential
|
||||||
|
feat = super(Backbone, self).forward(x)
|
||||||
|
return OrderedDict([["feat_res4", feat]])
|
||||||
|
|
||||||
|
|
||||||
|
class Res5Head(nn.Sequential):
|
||||||
|
def __init__(self, resnet):
|
||||||
|
super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
|
||||||
|
self.out_channels = [1024, 2048]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = super(Res5Head, self).forward(x)
|
||||||
|
x = F.adaptive_max_pool2d(x, 1)
|
||||||
|
feat = F.adaptive_max_pool2d(feat, 1)
|
||||||
|
return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
|
||||||
|
|
||||||
|
|
||||||
|
def build_resnet(name="resnet50", pretrained=True):
|
||||||
|
resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
|
||||||
|
|
||||||
|
# freeze layers
|
||||||
|
resnet.conv1.weight.requires_grad_(False)
|
||||||
|
resnet.bn1.weight.requires_grad_(False)
|
||||||
|
resnet.bn1.bias.requires_grad_(False)
|
||||||
|
|
||||||
|
return Backbone(resnet), Res5Head(resnet)
|
|
@ -0,0 +1,300 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from functools import reduce
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from utils.mask import exchange_token, exchange_patch, get_mask_box, jigsaw_token, cutout_patch, erase_patch, mixup_patch, jigsaw_patch
|
||||||
|
|
||||||
|
|
||||||
|
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
trans_names,
|
||||||
|
kernel_size,
|
||||||
|
use_feature_mask,
|
||||||
|
):
|
||||||
|
super(TransformerHead, self).__init__()
|
||||||
|
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
|
||||||
|
|
||||||
|
# Mask parameters
|
||||||
|
self.use_feature_mask = use_feature_mask
|
||||||
|
mask_shape = cfg.MODEL.MASK_SHAPE
|
||||||
|
mask_size = cfg.MODEL.MASK_SIZE
|
||||||
|
mask_mode = cfg.MODEL.MASK_MODE
|
||||||
|
|
||||||
|
self.bypass_mask = exchange_patch(mask_shape, mask_size, mask_mode)
|
||||||
|
self.get_mask_box = get_mask_box(mask_shape, mask_size, mask_mode)
|
||||||
|
|
||||||
|
self.transformer_encoder = Transformers(
|
||||||
|
cfg=cfg,
|
||||||
|
trans_names=trans_names,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_feature_mask=use_feature_mask,
|
||||||
|
)
|
||||||
|
self.conv0 = conv1x1(1024, 1024)
|
||||||
|
self.conv1 = conv1x1(1024, d_model)
|
||||||
|
self.conv2 = conv1x1(d_model, 2048)
|
||||||
|
|
||||||
|
def forward(self, box_features):
|
||||||
|
mask_box = self.get_mask_box(box_features)
|
||||||
|
|
||||||
|
if self.use_feature_mask:
|
||||||
|
skip_features = self.conv0(box_features)
|
||||||
|
if self.training:
|
||||||
|
skip_features = self.bypass_mask(skip_features)
|
||||||
|
else:
|
||||||
|
skip_features = box_features
|
||||||
|
|
||||||
|
trans_features = {}
|
||||||
|
trans_features["before_trans"] = F.adaptive_max_pool2d(skip_features, 1)
|
||||||
|
box_features = self.conv1(box_features)
|
||||||
|
box_features = self.transformer_encoder((box_features,mask_box))
|
||||||
|
box_features = self.conv2(box_features)
|
||||||
|
trans_features["after_trans"] = F.adaptive_max_pool2d(box_features, 1)
|
||||||
|
|
||||||
|
return trans_features
|
||||||
|
|
||||||
|
|
||||||
|
class Transformers(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg,
|
||||||
|
trans_names,
|
||||||
|
kernel_size,
|
||||||
|
use_feature_mask,
|
||||||
|
):
|
||||||
|
super(Transformers, self).__init__()
|
||||||
|
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
|
||||||
|
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
|
||||||
|
self.use_feature_mask = use_feature_mask
|
||||||
|
|
||||||
|
# If no conv before transformer, we do not use scales
|
||||||
|
if not cfg.MODEL.TRANSFORMER.USE_PATCH2VEC:
|
||||||
|
trans_names = ['scale1']
|
||||||
|
kernel_size = [(1,1)]
|
||||||
|
|
||||||
|
self.trans_names = trans_names
|
||||||
|
self.scale_size = len(self.trans_names)
|
||||||
|
hidden = d_model//(2*self.scale_size)
|
||||||
|
|
||||||
|
# kernel_size: (padding, stride)
|
||||||
|
kernels = {
|
||||||
|
(1,1): [(0,0),(1,1)],
|
||||||
|
(3,3): [(1,1),(1,1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
padding = []
|
||||||
|
stride = []
|
||||||
|
for ksize in kernel_size:
|
||||||
|
if ksize not in [(1,1),(3,3)]:
|
||||||
|
raise ValueError('Undefined kernel size.')
|
||||||
|
padding.append(kernels[ksize][0])
|
||||||
|
stride.append(kernels[ksize][1])
|
||||||
|
|
||||||
|
self.use_output_layer = cfg.MODEL.TRANSFORMER.USE_OUTPUT_LAYER
|
||||||
|
self.use_global_shortcut = cfg.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleDict()
|
||||||
|
for tname, ksize, psize, ssize in zip(self.trans_names, kernel_size, padding, stride):
|
||||||
|
transblock = Transformer(
|
||||||
|
cfg, d_model//self.scale_size, ksize, psize, ssize, hidden, use_feature_mask
|
||||||
|
)
|
||||||
|
self.blocks[tname] = nn.Sequential(transblock)
|
||||||
|
|
||||||
|
self.output_linear = nn.Sequential(
|
||||||
|
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
|
||||||
|
nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
)
|
||||||
|
self.mask_para = [cfg.MODEL.MASK_SHAPE, cfg.MODEL.MASK_SIZE, cfg.MODEL.MASK_MODE]
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
trans_feat = []
|
||||||
|
enc_feat, mask_box = inputs
|
||||||
|
|
||||||
|
if self.training and self.use_feature_mask and self.feature_aug_type == 'exchange_patch':
|
||||||
|
feature_mask = exchange_patch(self.mask_para[0], self.mask_para[1], self.mask_para[2])
|
||||||
|
enc_feat = feature_mask(enc_feat)
|
||||||
|
|
||||||
|
for tname, feat in zip(self.trans_names, torch.chunk(enc_feat, len(self.trans_names), dim=1)):
|
||||||
|
feat = self.blocks[tname]((feat, mask_box))
|
||||||
|
trans_feat.append(feat)
|
||||||
|
|
||||||
|
trans_feat = torch.cat(trans_feat, 1)
|
||||||
|
if self.use_output_layer:
|
||||||
|
trans_feat = self.output_linear(trans_feat)
|
||||||
|
if self.use_global_shortcut:
|
||||||
|
trans_feat = enc_feat + trans_feat
|
||||||
|
return trans_feat
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, cfg, channel, kernel_size, padding, stride, hidden, use_feature_mask
|
||||||
|
):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
self.k = kernel_size[0]
|
||||||
|
stack_num = cfg.MODEL.TRANSFORMER.ENCODER_LAYERS
|
||||||
|
num_head = cfg.MODEL.TRANSFORMER.N_HEAD
|
||||||
|
dropout = cfg.MODEL.TRANSFORMER.DROPOUT
|
||||||
|
output_size = (14,14)
|
||||||
|
token_size = tuple(map(lambda x,y:x//y, output_size, stride))
|
||||||
|
blocks = []
|
||||||
|
self.transblock = TransformerBlock(token_size, hidden=hidden, num_head=num_head, dropout=dropout)
|
||||||
|
for _ in range(stack_num):
|
||||||
|
blocks.append(self.transblock)
|
||||||
|
self.transformer = nn.Sequential(*blocks)
|
||||||
|
self.patch2vec = nn.Conv2d(channel, hidden, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
self.vec2patch = Vec2Patch(channel, hidden, output_size, kernel_size, stride, padding)
|
||||||
|
self.use_local_shortcut = cfg.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT
|
||||||
|
self.use_feature_mask = use_feature_mask
|
||||||
|
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
|
||||||
|
self.use_patch2vec = cfg.MODEL.TRANSFORMER.USE_PATCH2VEC
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
enc_feat, mask_box = inputs
|
||||||
|
b, c, h, w = enc_feat.size()
|
||||||
|
|
||||||
|
trans_feat = self.patch2vec(enc_feat)
|
||||||
|
|
||||||
|
_, c, h, w = trans_feat.size()
|
||||||
|
trans_feat = trans_feat.view(b, c, -1).permute(0, 2, 1)
|
||||||
|
|
||||||
|
# For 1x1 & 3x3 kernels, exchange tokens
|
||||||
|
if self.training and self.use_feature_mask:
|
||||||
|
if self.feature_aug_type == 'exchange_token':
|
||||||
|
feature_mask = exchange_token()
|
||||||
|
trans_feat = feature_mask(trans_feat, mask_box)
|
||||||
|
elif self.feature_aug_type == 'cutout_patch':
|
||||||
|
feature_mask = cutout_patch()
|
||||||
|
trans_feat = feature_mask(trans_feat)
|
||||||
|
elif self.feature_aug_type == 'erase_patch':
|
||||||
|
feature_mask = erase_patch()
|
||||||
|
trans_feat = feature_mask(trans_feat)
|
||||||
|
elif self.feature_aug_type == 'mixup_patch':
|
||||||
|
feature_mask = mixup_patch()
|
||||||
|
trans_feat = feature_mask(trans_feat)
|
||||||
|
|
||||||
|
if self.use_feature_mask:
|
||||||
|
if self.feature_aug_type == 'jigsaw_patch':
|
||||||
|
feature_mask = jigsaw_patch()
|
||||||
|
trans_feat = feature_mask(trans_feat)
|
||||||
|
elif self.feature_aug_type == 'jigsaw_token':
|
||||||
|
feature_mask = jigsaw_token()
|
||||||
|
trans_feat = feature_mask(trans_feat)
|
||||||
|
|
||||||
|
trans_feat = self.transformer(trans_feat)
|
||||||
|
trans_feat = self.vec2patch(trans_feat)
|
||||||
|
if self.use_local_shortcut:
|
||||||
|
trans_feat = enc_feat + trans_feat
|
||||||
|
|
||||||
|
return trans_feat
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
|
||||||
|
"""
|
||||||
|
def __init__(self, tokensize, hidden=128, num_head=4, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = MultiHeadedAttention(tokensize, d_model=hidden, head=num_head, p=dropout)
|
||||||
|
self.ffn = FeedForward(hidden, p=dropout)
|
||||||
|
self.norm1 = nn.LayerNorm(hidden)
|
||||||
|
self.norm2 = nn.LayerNorm(hidden)
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = x + self.dropout(self.attention(x))
|
||||||
|
y = self.norm2(x)
|
||||||
|
x = x + self.ffn(y)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Compute 'Scaled Dot Product Attention
|
||||||
|
"""
|
||||||
|
def __init__(self, p=0.1):
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
scores = torch.matmul(query, key.transpose(-2, -1)
|
||||||
|
) / math.sqrt(query.size(-1))
|
||||||
|
p_attn = F.softmax(scores, dim=-1)
|
||||||
|
p_attn = self.dropout(p_attn)
|
||||||
|
p_val = torch.matmul(p_attn, value)
|
||||||
|
return p_val, p_attn
|
||||||
|
|
||||||
|
|
||||||
|
class Vec2Patch(nn.Module):
|
||||||
|
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
|
||||||
|
super(Vec2Patch, self).__init__()
|
||||||
|
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||||
|
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
||||||
|
self.embedding = nn.Linear(hidden, c_out)
|
||||||
|
self.to_patch = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||||
|
h, w = output_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.embedding(x)
|
||||||
|
b, n, c = feat.size()
|
||||||
|
feat = feat.permute(0, 2, 1)
|
||||||
|
feat = self.to_patch(feat)
|
||||||
|
|
||||||
|
return feat
|
||||||
|
|
||||||
|
class MultiHeadedAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Take in model size and number of heads.
|
||||||
|
"""
|
||||||
|
def __init__(self, tokensize, d_model, head, p=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.query_embedding = nn.Linear(d_model, d_model)
|
||||||
|
self.value_embedding = nn.Linear(d_model, d_model)
|
||||||
|
self.key_embedding = nn.Linear(d_model, d_model)
|
||||||
|
self.output_linear = nn.Linear(d_model, d_model)
|
||||||
|
self.attention = Attention(p=p)
|
||||||
|
self.head = head
|
||||||
|
self.h, self.w = tokensize
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, n, c = x.size()
|
||||||
|
c_h = c // self.head
|
||||||
|
key = self.key_embedding(x)
|
||||||
|
query = self.query_embedding(x)
|
||||||
|
value = self.value_embedding(x)
|
||||||
|
key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
|
||||||
|
query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
|
||||||
|
value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
|
||||||
|
att, _ = self.attention(query, key, value)
|
||||||
|
att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c)
|
||||||
|
output = self.output_linear(att)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, d_model, p=0.1):
|
||||||
|
super(FeedForward, self).__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Linear(d_model, d_model * 4),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Dropout(p=p),
|
||||||
|
nn.Linear(d_model * 4, d_model),
|
||||||
|
nn.Dropout(p=p))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
|
@ -0,0 +1,154 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os.path as osp
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
|
||||||
|
from datasets import build_test_loader, build_train_loader
|
||||||
|
from defaults import get_default_cfg
|
||||||
|
from engine import evaluate_performance, train_one_epoch
|
||||||
|
from models.coat import COAT
|
||||||
|
from utils.utils import mkdir, resume_from_ckpt, save_on_master, set_random_seed
|
||||||
|
|
||||||
|
from loss.softmax_loss import SoftmaxLoss
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
cfg = get_default_cfg()
|
||||||
|
if args.cfg_file:
|
||||||
|
cfg.merge_from_file(args.cfg_file)
|
||||||
|
cfg.merge_from_list(args.opts)
|
||||||
|
cfg.freeze()
|
||||||
|
|
||||||
|
device = torch.device(cfg.DEVICE)
|
||||||
|
if cfg.SEED >= 0:
|
||||||
|
set_random_seed(cfg.SEED)
|
||||||
|
|
||||||
|
print("Creating model...")
|
||||||
|
model = COAT(cfg)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
print("Loading data...")
|
||||||
|
train_loader = build_train_loader(cfg)
|
||||||
|
gallery_loader, query_loader = build_test_loader(cfg)
|
||||||
|
|
||||||
|
softmax_criterion_s2 = None
|
||||||
|
softmax_criterion_s3 = None
|
||||||
|
if cfg.MODEL.LOSS.USE_SOFTMAX:
|
||||||
|
softmax_criterion_s2 = SoftmaxLoss(cfg)
|
||||||
|
softmax_criterion_s3 = SoftmaxLoss(cfg)
|
||||||
|
softmax_criterion_s2.to(device)
|
||||||
|
softmax_criterion_s3.to(device)
|
||||||
|
|
||||||
|
if args.eval:
|
||||||
|
assert args.ckpt, "--ckpt must be specified when --eval enabled"
|
||||||
|
resume_from_ckpt(args.ckpt, model)
|
||||||
|
evaluate_performance(
|
||||||
|
model,
|
||||||
|
gallery_loader,
|
||||||
|
query_loader,
|
||||||
|
device,
|
||||||
|
use_gt=cfg.EVAL_USE_GT,
|
||||||
|
use_cache=cfg.EVAL_USE_CACHE,
|
||||||
|
use_cbgm=cfg.EVAL_USE_CBGM,
|
||||||
|
gallery_size=cfg.EVAL_GALLERY_SIZE,
|
||||||
|
)
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
params = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
if cfg.MODEL.LOSS.USE_SOFTMAX:
|
||||||
|
params_softmax_s2 = [p for p in softmax_criterion_s2.parameters() if p.requires_grad]
|
||||||
|
params_softmax_s3 = [p for p in softmax_criterion_s3.parameters() if p.requires_grad]
|
||||||
|
params.extend(params_softmax_s2)
|
||||||
|
params.extend(params_softmax_s3)
|
||||||
|
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
params,
|
||||||
|
lr=cfg.SOLVER.BASE_LR,
|
||||||
|
momentum=cfg.SOLVER.SGD_MOMENTUM,
|
||||||
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
optimizer, milestones=cfg.SOLVER.LR_DECAY_MILESTONES, gamma=cfg.SOLVER.GAMMA
|
||||||
|
)
|
||||||
|
|
||||||
|
start_epoch = 0
|
||||||
|
if args.resume:
|
||||||
|
assert args.ckpt, "--ckpt must be specified when --resume enabled"
|
||||||
|
start_epoch = resume_from_ckpt(args.ckpt, model, optimizer, lr_scheduler) + 1
|
||||||
|
|
||||||
|
print("Creating output folder...")
|
||||||
|
output_dir = cfg.OUTPUT_DIR
|
||||||
|
mkdir(output_dir)
|
||||||
|
path = osp.join(output_dir, "config.yaml")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(cfg.dump())
|
||||||
|
print(f"Full config is saved to {path}")
|
||||||
|
tfboard = None
|
||||||
|
if cfg.TF_BOARD:
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
tf_log_path = osp.join(output_dir, "tf_log")
|
||||||
|
mkdir(tf_log_path)
|
||||||
|
tfboard = SummaryWriter(log_dir=tf_log_path)
|
||||||
|
print(f"TensorBoard files are saved to {tf_log_path}")
|
||||||
|
|
||||||
|
print("Start training...")
|
||||||
|
start_time = time.time()
|
||||||
|
for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
|
||||||
|
train_one_epoch(cfg, model, optimizer, train_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3)
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
# only save the last three checkpoints
|
||||||
|
if epoch >= cfg.SOLVER.MAX_EPOCHS - 3:
|
||||||
|
save_on_master(
|
||||||
|
{
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"lr_scheduler": lr_scheduler.state_dict(),
|
||||||
|
"epoch": epoch,
|
||||||
|
},
|
||||||
|
osp.join(output_dir, f"epoch_{epoch}.pth"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# evaluate the current checkpoint
|
||||||
|
evaluate_performance(
|
||||||
|
model,
|
||||||
|
gallery_loader,
|
||||||
|
query_loader,
|
||||||
|
device,
|
||||||
|
use_gt=cfg.EVAL_USE_GT,
|
||||||
|
use_cache=cfg.EVAL_USE_CACHE,
|
||||||
|
use_cbgm=cfg.EVAL_USE_CBGM,
|
||||||
|
gallery_size=cfg.EVAL_GALLERY_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tfboard:
|
||||||
|
tfboard.close()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(f"Total training time {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Train a person search network.")
|
||||||
|
parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval", action="store_true", help="Evaluate the performance of a given checkpoint."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume", action="store_true", help="Resume from the specified checkpoint."
|
||||||
|
)
|
||||||
|
parser.add_argument("--ckpt", help="Path to checkpoint to resume or evaluate.")
|
||||||
|
parser.add_argument(
|
||||||
|
"opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
|
@ -0,0 +1,150 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
zero_threshold = 0.00000001
|
||||||
|
|
||||||
|
class KMNode(object):
|
||||||
|
def __init__(self, id, exception=0, match=None, visit=False):
|
||||||
|
self.id = id
|
||||||
|
self.exception = exception
|
||||||
|
self.match = match
|
||||||
|
self.visit = visit
|
||||||
|
|
||||||
|
|
||||||
|
class KuhnMunkres(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.matrix = None
|
||||||
|
self.x_nodes = []
|
||||||
|
self.y_nodes = []
|
||||||
|
self.minz = float("inf")
|
||||||
|
self.x_length = 0
|
||||||
|
self.y_length = 0
|
||||||
|
self.index_x = 0
|
||||||
|
self.index_y = 1
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_matrix(self, x_y_values):
|
||||||
|
xs = set()
|
||||||
|
ys = set()
|
||||||
|
for x, y, value in x_y_values:
|
||||||
|
xs.add(x)
|
||||||
|
ys.add(y)
|
||||||
|
|
||||||
|
if len(xs) < len(ys):
|
||||||
|
self.index_x = 0
|
||||||
|
self.index_y = 1
|
||||||
|
else:
|
||||||
|
self.index_x = 1
|
||||||
|
self.index_y = 0
|
||||||
|
xs, ys = ys, xs
|
||||||
|
|
||||||
|
x_dic = {x: i for i, x in enumerate(xs)}
|
||||||
|
y_dic = {y: j for j, y in enumerate(ys)}
|
||||||
|
self.x_nodes = [KMNode(x) for x in xs]
|
||||||
|
self.y_nodes = [KMNode(y) for y in ys]
|
||||||
|
self.x_length = len(xs)
|
||||||
|
self.y_length = len(ys)
|
||||||
|
|
||||||
|
self.matrix = np.zeros((self.x_length, self.y_length))
|
||||||
|
for row in x_y_values:
|
||||||
|
x = row[self.index_x]
|
||||||
|
y = row[self.index_y]
|
||||||
|
value = row[2]
|
||||||
|
x_index = x_dic[x]
|
||||||
|
y_index = y_dic[y]
|
||||||
|
self.matrix[x_index, y_index] = value
|
||||||
|
|
||||||
|
for i in range(self.x_length):
|
||||||
|
self.x_nodes[i].exception = max(self.matrix[i, :])
|
||||||
|
|
||||||
|
def km(self):
|
||||||
|
for i in range(self.x_length):
|
||||||
|
while True:
|
||||||
|
self.minz = float("inf")
|
||||||
|
self.set_false(self.x_nodes)
|
||||||
|
self.set_false(self.y_nodes)
|
||||||
|
|
||||||
|
if self.dfs(i):
|
||||||
|
break
|
||||||
|
|
||||||
|
self.change_exception(self.x_nodes, -self.minz)
|
||||||
|
self.change_exception(self.y_nodes, self.minz)
|
||||||
|
|
||||||
|
def dfs(self, i):
|
||||||
|
x_node = self.x_nodes[i]
|
||||||
|
x_node.visit = True
|
||||||
|
for j in range(self.y_length):
|
||||||
|
y_node = self.y_nodes[j]
|
||||||
|
if not y_node.visit:
|
||||||
|
t = x_node.exception + y_node.exception - self.matrix[i][j]
|
||||||
|
if abs(t) < zero_threshold:
|
||||||
|
y_node.visit = True
|
||||||
|
if y_node.match is None or self.dfs(y_node.match):
|
||||||
|
x_node.match = j
|
||||||
|
y_node.match = i
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if t >= zero_threshold:
|
||||||
|
self.minz = min(self.minz, t)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def set_false(self, nodes):
|
||||||
|
for node in nodes:
|
||||||
|
node.visit = False
|
||||||
|
|
||||||
|
def change_exception(self, nodes, change):
|
||||||
|
for node in nodes:
|
||||||
|
if node.visit:
|
||||||
|
node.exception += change
|
||||||
|
|
||||||
|
def get_connect_result(self):
|
||||||
|
ret = []
|
||||||
|
for i in range(self.x_length):
|
||||||
|
x_node = self.x_nodes[i]
|
||||||
|
j = x_node.match
|
||||||
|
y_node = self.y_nodes[j]
|
||||||
|
x_id = x_node.id
|
||||||
|
y_id = y_node.id
|
||||||
|
value = self.matrix[i][j]
|
||||||
|
|
||||||
|
if self.index_x == 1 and self.index_y == 0:
|
||||||
|
x_id, y_id = y_id, x_id
|
||||||
|
ret.append((x_id, y_id, value))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_max_value_result(self):
|
||||||
|
ret = -100
|
||||||
|
for i in range(self.x_length):
|
||||||
|
j = self.x_nodes[i].match
|
||||||
|
ret = max(ret, self.matrix[i][j])
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def run_kuhn_munkres(x_y_values):
|
||||||
|
process = KuhnMunkres()
|
||||||
|
process.set_matrix(x_y_values)
|
||||||
|
process.km()
|
||||||
|
return process.get_connect_result(), process.get_max_value_result()
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
values = []
|
||||||
|
random.seed(0)
|
||||||
|
for i in range(500):
|
||||||
|
for j in range(1000):
|
||||||
|
value = random.random()
|
||||||
|
values.append((i, j, value))
|
||||||
|
|
||||||
|
return run_kuhn_munkres(values)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)]
|
||||||
|
print(run_kuhn_munkres(values))
|
|
@ -0,0 +1,325 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class exchange_token:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, features, mask_box):
|
||||||
|
b, hw, c = features.size()
|
||||||
|
assert hw == 14*14
|
||||||
|
new_idx, mask_x1, mask_x2, mask_y1, mask_y2 = mask_box
|
||||||
|
features = features.view(b, 14, 14, c)
|
||||||
|
features[:, mask_x1 : mask_x2, mask_y1 : mask_y2, :] = features[new_idx, mask_x1 : mask_x2, mask_y1 : mask_y2, :]
|
||||||
|
features = features.view(b, hw, c)
|
||||||
|
return features
|
||||||
|
|
||||||
|
class jigsaw_token:
|
||||||
|
def __init__(self, shift=5, group=2, begin=1):
|
||||||
|
self.shift = shift
|
||||||
|
self.group = group
|
||||||
|
self.begin = begin
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
batchsize = features.size(0)
|
||||||
|
dim = features.size(2)
|
||||||
|
|
||||||
|
num_tokens = features.size(1)
|
||||||
|
if num_tokens == 196:
|
||||||
|
self.group = 2
|
||||||
|
elif num_tokens == 25:
|
||||||
|
self.group = 5
|
||||||
|
else:
|
||||||
|
raise Exception("Jigsaw - Unwanted number of tokens")
|
||||||
|
|
||||||
|
# Shift Operation
|
||||||
|
feature_random = torch.cat([features[:, self.begin-1+self.shift:, :], features[:, self.begin-1:self.begin-1+self.shift, :]], dim=1)
|
||||||
|
x = feature_random
|
||||||
|
|
||||||
|
# Patch Shuffle Operation
|
||||||
|
try:
|
||||||
|
x = x.view(batchsize, self.group, -1, dim)
|
||||||
|
except:
|
||||||
|
raise Exception("Jigsaw - Unwanted number of groups")
|
||||||
|
|
||||||
|
x = torch.transpose(x, 1, 2).contiguous()
|
||||||
|
x = x.view(batchsize, -1, dim)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class get_mask_box:
|
||||||
|
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
|
||||||
|
self.shape = shape
|
||||||
|
self.mask_size = mask_size
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
# Stripe mask
|
||||||
|
if self.shape == 'stripe':
|
||||||
|
if self.mode == 'horizontal':
|
||||||
|
mask_box = self.hstripe(features, self.mask_size)
|
||||||
|
elif self.mode == 'vertical':
|
||||||
|
mask_box = self.vstripe(features, self.mask_size)
|
||||||
|
elif self.mode == 'random_direction':
|
||||||
|
if random.random() < 0.5:
|
||||||
|
mask_box = self.hstripe(features, self.mask_size)
|
||||||
|
else:
|
||||||
|
mask_box = self.vstripe(features, self.mask_size)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown stripe mask mode name")
|
||||||
|
# Square mask
|
||||||
|
elif self.shape == 'square':
|
||||||
|
if self.mode == 'random_size':
|
||||||
|
self.mask_size = 4 if random.random() < 0.5 else 5
|
||||||
|
mask_box = self.square(features, self.mask_size)
|
||||||
|
# Random stripe/square mask
|
||||||
|
elif self.shape == 'random':
|
||||||
|
random_num = random.random()
|
||||||
|
if random_num < 0.25:
|
||||||
|
mask_box = self.hstripe(features, 2)
|
||||||
|
elif random_num < 0.5 and random_num >= 0.25:
|
||||||
|
mask_box = self.vstripe(features, 2)
|
||||||
|
elif random_num < 0.75 and random_num >= 0.5:
|
||||||
|
mask_box = self.square(features, 4)
|
||||||
|
else:
|
||||||
|
mask_box = self.square(features, 5)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown mask shape name")
|
||||||
|
return mask_box
|
||||||
|
|
||||||
|
def hstripe(self, features, mask_size):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# horizontal stripe
|
||||||
|
mask_x1 = 0
|
||||||
|
mask_x2 = features.shape[2]
|
||||||
|
y1_max = features.shape[3] - mask_size
|
||||||
|
mask_y1 = torch.randint(y1_max, (1,))
|
||||||
|
mask_y2 = mask_y1 + mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
|
||||||
|
return mask_box
|
||||||
|
|
||||||
|
def vstripe(self, features, mask_size):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# vertical stripe
|
||||||
|
mask_y1 = 0
|
||||||
|
mask_y2 = features.shape[3]
|
||||||
|
x1_max = features.shape[2] - mask_size
|
||||||
|
mask_x1 = torch.randint(x1_max, (1,))
|
||||||
|
mask_x2 = mask_x1 + mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
|
||||||
|
return mask_box
|
||||||
|
|
||||||
|
def square(self, features, mask_size):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# square
|
||||||
|
x1_max = features.shape[2] - mask_size
|
||||||
|
y1_max = features.shape[3] - mask_size
|
||||||
|
mask_x1 = torch.randint(x1_max, (1,))
|
||||||
|
mask_y1 = torch.randint(y1_max, (1,))
|
||||||
|
mask_x2 = mask_x1 + mask_size
|
||||||
|
mask_y2 = mask_y1 + mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
|
||||||
|
return mask_box
|
||||||
|
|
||||||
|
|
||||||
|
class exchange_patch:
|
||||||
|
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
|
||||||
|
self.shape = shape
|
||||||
|
self.mask_size = mask_size
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
# Stripe mask
|
||||||
|
if self.shape == 'stripe':
|
||||||
|
if self.mode == 'horizontal':
|
||||||
|
features = self.xpatch_hstripe(features, self.mask_size)
|
||||||
|
elif self.mode == 'vertical':
|
||||||
|
features = self.xpatch_vstripe(features, self.mask_size)
|
||||||
|
elif self.mode == 'random_direction':
|
||||||
|
if random.random() < 0.5:
|
||||||
|
features = self.xpatch_hstripe(features, self.mask_size)
|
||||||
|
else:
|
||||||
|
features = self.xpatch_vstripe(features, self.mask_size)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown stripe mask mode name")
|
||||||
|
# Square mask
|
||||||
|
elif self.shape == 'square':
|
||||||
|
if self.mode == 'random_size':
|
||||||
|
self.mask_size = 4 if random.random() < 0.5 else 5
|
||||||
|
features = self.xpatch_square(features, self.mask_size)
|
||||||
|
# Random stripe/square mask
|
||||||
|
elif self.shape == 'random':
|
||||||
|
random_num = random.random()
|
||||||
|
if random_num < 0.25:
|
||||||
|
features = self.xpatch_hstripe(features, 2)
|
||||||
|
elif random_num < 0.5 and random_num >= 0.25:
|
||||||
|
features = self.xpatch_vstripe(features, 2)
|
||||||
|
elif random_num < 0.75 and random_num >= 0.5:
|
||||||
|
features = self.xpatch_square(features, 4)
|
||||||
|
else:
|
||||||
|
features = self.xpatch_square(features, 5)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown mask shape name")
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def xpatch_hstripe(self, features, mask_size):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# horizontal stripe
|
||||||
|
y1_max = features.shape[3] - mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_y1 = torch.randint(y1_max, (1,))
|
||||||
|
mask_y2 = mask_y1 + mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
features[:, :, :, mask_y1 : mask_y2] = features[new_idx, :, :, mask_y1 : mask_y2]
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def xpatch_vstripe(self, features, mask_size):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# vertical stripe
|
||||||
|
x1_max = features.shape[2] - mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_x1 = torch.randint(x1_max, (1,))
|
||||||
|
mask_x2 = mask_x1 + mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
features[:, :, mask_x1 : mask_x2, :] = features[new_idx, :, mask_x1 : mask_x2, :]
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def xpatch_square(self, features, mask_size):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# square
|
||||||
|
x1_max = features.shape[2] - mask_size
|
||||||
|
y1_max = features.shape[3] - mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_x1 = torch.randint(x1_max, (1,))
|
||||||
|
mask_y1 = torch.randint(y1_max, (1,))
|
||||||
|
mask_x2 = mask_x1 + mask_size
|
||||||
|
mask_y2 = mask_y1 + mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
features[:, :, mask_x1 : mask_x2, mask_y1 : mask_y2] = features[new_idx, :, mask_x1 : mask_x2, mask_y1 : mask_y2]
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class cutout_patch:
|
||||||
|
def __init__(self, mask_size=2):
|
||||||
|
self.mask_size = mask_size
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
if random.random() < 0.5:
|
||||||
|
y1_max = features.shape[3] - self.mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_y1 = torch.randint(y1_max, (features.shape[0],))
|
||||||
|
mask_y2 = mask_y1 + self.mask_size
|
||||||
|
for k in range(features.shape[0]):
|
||||||
|
features[k, :, :, mask_y1[k] : mask_y2[k]] = 0
|
||||||
|
else:
|
||||||
|
x1_max = features.shape[3] - self.mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_x1 = torch.randint(x1_max, (features.shape[0],))
|
||||||
|
mask_x2 = mask_x1 + self.mask_size
|
||||||
|
for k in range(features.shape[0]):
|
||||||
|
features[k, :, mask_x1[k] : mask_x2[k], :] = 0
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class erase_patch:
|
||||||
|
def __init__(self, mask_size=2):
|
||||||
|
self.mask_size = mask_size
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
std, mean = torch.std_mean(features.detach())
|
||||||
|
dim = features.shape[1]
|
||||||
|
if random.random() < 0.5:
|
||||||
|
y1_max = features.shape[3] - self.mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_y1 = torch.randint(y1_max, (features.shape[0],))
|
||||||
|
mask_y2 = mask_y1 + self.mask_size
|
||||||
|
for k in range(features.shape[0]):
|
||||||
|
features[k, :, :, mask_y1[k] : mask_y2[k]] = torch.normal(mean.repeat(dim,14,2), std.repeat(dim,14,2))
|
||||||
|
else:
|
||||||
|
x1_max = features.shape[3] - self.mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_x1 = torch.randint(x1_max, (features.shape[0],))
|
||||||
|
mask_x2 = mask_x1 + self.mask_size
|
||||||
|
for k in range(features.shape[0]):
|
||||||
|
features[k, :, mask_x1[k] : mask_x2[k], :] = torch.normal(mean.repeat(dim,2,14), std.repeat(dim,2,14))
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
class mixup_patch:
|
||||||
|
def __init__(self, mask_size=2):
|
||||||
|
self.mask_size = mask_size
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
lam = random.uniform(0, 1)
|
||||||
|
if random.random() < 0.5:
|
||||||
|
y1_max = features.shape[3] - self.mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_y1 = torch.randint(y1_max, (1,))
|
||||||
|
mask_y2 = mask_y1 + self.mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
features[:, :, :, mask_y1 : mask_y2] = lam*features[:, :, :, mask_y1 : mask_y2] + (1-lam)*features[new_idx, :, :, mask_y1 : mask_y2]
|
||||||
|
else:
|
||||||
|
x1_max = features.shape[2] - self.mask_size
|
||||||
|
num_masks = 1
|
||||||
|
for i in range(num_masks):
|
||||||
|
mask_x1 = torch.randint(x1_max, (1,))
|
||||||
|
mask_x2 = mask_x1 + self.mask_size
|
||||||
|
new_idx = torch.randperm(features.shape[0])
|
||||||
|
features[:, :, mask_x1 : mask_x2, :] = lam*features[:, :, mask_x1 : mask_x2, :] + (1-lam)*features[new_idx, :, mask_x1 : mask_x2, :]
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
class jigsaw_patch:
|
||||||
|
def __init__(self, shift=5, group=2):
|
||||||
|
self.shift = shift
|
||||||
|
self.group = group
|
||||||
|
|
||||||
|
def __call__(self, features):
|
||||||
|
batchsize = features.size(0)
|
||||||
|
dim = features.size(1)
|
||||||
|
features = features.view(batchsize, dim, -1)
|
||||||
|
|
||||||
|
# Shift Operation
|
||||||
|
feature_random = torch.cat([features[:, :, self.shift:], features[:, :, :self.shift]], dim=2)
|
||||||
|
x = feature_random
|
||||||
|
|
||||||
|
# Patch Shuffle Operation
|
||||||
|
try:
|
||||||
|
x = x.view(batchsize, dim, self.group, -1)
|
||||||
|
except:
|
||||||
|
x = torch.cat([x, x[:, -2:-1, :]], dim=1)
|
||||||
|
x = x.view(batchsize, self.group, -1, dim)
|
||||||
|
|
||||||
|
x = torch.transpose(x, 2, 3).contiguous()
|
||||||
|
|
||||||
|
x = x.view(batchsize, dim, -1)
|
||||||
|
x = x.view(batchsize, dim, 14, 14)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
from torchvision.transforms import functional as F
|
||||||
|
|
||||||
|
def mixup_data(images, alpha=0.8):
|
||||||
|
if alpha > 0. and alpha < 1.:
|
||||||
|
lam = random.uniform(alpha, 1)
|
||||||
|
else:
|
||||||
|
lam = 1.
|
||||||
|
|
||||||
|
batch_size = len(images)
|
||||||
|
min_x = 9999
|
||||||
|
min_y = 9999
|
||||||
|
for i in range(batch_size):
|
||||||
|
min_x = min(min_x, images[i].shape[1])
|
||||||
|
min_y = min(min_y, images[i].shape[2])
|
||||||
|
|
||||||
|
shuffle_images = deepcopy(images)
|
||||||
|
random.shuffle(shuffle_images)
|
||||||
|
mixed_images = deepcopy(images)
|
||||||
|
for i in range(batch_size):
|
||||||
|
mixed_images[i][:, :min_x, :min_y] = lam * images[i][:, :min_x, :min_y] + (1 - lam) * shuffle_images[i][:, :min_x, :min_y]
|
||||||
|
|
||||||
|
return mixed_images
|
||||||
|
|
||||||
|
class Compose:
|
||||||
|
def __init__(self, transforms):
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __call__(self, image, target):
|
||||||
|
for t in self.transforms:
|
||||||
|
image, target = t(image, target)
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
class RandomHorizontalFlip:
|
||||||
|
def __init__(self, prob=0.5):
|
||||||
|
self.prob = prob
|
||||||
|
|
||||||
|
def __call__(self, image, target):
|
||||||
|
if random.random() < self.prob:
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
image = image.flip(-1)
|
||||||
|
bbox = target["boxes"]
|
||||||
|
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
|
||||||
|
target["boxes"] = bbox
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
class Cutout(object):
|
||||||
|
"""Randomly mask out one or more patches from an image.
|
||||||
|
https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
|
||||||
|
Args:
|
||||||
|
n_holes (int): Number of patches to cut out of each image.
|
||||||
|
length (int): The length (in pixels) of each square patch.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_holes=2, length=100):
|
||||||
|
self.n_holes = n_holes
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (Tensor): Tensor image of size (C, H, W).
|
||||||
|
Returns:
|
||||||
|
Tensor: Image with n_holes of dimension length x length cut out of it.
|
||||||
|
"""
|
||||||
|
h = img.size(1)
|
||||||
|
w = img.size(2)
|
||||||
|
mask = np.ones((h, w), np.float32)
|
||||||
|
|
||||||
|
for n in range(self.n_holes):
|
||||||
|
y = np.random.randint(h)
|
||||||
|
x = np.random.randint(w)
|
||||||
|
y1 = np.clip(y - self.length // 2, 0, h)
|
||||||
|
y2 = np.clip(y + self.length // 2, 0, h)
|
||||||
|
x1 = np.clip(x - self.length // 2, 0, w)
|
||||||
|
x2 = np.clip(x + self.length // 2, 0, w)
|
||||||
|
mask[y1: y2, x1: x2] = 0.
|
||||||
|
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
mask = mask.expand_as(img)
|
||||||
|
img = img * mask
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
|
||||||
|
class RandomErasing(object):
|
||||||
|
'''
|
||||||
|
https://github.com/zhunzhong07/CamStyle/blob/master/reid/utils/data/transforms.py
|
||||||
|
'''
|
||||||
|
def __init__(self, EPSILON=0.5, mean=[0.485, 0.456, 0.406]):
|
||||||
|
self.EPSILON = EPSILON
|
||||||
|
self.mean = mean
|
||||||
|
|
||||||
|
def __call__(self, img, target):
|
||||||
|
if random.uniform(0, 1) > self.EPSILON:
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
for attempt in range(100):
|
||||||
|
area = img.size()[1] * img.size()[2]
|
||||||
|
|
||||||
|
target_area = random.uniform(0.02, 0.2) * area
|
||||||
|
aspect_ratio = random.uniform(0.3, 3)
|
||||||
|
|
||||||
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
|
|
||||||
|
if w <= img.size()[2] and h <= img.size()[1]:
|
||||||
|
x1 = random.randint(0, img.size()[1] - h)
|
||||||
|
y1 = random.randint(0, img.size()[2] - w)
|
||||||
|
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
|
||||||
|
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
|
||||||
|
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
|
||||||
|
class ToTensor:
|
||||||
|
def __call__(self, image, target):
|
||||||
|
# convert [0, 255] to [0, 1]
|
||||||
|
image = F.to_tensor(image)
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
|
||||||
|
def build_transforms(cfg, is_train):
|
||||||
|
transforms = []
|
||||||
|
transforms.append(ToTensor())
|
||||||
|
if is_train:
|
||||||
|
transforms.append(RandomHorizontalFlip())
|
||||||
|
if cfg.INPUT.IMAGE_CUTOUT:
|
||||||
|
transforms.append(Cutout())
|
||||||
|
if cfg.INPUT.IMAGE_ERASE:
|
||||||
|
transforms.append(RandomErasing())
|
||||||
|
|
||||||
|
return Compose(transforms)
|
|
@ -0,0 +1,436 @@
|
||||||
|
# This file is part of COAT, and is distributed under the
|
||||||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||||||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import errno
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Logger #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
class SmoothedValue(object):
|
||||||
|
"""
|
||||||
|
Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return
|
||||||
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median,
|
||||||
|
avg=self.avg,
|
||||||
|
global_avg=self.global_avg,
|
||||||
|
max=self.max,
|
||||||
|
value=self.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger(object):
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
assert isinstance(v, (float, int))
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append("{}: {}".format(name, str(meter)))
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
"max mem: {memory:.0f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(
|
||||||
|
"{} Total time: {} ({:.4f} s / it)".format(
|
||||||
|
header, total_time_str, total_time / len(iterable)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Distributed training #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
def all_gather(data):
|
||||||
|
"""
|
||||||
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[data]: list of data gathered from each rank
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
# serialized to a Tensor
|
||||||
|
buffer = pickle.dumps(data)
|
||||||
|
storage = torch.ByteStorage.from_buffer(buffer)
|
||||||
|
tensor = torch.ByteTensor(storage).to("cuda")
|
||||||
|
|
||||||
|
# obtain Tensor size of each rank
|
||||||
|
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||||
|
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||||
|
dist.all_gather(size_list, local_size)
|
||||||
|
size_list = [int(size.item()) for size in size_list]
|
||||||
|
max_size = max(size_list)
|
||||||
|
|
||||||
|
# receiving Tensor from all ranks
|
||||||
|
# we pad the tensor because torch all_gather does not support
|
||||||
|
# gathering tensors of different shapes
|
||||||
|
tensor_list = []
|
||||||
|
for _ in size_list:
|
||||||
|
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||||
|
if local_size != max_size:
|
||||||
|
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||||
|
tensor = torch.cat((tensor, padding), dim=0)
|
||||||
|
dist.all_gather(tensor_list, tensor)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for size, tensor in zip(size_list, tensor_list):
|
||||||
|
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||||
|
data_list.append(pickle.loads(buffer))
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_dict(input_dict, average=True):
|
||||||
|
"""
|
||||||
|
Reduce the values in the dictionary from all processes so that all processes
|
||||||
|
have the averaged results. Returns a dict with the same fields as
|
||||||
|
input_dict, after reduction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dict (dict): all the values will be reduced
|
||||||
|
average (bool): whether to do average or sum
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size < 2:
|
||||||
|
return input_dict
|
||||||
|
with torch.no_grad():
|
||||||
|
names = []
|
||||||
|
values = []
|
||||||
|
# sort the keys so that they are consistent across processes
|
||||||
|
for k in sorted(input_dict.keys()):
|
||||||
|
names.append(k)
|
||||||
|
values.append(input_dict[k])
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
dist.all_reduce(values)
|
||||||
|
if average:
|
||||||
|
values /= world_size
|
||||||
|
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||||
|
return reduced_dict
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
elif "SLURM_PROCID" in os.environ:
|
||||||
|
args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend,
|
||||||
|
init_method=args.dist_url,
|
||||||
|
world_size=args.world_size,
|
||||||
|
rank=args.rank,
|
||||||
|
)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# File operation #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
def filename(path):
|
||||||
|
return osp.splitext(osp.basename(path))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def read_json(fpath):
|
||||||
|
with open(fpath, "r") as f:
|
||||||
|
obj = json.load(f)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def write_json(obj, fpath):
|
||||||
|
mkdir(osp.dirname(fpath))
|
||||||
|
_obj = obj.copy()
|
||||||
|
for k, v in _obj.items():
|
||||||
|
if isinstance(v, np.ndarray):
|
||||||
|
_obj.pop(k)
|
||||||
|
with open(fpath, "w") as f:
|
||||||
|
json.dump(_obj, f, indent=4, separators=(",", ": "))
|
||||||
|
|
||||||
|
|
||||||
|
def symlink(src, dst, overwrite=True, **kwargs):
|
||||||
|
if os.path.lexists(dst) and overwrite:
|
||||||
|
os.remove(dst)
|
||||||
|
os.symlink(src, dst, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
# Misc #
|
||||||
|
# -------------------------------------------------------- #
|
||||||
|
def create_small_table(small_dict):
|
||||||
|
"""
|
||||||
|
Create a small table using the keys of small_dict as headers. This is only
|
||||||
|
suitable for small dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
small_dict (dict): a result dictionary of only a few items.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the table as a string.
|
||||||
|
"""
|
||||||
|
keys, values = tuple(zip(*small_dict.items()))
|
||||||
|
table = tabulate(
|
||||||
|
[values],
|
||||||
|
headers=keys,
|
||||||
|
tablefmt="pipe",
|
||||||
|
floatfmt=".3f",
|
||||||
|
stralign="center",
|
||||||
|
numalign="center",
|
||||||
|
)
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
|
||||||
|
def f(x):
|
||||||
|
if x >= warmup_iters:
|
||||||
|
return 1
|
||||||
|
alpha = float(x) / warmup_iters
|
||||||
|
return warmup_factor * (1 - alpha) + alpha
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
|
||||||
|
|
||||||
|
|
||||||
|
def resume_from_ckpt(ckpt_path, model, optimizer=None, lr_scheduler=None):
|
||||||
|
ckpt = torch.load(ckpt_path)
|
||||||
|
model.load_state_dict(ckpt["model"], strict=False)
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer"])
|
||||||
|
if lr_scheduler is not None:
|
||||||
|
lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
|
||||||
|
print(f"loaded checkpoint {ckpt_path}")
|
||||||
|
print(f"model was trained for {ckpt['epoch']} epochs")
|
||||||
|
return ckpt["epoch"]
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue