更新pvtv2-b2的分支

master
詹力 2023-10-10 21:52:30 +08:00
parent bbe2e5694b
commit b44c1bfb16
27 changed files with 29070 additions and 0 deletions

3
coat-pvtv2-b2/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
**/__pycache__
*.pth
**/logs

View File

@ -0,0 +1,26 @@
name: coat
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- cudatoolkit=11.0
- numpy=1.19.2
- pillow=8.2.0
- pip=21.0.1
- python=3.8.8
- pytorch=1.7.1
- scipy=1.6.2
- torchvision=0.8.2
- tqdm=4.60.0
- scikit-learn=0.24.1
- black=21.5b0
- flake8=3.9.0
- isort=5.8.0
- tabulate=0.8.9
- future=0.18.2
- tensorboard=2.4.1
- tensorboardx=2.2
- pip:
- ipython==7.5.0
- yacs==0.1.8

150
coat-pvtv2-b2/README.md Normal file
View File

@ -0,0 +1,150 @@
# **COAT代码使用说明**
这个存储库托管了论文的源代码:[[CVPR 2022] Cascade Transformers for End-to-End Person Search](https://arxiv.org/abs/2203.09642)。在这项工作中我们开发了一种新颖的级联遮挡感知TransformerCOAT模型用于端到端的人物搜索。COAT模型在PRW基准数据集上以显著的优势胜过了最先进的方法并在CUHK-SYSU数据集上取得了最先进的性能。
| 数据集(Datasets) | mAP | Top-1 | Model |
| ---------------- | ---- | ----- | ------------------------------------------------------------ |
| CUHK-SYSU | 94.2 | 94.7 | [model](https://drive.google.com/file/d/1LkEwXYaJg93yk4Kfhyk3m6j8v3i9s1B7/view?usp=sharing) |
| PRW | 53.3 | 87.4 | [model](https://drive.google.com/file/d/1vEd_zzFN88RgxbRMG5-WfJZgD3vmP0Xg/view?usp=sharing) |
**Abstract**: The goal of person search is to localize a target person from a gallery set of scene images, which is extremely challenging due to large scale variations, pose/viewpoint changes, and occlusions. In this paper, we propose the Cascade Occluded Attention Transformer (COAT) for end-to-end person search. Specifically, our three-stage cascade design focuses on detecting people at the first stage, then progressively refines the representation for person detection and re-identification simultaneously at the following stages. The occluded attention transformer at each stage applies tighter intersection over union thresholds, forcing the network to learn coarse-to-fine pose/scale invariant features. Meanwhile, we calculate the occluded attention across instances in a mini-batch to differentiate tokens from other people or the background. In this way, we simulate the effect of other objects occluding a person of interest at the token-level. Through comprehensive experiments, we demonstrate the benefits of our method by achieving state-of-the-art performance on two benchmark datasets.
![COAT](doc/framework.png)
## Installation
1. Download the datasets in your path `$DATA_DIR`. Change the dataset paths in L4 in [cuhk_sysu.yaml](configs/cuhk_sysu.yaml) and [prw.yaml](configs/prw.yaml).
**PRW**:
```
cd $DATA_DIR
pip install gdown
gdown https://drive.google.com/uc?id=0B6tjyrV1YrHeYnlhNnhEYTh5MUU
unzip PRW-v16.04.20.zip
mv PRW-v16.04.20 PRW
```
**CUHK-SYSU**:
```
cd $DATA_DIR
gdown https://drive.google.com/uc?id=1z3LsFrJTUeEX3-XjSEJMOBrslxD2T5af
tar -xzvf cuhk_sysu.tar.gz
mv cuhk_sysu CUHK-SYSU
```
2. Our method is tested with PyTorch 1.7.1. You can install the required packages by anaconda/miniconda with the following commands:
```
cd COAT
conda env create -f COAT_pt171.yml
conda activate coat
```
If you want to install another version of PyTorch, you can modify the versions in `coat_pt171.yml`. Just make sure the dependencies have the appropriate version.
## CUHK-SYSU数据集实验
**训练**: 目前代码只支持单GPU. The default training script for CUHK-SYSU is as follows:
**在本地GTX4090训练**
``` bash
cd COAT
# 说明4090显存较小所以batchsize只能设置为2, 实测可以运行
python train.py --cfg configs/cuhk_sysu-local.yaml INPUT.BATCH_SIZE_TRAIN 2 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 14 SOLVER.LR_DECAY_MILESTONES [11] MODEL.LOSS.USE_SOFTMAX True SOLVER.LW_RCNN_SOFTMAX_2ND 0.1 SOLVER.LW_RCNN_SOFTMAX_3RD 0.1 OUTPUT_DIR ./logs/cuhk-sysu
```
**在本地UESTC训练**
```bash
cd COAT
# 说明RTX8000显存48G所以batchsize只能设置为3
python train.py --cfg configs/cuhk_sysu.yaml INPUT.BATCH_SIZE_TRAIN 2 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 14 SOLVER.LR_DECAY_MILESTONES [11] MODEL.LOSS.USE_SOFTMAX True SOLVER.LW_RCNN_SOFTMAX_2ND 0.1 SOLVER.LW_RCNN_SOFTMAX_3RD 0.1 OUTPUT_DIR ./logs/cuhk-sysu
```
Note that the dataset-specific parameters are defined in `configs/cuhk_sysu.yaml`. When the batch size (`INPUT.BATCH_SIZE_TRAIN`) is 3, the training will take about 23GB GPU memory, being suitable for GPUs like RTX6000. When the batch size is 5, the training will take about 38GB GPU memory, being able to run on A100 GPU. The larger batch size usually results in better performance on CUHK-SYSU.
For the CUHK-SYSU dataset, we use a relative low weight for softmax loss (`SOLVER.LW_RCNN_SOFTMAX_2ND` 0.1 and `SOLVER.LW_RCNN_SOFTMAX_3RD` 0.1). The trained models and TF logs will be saved in the folder `OUTPUT_DIR`. Other important training parameters can be found in the file `COAT/defaults.py`. For example, `CKPT_PERIOD` is the frequency of saving a checkpoint model.
**Testing**: The test script is very simple. You just need to add the flag `--eval` and provide the folder `--ckpt` where the [model](https://drive.google.com/file/d/1LkEwXYaJg93yk4Kfhyk3m6j8v3i9s1B7/view?usp=sharing) was saved.
测试这个测试脚本非常简单你只需要添加flag --eval以及对应提供--ckpt当模型已经保存的时候
```
python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth
```
**Testing with CBGM**: Context Bipartite Graph Matching ([CBGM](https://github.com/serend1p1ty/SeqNet)) is an optimized matching algorithm in test phase. The detail can be found in the paper [[AAAI 2021] Sequential End-to-end Network for Efficient Person Search](https://arxiv.org/abs/2103.10148). We can use CBGM to further improve the person search accuracy. In test script, we just set the flag `EVAL_USE_CBGM` to True (default is False).
```
python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth EVAL_USE_CB GM True
```
**Testing with different gallery sizes on CUHK-SYSU**: The default gallery size for evaluating CUHK-SYSU is 100. If you want to test with other pre-defined gallery sizes (50, 100, 500, 1000, 2000, 4000) for drawing the CUHK-SYSU gallery size curve, please set the parameter `EVAL_GALLERY_SIZE` with a gallery size.
```
python train.py --cfg ./configs/cuhk-sysu/config.yaml --eval --ckpt ./logs/cuhk-sysu/cuhk_COAT.pth EVAL_GALLER Y_SIZE 500
```
## Experiments on PRW
**Training**: The script is similar to CUHK-SYSU. The code currently only supports single GPU. The default training script for PRW is as follows:
**在本地GTX4090训练**
```bash
cd COAT
# PRW数据集较小可以RTX4090的bs可以设置为3
python train.py --cfg ./configs/prw-local.yaml INPUT.BATCH_SIZE_TRAIN 3 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 13 MODEL.LOSS.USE_SOFTMAX True OUTPUT_DIR ./logs/prw
```
**在本地UESTC训练**
```bash
cd COAT
# PRW数据集较小可以RTX4090的bs可以设置为3
python train.py --cfg ./configs/prw.yaml INPUT.BATCH_SIZE_TRAIN 3 SOLVER.BASE_LR 0.003 SOLVER.MAX_EPOCHS 13 MODEL.LOSS.USE_SOFTMAX True OUTPUT_DIR ./logs/prw
```
The dataset-specific parameters are defined in `configs/prw.yaml`. When the batch size (`INPUT.BATCH_SIZE_TRAIN`) is 3, the training will take about 19GB GPU memory, being suitable for GPUs like RTX6000. The larger batch size does not necessarily result in better accuracy on the PRW dataset.
Softmax loss is effective on PRW. The default weights of softmax loss at Stage 2 and Stage 3 (`SOLVER.LW_RCNN_SOFTMAX_2ND` and `SOLVER.LW_RCNN_SOFTMAX_3RD`) are 0.5, which can be found in the file `COAT/defaults.py`. If you want to run a model without Softmax loss for comparison, just set `MODEL.LOSS.USE_SOFTMAX` to False in the script.
**Testing**: The test script is similar to CUHK-SYSU. Make sure the path of pre-trained model [model](https://drive.google.com/file/d/1vEd_zzFN88RgxbRMG5-WfJZgD3vmP0Xg/view?usp=sharing) is correct.
```
python train.py --cfg ./logs/prw/config.yaml --eval --ckpt ./logs/prw/prw_COAT.pth
```
**Testing with CBGM**: Similar to CUHK-SYSU, set the flag `EVAL_USE_CBGM` to True (default is False).
```
python train.py --cfg ./logs/prw/config.yaml --eval --ckpt ./logs/prw/prw_COAT.pth EVAL_USE_CBGM True
```
## Acknowledgement
This code borrows from [SeqNet](https://github.com/serend1p1ty/SeqNet), [TransReID](https://github.com/damo-cv/TransReID), and [DSTT](https://github.com/ruiliu-ai/DSTT).
## Citation
If you use this code in your research, please cite this project as follows:
```
@inproceedings{yu2022coat,
title = {Cascade Transformers for End-to-End Person Search},
author = {Rui Yu and
Dawei Du and
Rodney LaLonde and
Daniel Davila and
Christopher Funk and
Anthony Hoogs and
Brian Clipp},
booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition},
year = {2022}
}
```
## License
This work is distributed under the OSI-approved BSD 3-Clause [License](https://github.com/Kitware/COAT/blob/master/LICENSE).

View File

@ -0,0 +1,15 @@
OUTPUT_DIR: "./logs/cuhk_coat"
INPUT:
DATASET: "CUHK-SYSU"
DATA_ROOT: "E:/DeepLearning/PersonSearch/COAT/datasets/CUHK-SYSU"
BATCH_SIZE_TRAIN: 3
SOLVER:
MAX_EPOCHS: 14
BASE_LR: 0.003
LW_RCNN_SOFTMAX_2ND: 0.1
LW_RCNN_SOFTMAX_3RD: 0.1
MODEL:
LOSS:
LUT_SIZE: 5532
CQ_SIZE: 5000
DISP_PERIOD: 100

View File

@ -0,0 +1,15 @@
OUTPUT_DIR: "./logs/cuhk_coat"
INPUT:
DATASET: "CUHK-SYSU"
DATA_ROOT: "/home/logzhan/datasets/CUHK-SYSU"
BATCH_SIZE_TRAIN: 4
SOLVER:
MAX_EPOCHS: 14
BASE_LR: 0.003
LW_RCNN_SOFTMAX_2ND: 0.1
LW_RCNN_SOFTMAX_3RD: 0.1
MODEL:
LOSS:
LUT_SIZE: 5532
CQ_SIZE: 5000
DISP_PERIOD: 100

View File

@ -0,0 +1,13 @@
OUTPUT_DIR: "./logs/prw_coat"
INPUT:
DATASET: "PRW"
DATA_ROOT: "E:/DeepLearning/PersonSearch/COAT/datasets/PRW"
BATCH_SIZE_TRAIN: 3
SOLVER:
MAX_EPOCHS: 13
BASE_LR: 0.003
MODEL:
LOSS:
LUT_SIZE: 482
CQ_SIZE: 500
DISP_PERIOD: 100

View File

@ -0,0 +1,13 @@
OUTPUT_DIR: "./logs/prw_coat"
INPUT:
DATASET: "PRW"
DATA_ROOT: "../../datasets/PRW"
BATCH_SIZE_TRAIN: 3
SOLVER:
MAX_EPOCHS: 13
BASE_LR: 0.003
MODEL:
LOSS:
LUT_SIZE: 482
CQ_SIZE: 500
DISP_PERIOD: 100

View File

@ -0,0 +1,5 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from .build import build_test_loader, build_train_loader

View File

@ -0,0 +1,42 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import torch
from PIL import Image
class BaseDataset:
"""
Base class of person search dataset.
"""
def __init__(self, root, transforms, split):
self.root = root
self.transforms = transforms
self.split = split
assert self.split in ("train", "gallery", "query")
self.annotations = self._load_annotations()
def _load_annotations(self):
"""
For each image, load its annotation that is a dictionary with the following keys:
img_name (str): image name
img_path (str): image path
boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format
pids (np.array[N]): person IDs corresponding to these boxes
cam_id (int): camera ID (only for PRW dataset)
"""
raise NotImplementedError
def __getitem__(self, index):
anno = self.annotations[index]
img = Image.open(anno["img_path"]).convert("RGB")
boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32)
labels = torch.as_tensor(anno["pids"], dtype=torch.int64)
target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels}
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.annotations)

View File

@ -0,0 +1,104 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import torch
from utils.transforms import build_transforms
from utils.utils import create_small_table
from .cuhk_sysu import CUHKSYSU
from .prw import PRW
def print_statistics(dataset):
"""
Print dataset statistics.
"""
num_imgs = len(dataset.annotations)
num_boxes = 0
pid_set = set()
for anno in dataset.annotations:
num_boxes += anno["boxes"].shape[0]
for pid in anno["pids"]:
pid_set.add(pid)
statistics = {
"dataset": dataset.name,
"split": dataset.split,
"num_images": num_imgs,
"num_boxes": num_boxes,
}
if dataset.name != "CUHK-SYSU" or dataset.split != "query":
pid_list = sorted(list(pid_set))
if dataset.split == "query":
num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
statistics.update(
{
"num_labeled_pids": num_pids,
"min_labeled_pid": int(min_pid),
"max_labeled_pid": int(max_pid),
}
)
else:
unlabeled_pid = pid_list[-1]
pid_list = pid_list[:-1] # remove unlabeled pid
num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
statistics.update(
{
"num_labeled_pids": num_pids,
"min_labeled_pid": int(min_pid),
"max_labeled_pid": int(max_pid),
"unlabeled_pid": int(unlabeled_pid),
}
)
print(f"=> {dataset.name}-{dataset.split} loaded:\n" + create_small_table(statistics))
def build_dataset(dataset_name, root, transforms, split, verbose=True):
if dataset_name == "CUHK-SYSU":
dataset = CUHKSYSU(root, transforms, split)
elif dataset_name == "PRW":
dataset = PRW(root, transforms, split)
else:
raise NotImplementedError(f"Unknow dataset: {dataset_name}")
if verbose:
print_statistics(dataset)
return dataset
def collate_fn(batch):
return tuple(zip(*batch))
def build_train_loader(cfg):
transforms = build_transforms(cfg, is_train=True)
dataset = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "train")
return torch.utils.data.DataLoader(
dataset,
batch_size=cfg.INPUT.BATCH_SIZE_TRAIN,
shuffle=True,
num_workers=cfg.INPUT.NUM_WORKERS_TRAIN,
pin_memory=True,
drop_last=True,
collate_fn=collate_fn,
)
def build_test_loader(cfg):
transforms = build_transforms(cfg, is_train=False)
gallery_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "gallery")
query_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "query")
gallery_loader = torch.utils.data.DataLoader(
gallery_set,
batch_size=cfg.INPUT.BATCH_SIZE_TEST,
shuffle=False,
num_workers=cfg.INPUT.NUM_WORKERS_TEST,
pin_memory=True,
collate_fn=collate_fn,
)
query_loader = torch.utils.data.DataLoader(
query_set,
batch_size=cfg.INPUT.BATCH_SIZE_TEST,
shuffle=False,
num_workers=cfg.INPUT.NUM_WORKERS_TEST,
pin_memory=True,
collate_fn=collate_fn,
)
return gallery_loader, query_loader

View File

@ -0,0 +1,121 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import os.path as osp
import numpy as np
from scipy.io import loadmat
from .base import BaseDataset
class CUHKSYSU(BaseDataset):
def __init__(self, root, transforms, split):
self.name = "CUHK-SYSU"
self.img_prefix = osp.join(root, "Image", "SSM")
super(CUHKSYSU, self).__init__(root, transforms, split)
def _load_queries(self):
# TestG50: a test protocol, 50 gallery images per query
protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat"))
protoc = protoc["TestG50"].squeeze()
queries = []
for item in protoc["Query"]:
img_name = str(item["imname"][0, 0][0])
roi = item["idlocate"][0, 0][0].astype(np.int32)
roi[2:] += roi[:2]
queries.append(
{
"img_name": img_name,
"img_path": osp.join(self.img_prefix, img_name),
"boxes": roi[np.newaxis, :],
"pids": np.array([-100]), # dummy pid
}
)
return queries
def _load_split_img_names(self):
"""
Load the image names for the specific split.
"""
assert self.split in ("train", "gallery")
# gallery images
gallery_imgs = loadmat(osp.join(self.root, "annotation", "pool.mat"))
gallery_imgs = gallery_imgs["pool"].squeeze()
gallery_imgs = [str(a[0]) for a in gallery_imgs]
if self.split == "gallery":
return gallery_imgs
# all images
all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat"))
all_imgs = all_imgs["Img"].squeeze()
all_imgs = [str(a[0][0]) for a in all_imgs]
# training images = all images - gallery images
training_imgs = sorted(list(set(all_imgs) - set(gallery_imgs)))
return training_imgs
def _load_annotations(self):
if self.split == "query":
return self._load_queries()
# load all images and build a dict from image to boxes
all_imgs = loadmat(osp.join(self.root, "annotation", "Images.mat"))
all_imgs = all_imgs["Img"].squeeze()
name_to_boxes = {}
name_to_pids = {}
unlabeled_pid = 5555 # default pid for unlabeled people
for img_name, _, boxes in all_imgs:
img_name = str(img_name[0])
boxes = np.asarray([b[0] for b in boxes[0]])
boxes = boxes.reshape(boxes.shape[0], 4) # (x1, y1, w, h)
valid_index = np.where((boxes[:, 2] > 0) & (boxes[:, 3] > 0))[0]
assert valid_index.size > 0, "Warning: {} has no valid boxes.".format(img_name)
boxes = boxes[valid_index]
name_to_boxes[img_name] = boxes.astype(np.int32)
name_to_pids[img_name] = unlabeled_pid * np.ones(boxes.shape[0], dtype=np.int32)
def set_box_pid(boxes, box, pids, pid):
for i in range(boxes.shape[0]):
if np.all(boxes[i] == box):
pids[i] = pid
return
# assign a unique pid from 1 to N for each identity
if self.split == "train":
train = loadmat(osp.join(self.root, "annotation/test/train_test/Train.mat"))
train = train["Train"].squeeze()
for index, item in enumerate(train):
scenes = item[0, 0][2].squeeze()
for img_name, box, _ in scenes:
img_name = str(img_name[0])
box = box.squeeze().astype(np.int32)
set_box_pid(name_to_boxes[img_name], box, name_to_pids[img_name], index + 1)
else:
protoc = loadmat(osp.join(self.root, "annotation/test/train_test/TestG50.mat"))
protoc = protoc["TestG50"].squeeze()
for index, item in enumerate(protoc):
# query
im_name = str(item["Query"][0, 0][0][0])
box = item["Query"][0, 0][1].squeeze().astype(np.int32)
set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1)
# gallery
gallery = item["Gallery"].squeeze()
for im_name, box, _ in gallery:
im_name = str(im_name[0])
if box.size == 0:
break
box = box.squeeze().astype(np.int32)
set_box_pid(name_to_boxes[im_name], box, name_to_pids[im_name], index + 1)
annotations = []
imgs = self._load_split_img_names()
for img_name in imgs:
boxes = name_to_boxes[img_name]
boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2)
pids = name_to_pids[img_name]
annotations.append(
{
"img_name": img_name,
"img_path": osp.join(self.img_prefix, img_name),
"boxes": boxes,
"pids": pids,
}
)
return annotations

View File

@ -0,0 +1,97 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import os.path as osp
import re
import numpy as np
from scipy.io import loadmat
from .base import BaseDataset
class PRW(BaseDataset):
def __init__(self, root, transforms, split):
self.name = "PRW"
self.img_prefix = osp.join(root, "frames")
super(PRW, self).__init__(root, transforms, split)
def _get_cam_id(self, img_name):
match = re.search(r"c\d", img_name).group().replace("c", "")
return int(match)
def _load_queries(self):
query_info = osp.join(self.root, "query_info.txt")
with open(query_info, "rb") as f:
raw = f.readlines()
queries = []
for line in raw:
linelist = str(line, "utf-8").split(" ")
pid = int(linelist[0])
x, y, w, h = (
float(linelist[1]),
float(linelist[2]),
float(linelist[3]),
float(linelist[4]),
)
roi = np.array([x, y, x + w, y + h]).astype(np.int32)
roi = np.clip(roi, 0, None) # several coordinates are negative
img_name = linelist[5][:-2] + ".jpg"
queries.append(
{
"img_name": img_name,
"img_path": osp.join(self.img_prefix, img_name),
"boxes": roi[np.newaxis, :],
"pids": np.array([pid]),
"cam_id": self._get_cam_id(img_name),
}
)
return queries
def _load_split_img_names(self):
"""
Load the image names for the specific split.
"""
assert self.split in ("train", "gallery")
if self.split == "train":
imgs = loadmat(osp.join(self.root, "frame_train.mat"))["img_index_train"]
else:
imgs = loadmat(osp.join(self.root, "frame_test.mat"))["img_index_test"]
return [img[0][0] + ".jpg" for img in imgs]
def _load_annotations(self):
if self.split == "query":
return self._load_queries()
annotations = []
imgs = self._load_split_img_names()
for img_name in imgs:
anno_path = osp.join(self.root, "annotations", img_name)
anno = loadmat(anno_path)
box_key = "box_new"
if box_key not in anno.keys():
box_key = "anno_file"
if box_key not in anno.keys():
box_key = "anno_previous"
rois = anno[box_key][:, 1:]
ids = anno[box_key][:, 0]
rois = np.clip(rois, 0, None) # several coordinates are negative
assert len(rois) == len(ids)
rois[:, 2:] += rois[:, :2]
ids[ids == -2] = 5555 # assign pid = 5555 for unlabeled people
annotations.append(
{
"img_name": img_name,
"img_path": osp.join(self.img_prefix, img_name),
"boxes": rois.astype(np.int32),
# (training pids) 1, 2,..., 478, 480, 481, 482, 483, 932, 5555
"pids": ids.astype(np.int32),
"cam_id": self._get_cam_id(img_name),
}
)
return annotations

219
coat-pvtv2-b2/defaults.py Normal file
View File

@ -0,0 +1,219 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from yacs.config import CfgNode as CN
_C = CN()
# -------------------------------------------------------- #
# Input #
# -------------------------------------------------------- #
_C.INPUT = CN()
_C.INPUT.DATASET = "CUHK-SYSU"
_C.INPUT.DATA_ROOT = "data/CUHK-SYSU"
# Size of the smallest side of the image
_C.INPUT.MIN_SIZE = 900
# Maximum size of the side of the image
_C.INPUT.MAX_SIZE = 1500
# Number of images per batch
_C.INPUT.BATCH_SIZE_TRAIN = 5
_C.INPUT.BATCH_SIZE_TEST = 1
# Number of data loading threads
_C.INPUT.NUM_WORKERS_TRAIN = 5
_C.INPUT.NUM_WORKERS_TEST = 1
# Image augmentation
_C.INPUT.IMAGE_CUTOUT = False
_C.INPUT.IMAGE_ERASE = False
_C.INPUT.IMAGE_MIXUP = False
# -------------------------------------------------------- #
# GRID #
# -------------------------------------------------------- #
_C.INPUT.IMAGE_GRID = False
_C.GRID = CN()
_C.GRID.ROTATE = 1
_C.GRID.OFFSET = 0
_C.GRID.RATIO = 0.5
_C.GRID.MODE = 1
_C.GRID.PROB = 0.5
# -------------------------------------------------------- #
# Solver #
# -------------------------------------------------------- #
_C.SOLVER = CN()
_C.SOLVER.MAX_EPOCHS = 13
# Learning rate settings
_C.SOLVER.BASE_LR = 0.003
# The epoch milestones to decrease the learning rate by GAMMA
_C.SOLVER.LR_DECAY_MILESTONES = [10, 14]
_C.SOLVER.GAMMA = 0.1
_C.SOLVER.WEIGHT_DECAY = 0.0005
_C.SOLVER.SGD_MOMENTUM = 0.9
# Loss weight of RPN regression
_C.SOLVER.LW_RPN_REG = 1
# Loss weight of RPN classification
_C.SOLVER.LW_RPN_CLS = 1
# Loss weight of Cascade R-CNN and Re-ID (OIM)
_C.SOLVER.LW_RCNN_REG_1ST = 10
_C.SOLVER.LW_RCNN_CLS_1ST = 1
_C.SOLVER.LW_RCNN_REG_2ND = 10
_C.SOLVER.LW_RCNN_CLS_2ND = 1
_C.SOLVER.LW_RCNN_REG_3RD = 10
_C.SOLVER.LW_RCNN_CLS_3RD = 1
_C.SOLVER.LW_RCNN_REID_2ND = 0.5
_C.SOLVER.LW_RCNN_REID_3RD = 0.5
# Loss weight of box reid, softmax loss
_C.SOLVER.LW_RCNN_SOFTMAX_2ND = 0.5
_C.SOLVER.LW_RCNN_SOFTMAX_3RD = 0.5
# Set to negative value to disable gradient clipping
_C.SOLVER.CLIP_GRADIENTS = 10.0
# -------------------------------------------------------- #
# RPN #
# -------------------------------------------------------- #
_C.MODEL = CN()
_C.MODEL.RPN = CN()
# NMS threshold used on RoIs
_C.MODEL.RPN.NMS_THRESH = 0.7
# Number of anchors per image used to train RPN
_C.MODEL.RPN.BATCH_SIZE_TRAIN = 256
# Target fraction of foreground examples per RPN minibatch
_C.MODEL.RPN.POS_FRAC_TRAIN = 0.5
# Overlap threshold for an anchor to be considered foreground (if >= POS_THRESH_TRAIN)
_C.MODEL.RPN.POS_THRESH_TRAIN = 0.7
# Overlap threshold for an anchor to be considered background (if < NEG_THRESH_TRAIN)
_C.MODEL.RPN.NEG_THRESH_TRAIN = 0.3
# Number of top scoring RPN RoIs to keep before applying NMS
_C.MODEL.RPN.PRE_NMS_TOPN_TRAIN = 12000
_C.MODEL.RPN.PRE_NMS_TOPN_TEST = 6000
# Number of top scoring RPN RoIs to keep after applying NMS
_C.MODEL.RPN.POST_NMS_TOPN_TRAIN = 2000
_C.MODEL.RPN.POST_NMS_TOPN_TEST = 300
# -------------------------------------------------------- #
# RoI head #
# -------------------------------------------------------- #
_C.MODEL.ROI_HEAD = CN()
# Whether to use bn neck (i.e. batch normalization after linear)
_C.MODEL.ROI_HEAD.BN_NECK = True
# Number of RoIs per image used to train RoI head
_C.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN = 128
# Target fraction of foreground examples per RoI minibatch
_C.MODEL.ROI_HEAD.POS_FRAC_TRAIN = 0.25 # 0.5
_C.MODEL.ROI_HEAD.USE_DIFF_THRESH = True
# Overlap threshold for an RoI to be considered foreground (if >= POS_THRESH_TRAIN)
_C.MODEL.ROI_HEAD.POS_THRESH_TRAIN = 0.5
_C.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND = 0.6
_C.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD = 0.7
# Overlap threshold for an RoI to be considered background (if < NEG_THRESH_TRAIN)
_C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN = 0.5
_C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND = 0.6
_C.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD = 0.7
# Minimum score threshold
_C.MODEL.ROI_HEAD.SCORE_THRESH_TEST = 0.5
# NMS threshold used on boxes
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST = 0.4
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST = 0.4
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND = 0.4
_C.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD = 0.5
# Maximum number of detected objects
_C.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST = 300
# -------------------------------------------------------- #
# Transformer head #
# -------------------------------------------------------- #
_C.MODEL.TRANSFORMER = CN()
_C.MODEL.TRANSFORMER.DIM_MODEL = 512
_C.MODEL.TRANSFORMER.ENCODER_LAYERS = 1
_C.MODEL.TRANSFORMER.N_HEAD = 8
_C.MODEL.TRANSFORMER.USE_OUTPUT_LAYER = False
_C.MODEL.TRANSFORMER.DROPOUT = 0.
_C.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT = True
_C.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT = True
_C.MODEL.TRANSFORMER.USE_DIFF_SCALE = True
_C.MODEL.TRANSFORMER.NAMES_1ST = ['scale1','scale2']
_C.MODEL.TRANSFORMER.NAMES_2ND = ['scale1','scale2']
_C.MODEL.TRANSFORMER.NAMES_3RD = ['scale1','scale2']
_C.MODEL.TRANSFORMER.KERNEL_SIZE_1ST = [(1,1),(3,3)]
_C.MODEL.TRANSFORMER.KERNEL_SIZE_2ND = [(1,1),(3,3)]
_C.MODEL.TRANSFORMER.KERNEL_SIZE_3RD = [(1,1),(3,3)]
_C.MODEL.TRANSFORMER.USE_MASK_1ST = False
_C.MODEL.TRANSFORMER.USE_MASK_2ND = True
_C.MODEL.TRANSFORMER.USE_MASK_3RD = True
_C.MODEL.TRANSFORMER.USE_PATCH2VEC = True
####
_C.MODEL.USE_FEATURE_MASK = True
_C.MODEL.FEATURE_AUG_TYPE = 'exchange_token' # 'exchange_token', 'jigsaw_token', 'cutout_patch', 'erase_patch', 'mixup_patch', 'jigsaw_patch'
_C.MODEL.FEATURE_MASK_SIZE = 4
_C.MODEL.MASK_SHAPE = 'stripe' # 'square', 'random'
_C.MODEL.MASK_SIZE = 1
_C.MODEL.MASK_MODE = 'random_direction' # 'horizontal', 'vertical' for stripe; 'random_size' for square
_C.MODEL.MASK_PERCENT = 0.1
####
_C.MODEL.EMBEDDING_DIM = 256
# -------------------------------------------------------- #
# Loss #
# -------------------------------------------------------- #
_C.MODEL.LOSS = CN()
# Size of the lookup table in OIM
_C.MODEL.LOSS.LUT_SIZE = 5532
# Size of the circular queue in OIM
_C.MODEL.LOSS.CQ_SIZE = 5000
_C.MODEL.LOSS.OIM_MOMENTUM = 0.5
_C.MODEL.LOSS.OIM_SCALAR = 30.0
_C.MODEL.LOSS.USE_SOFTMAX = True
# -------------------------------------------------------- #
# Evaluation #
# -------------------------------------------------------- #
# The period to evaluate the model during training
_C.EVAL_PERIOD = 1
# Evaluation with GT boxes to verify the upper bound of person search performance
_C.EVAL_USE_GT = False
# Fast evaluation with cached features
_C.EVAL_USE_CACHE = False
# Evaluation with Context Bipartite Graph Matching (CBGM) algorithm
_C.EVAL_USE_CBGM = False
# Gallery size in evaluation, only for CUHK-SYSU
_C.EVAL_GALLERY_SIZE = 100
# Feature used for evaluation
_C.EVAL_FEATURE = 'concat' # 'stage2', 'stage3'
# -------------------------------------------------------- #
# Miscs #
# -------------------------------------------------------- #
# Save a checkpoint after every this number of epochs
_C.CKPT_PERIOD = 1
# The period (in terms of iterations) to display training losses
_C.DISP_PERIOD = 10
# Whether to use tensorboard for visualization
_C.TF_BOARD = True
# The device loading the model
_C.DEVICE = "cuda:0"
# Set seed to negative to fully randomize everything
_C.SEED = 1
# Directory where output files are written
_C.OUTPUT_DIR = "./output"
def get_default_cfg():
"""
Get a copy of the default config.
"""
return _C.clone()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

179
coat-pvtv2-b2/engine.py Normal file
View File

@ -0,0 +1,179 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import math
import sys
from copy import deepcopy
import torch
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from eval_func import eval_detection, eval_search_cuhk, eval_search_prw
from utils.utils import MetricLogger, SmoothedValue, mkdir, reduce_dict, warmup_lr_scheduler
from utils.transforms import mixup_data
def to_device(images, targets, device):
images = [image.to(device) for image in images]
for t in targets:
t["boxes"] = t["boxes"].to(device)
t["labels"] = t["labels"].to(device)
return images, targets
def train_one_epoch(cfg, model, optimizer, data_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3):
model.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
# warmup learning rate in the first epoch
if epoch == 0:
warmup_factor = 1.0 / 1000
warmup_iters = len(data_loader) - 1
warmup_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)
for i, (images, targets) in enumerate(
metric_logger.log_every(data_loader, cfg.DISP_PERIOD, header)
):
images, targets = to_device(images, targets, device)
# if using image based data augmentation
if cfg.INPUT.IMAGE_MIXUP:
images = mixup_data(images, alpha=0.8)
loss_dict, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = model(images, targets)
if cfg.MODEL.LOSS.USE_SOFTMAX:
softmax_loss_2nd = cfg.SOLVER.LW_RCNN_SOFTMAX_2ND * softmax_criterion_s2(feats_reid_2nd, targets_reid_2nd)
softmax_loss_3rd = cfg.SOLVER.LW_RCNN_SOFTMAX_3RD * softmax_criterion_s3(feats_reid_3rd, targets_reid_3rd)
loss_dict.update(loss_box_softmax_2nd=softmax_loss_2nd)
loss_dict.update(loss_box_softmax_3rd=softmax_loss_3rd)
losses = sum(loss for loss in loss_dict.values())
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = reduce_dict(loss_dict)
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
loss_value = losses_reduced.item()
if not math.isfinite(loss_value):
print(f"Loss is {loss_value}, stopping training")
print(loss_dict_reduced)
sys.exit(1)
optimizer.zero_grad()
losses.backward()
if cfg.SOLVER.CLIP_GRADIENTS > 0:
clip_grad_norm_(model.parameters(), cfg.SOLVER.CLIP_GRADIENTS)
optimizer.step()
if epoch == 0:
warmup_scheduler.step()
metric_logger.update(loss=loss_value, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
if tfboard:
iter = epoch * len(data_loader) + i
for k, v in loss_dict_reduced.items():
tfboard.add_scalars("train", {k: v}, iter)
@torch.no_grad()
def evaluate_performance(
model, gallery_loader, query_loader, device, use_gt=False, use_cache=False, use_cbgm=False, gallery_size=100):
"""
Args:
use_gt (bool, optional): Whether to use GT as detection results to verify the upper
bound of person search performance. Defaults to False.
use_cache (bool, optional): Whether to use the cached features. Defaults to False.
use_cbgm (bool, optional): Whether to use Context Bipartite Graph Matching algorithm.
Defaults to False.
"""
model.eval()
if use_cache:
eval_cache = torch.load("data/eval_cache/eval_cache.pth")
gallery_dets = eval_cache["gallery_dets"]
gallery_feats = eval_cache["gallery_feats"]
query_dets = eval_cache["query_dets"]
query_feats = eval_cache["query_feats"]
query_box_feats = eval_cache["query_box_feats"]
else:
gallery_dets, gallery_feats = [], []
for images, targets in tqdm(gallery_loader, ncols=0):
images, targets = to_device(images, targets, device)
if not use_gt:
outputs = model(images)
else:
boxes = targets[0]["boxes"]
n_boxes = boxes.size(0)
embeddings = model(images, targets)
outputs = [
{
"boxes": boxes,
"embeddings": torch.cat(embeddings),
"labels": torch.ones(n_boxes).to(device),
"scores": torch.ones(n_boxes).to(device),
}
]
for output in outputs:
box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1)
gallery_dets.append(box_w_scores.cpu().numpy())
gallery_feats.append(output["embeddings"].cpu().numpy())
# regarding query image as gallery to detect all people
# i.e. query person + surrounding people (context information)
query_dets, query_feats = [], []
for images, targets in tqdm(query_loader, ncols=0):
images, targets = to_device(images, targets, device)
# targets will be modified in the model, so deepcopy it
outputs = model(images, deepcopy(targets), query_img_as_gallery=True)
# consistency check
gt_box = targets[0]["boxes"].squeeze()
assert (
gt_box - outputs[0]["boxes"][0]
).sum() <= 0.001, "GT box must be the first one in the detected boxes of query image"
for output in outputs:
box_w_scores = torch.cat([output["boxes"], output["scores"].unsqueeze(1)], dim=1)
query_dets.append(box_w_scores.cpu().numpy())
query_feats.append(output["embeddings"].cpu().numpy())
# extract the features of query boxes
query_box_feats = []
for images, targets in tqdm(query_loader, ncols=0):
images, targets = to_device(images, targets, device)
embeddings = model(images, targets)
assert len(embeddings) == 1, "batch size in test phase should be 1"
query_box_feats.append(embeddings[0].cpu().numpy())
mkdir("data/eval_cache")
save_dict = {
"gallery_dets": gallery_dets,
"gallery_feats": gallery_feats,
"query_dets": query_dets,
"query_feats": query_feats,
"query_box_feats": query_box_feats,
}
torch.save(save_dict, "data/eval_cache/eval_cache.pth")
eval_detection(gallery_loader.dataset, gallery_dets, det_thresh=0.01)
eval_search_func = (
eval_search_cuhk if gallery_loader.dataset.name == "CUHK-SYSU" else eval_search_prw
)
eval_search_func(
gallery_loader.dataset,
query_loader.dataset,
gallery_dets,
gallery_feats,
query_box_feats,
query_dets,
query_feats,
cbgm=use_cbgm,
gallery_size=gallery_size,
)

488
coat-pvtv2-b2/eval_func.py Normal file
View File

@ -0,0 +1,488 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import os.path as osp
import numpy as np
from scipy.io import loadmat
from sklearn.metrics import average_precision_score
from utils.km import run_kuhn_munkres
from utils.utils import write_json
def _compute_iou(a, b):
x1 = max(a[0], b[0])
y1 = max(a[1], b[1])
x2 = min(a[2], b[2])
y2 = min(a[3], b[3])
inter = max(0, x2 - x1) * max(0, y2 - y1)
union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
return inter * 1.0 / union
def eval_detection(
gallery_dataset, gallery_dets, det_thresh=0.5, iou_thresh=0.5, labeled_only=False
):
"""
gallery_det (list of ndarray): n_det x [x1, y1, x2, y2, score] per image
det_thresh (float): filter out gallery detections whose scores below this
iou_thresh (float): treat as true positive if IoU is above this threshold
labeled_only (bool): filter out unlabeled background people
"""
assert len(gallery_dataset) == len(gallery_dets)
annos = gallery_dataset.annotations
y_true, y_score = [], []
count_gt, count_tp = 0, 0
for anno, det in zip(annos, gallery_dets):
gt_boxes = anno["boxes"]
if labeled_only:
# exclude the unlabeled people (pid == 5555)
inds = np.where(anno["pids"].ravel() != 5555)[0]
if len(inds) == 0:
continue
gt_boxes = gt_boxes[inds]
num_gt = gt_boxes.shape[0]
if det != []:
det = np.asarray(det)
inds = np.where(det[:, 4].ravel() >= det_thresh)[0]
det = det[inds]
num_det = det.shape[0]
else:
num_det = 0
if num_det == 0:
count_gt += num_gt
continue
ious = np.zeros((num_gt, num_det), dtype=np.float32)
for i in range(num_gt):
for j in range(num_det):
ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4])
tfmat = ious >= iou_thresh
# for each det, keep only the largest iou of all the gt
for j in range(num_det):
largest_ind = np.argmax(ious[:, j])
for i in range(num_gt):
if i != largest_ind:
tfmat[i, j] = False
# for each gt, keep only the largest iou of all the det
for i in range(num_gt):
largest_ind = np.argmax(ious[i, :])
for j in range(num_det):
if j != largest_ind:
tfmat[i, j] = False
for j in range(num_det):
y_score.append(det[j, -1])
y_true.append(tfmat[:, j].any())
count_tp += tfmat.sum()
count_gt += num_gt
det_rate = count_tp * 1.0 / count_gt
ap = average_precision_score(y_true, y_score) * det_rate
print("{} detection:".format("labeled only" if labeled_only else "all"))
print(" recall = {:.2%}".format(det_rate))
if not labeled_only:
print(" ap = {:.2%}".format(ap))
return det_rate, ap
def eval_search_cuhk(
gallery_dataset,
query_dataset,
gallery_dets,
gallery_feats,
query_box_feats,
query_dets,
query_feats,
k1=10,
k2=3,
det_thresh=0.5,
cbgm=False,
gallery_size=100,
):
"""
gallery_dataset/query_dataset: an instance of BaseDataset
gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image
gallery_feat (list of ndarray): n_det x D features per image
query_feat (list of ndarray): D dimensional features per query image
det_thresh (float): filter out gallery detections whose scores below this
gallery_size (int): gallery size [-1, 50, 100, 500, 1000, 2000, 4000]
-1 for using full set
"""
assert len(gallery_dataset) == len(gallery_dets)
assert len(gallery_dataset) == len(gallery_feats)
assert len(query_dataset) == len(query_box_feats)
use_full_set = gallery_size == -1
fname = "TestG{}".format(gallery_size if not use_full_set else 50)
protoc = loadmat(osp.join(gallery_dataset.root, "annotation/test/train_test", fname + ".mat"))
protoc = protoc[fname].squeeze()
# mapping from gallery image to (det, feat)
annos = gallery_dataset.annotations
name_to_det_feat = {}
for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
name = anno["img_name"]
if det != []:
scores = det[:, 4].ravel()
inds = np.where(scores >= det_thresh)[0]
if len(inds) > 0:
name_to_det_feat[name] = (det[inds], feat[inds])
aps = []
accs = []
topk = [1, 5, 10]
ret = {"image_root": gallery_dataset.img_prefix, "results": []}
for i in range(len(query_dataset)):
y_true, y_score = [], []
imgs, rois = [], []
count_gt, count_tp = 0, 0
# get L2-normalized feature vector
feat_q = query_box_feats[i].ravel()
# ignore the query image
query_imname = str(protoc["Query"][i]["imname"][0, 0][0])
query_roi = protoc["Query"][i]["idlocate"][0, 0][0].astype(np.int32)
query_roi[2:] += query_roi[:2]
query_gt = []
tested = set([query_imname])
name2sim = {}
name2gt = {}
sims = []
imgs_cbgm = []
# 1. Go through the gallery samples defined by the protocol
for item in protoc["Gallery"][i].squeeze():
gallery_imname = str(item[0][0])
# some contain the query (gt not empty), some not
gt = item[1][0].astype(np.int32)
count_gt += gt.size > 0
# compute distance between query and gallery dets
if gallery_imname not in name_to_det_feat:
continue
det, feat_g = name_to_det_feat[gallery_imname]
# no detection in this gallery, skip it
if det.shape[0] == 0:
continue
# get L2-normalized feature matrix NxD
assert feat_g.size == np.prod(feat_g.shape[:2])
feat_g = feat_g.reshape(feat_g.shape[:2])
# compute cosine similarities
sim = feat_g.dot(feat_q).ravel()
if gallery_imname in name2sim:
continue
name2sim[gallery_imname] = sim
name2gt[gallery_imname] = gt
sims.extend(list(sim))
imgs_cbgm.extend([gallery_imname] * len(sim))
# 2. Go through the remaining gallery images if using full set
if use_full_set:
for gallery_imname in gallery_dataset.imgs:
if gallery_imname in tested:
continue
if gallery_imname not in name_to_det_feat:
continue
det, feat_g = name_to_det_feat[gallery_imname]
# get L2-normalized feature matrix NxD
assert feat_g.size == np.prod(feat_g.shape[:2])
feat_g = feat_g.reshape(feat_g.shape[:2])
# compute cosine similarities
sim = feat_g.dot(feat_q).ravel()
# guaranteed no target query in these gallery images
label = np.zeros(len(sim), dtype=np.int32)
y_true.extend(list(label))
y_score.extend(list(sim))
imgs.extend([gallery_imname] * len(sim))
rois.extend(list(det))
if cbgm:
# -------- Context Bipartite Graph Matching (CBGM) ------- #
sims = np.array(sims)
imgs_cbgm = np.array(imgs_cbgm)
# only process the top-k1 gallery images for efficiency
inds = np.argsort(sims)[-k1:]
imgs_cbgm = set(imgs_cbgm[inds])
for img in imgs_cbgm:
sim = name2sim[img]
det, feat_g = name_to_det_feat[img]
# only regard the people with top-k2 detection confidence
# in the query image as context information
qboxes = query_dets[i][:k2]
qfeats = query_feats[i][:k2]
assert (
query_roi - qboxes[0][:4]
).sum() <= 0.001, "query_roi must be the first one in pboxes"
# build the bipartite graph and run Kuhn-Munkres (K-M) algorithm
# to find the best match
graph = []
for indx_i, pfeat in enumerate(qfeats):
for indx_j, gfeat in enumerate(feat_g):
graph.append((indx_i, indx_j, (pfeat * gfeat).sum()))
km_res, max_val = run_kuhn_munkres(graph)
# revise the similarity between query person and its matching
for indx_i, indx_j, _ in km_res:
# 0 denotes the query roi
if indx_i == 0:
sim[indx_j] = max_val
break
for gallery_imname, sim in name2sim.items():
gt = name2gt[gallery_imname]
det, feat_g = name_to_det_feat[gallery_imname]
# assign label for each det
label = np.zeros(len(sim), dtype=np.int32)
if gt.size > 0:
w, h = gt[2], gt[3]
gt[2:] += gt[:2]
query_gt.append({"img": str(gallery_imname), "roi": list(map(float, list(gt)))})
iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
inds = np.argsort(sim)[::-1]
sim = sim[inds]
det = det[inds]
# only set the first matched det as true positive
for j, roi in enumerate(det[:, :4]):
if _compute_iou(roi, gt) >= iou_thresh:
label[j] = 1
count_tp += 1
break
y_true.extend(list(label))
y_score.extend(list(sim))
imgs.extend([gallery_imname] * len(sim))
rois.extend(list(det))
tested.add(gallery_imname)
# 3. Compute AP for this query (need to scale by recall rate)
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
assert count_tp <= count_gt
recall_rate = count_tp * 1.0 / count_gt
ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
aps.append(ap)
inds = np.argsort(y_score)[::-1]
y_score = y_score[inds]
y_true = y_true[inds]
accs.append([min(1, sum(y_true[:k])) for k in topk])
# 4. Save result for JSON dump
new_entry = {
"query_img": str(query_imname),
"query_roi": list(map(float, list(query_roi))),
"query_gt": query_gt,
"gallery": [],
}
# only record wrong results
if int(y_true[0]):
continue
# only save top-10 predictions
for k in range(10):
new_entry["gallery"].append(
{
"img": str(imgs[inds[k]]),
"roi": list(map(float, list(rois[inds[k]]))),
"score": float(y_score[k]),
"correct": int(y_true[k]),
}
)
ret["results"].append(new_entry)
print("search ranking:")
print(" mAP = {:.2%}".format(np.mean(aps)))
accs = np.mean(accs, axis=0)
for i, k in enumerate(topk):
print(" top-{:2d} = {:.2%}".format(k, accs[i]))
write_json(ret, "vis/results.json")
ret["mAP"] = np.mean(aps)
ret["accs"] = accs
return ret
def eval_search_prw(
gallery_dataset,
query_dataset,
gallery_dets,
gallery_feats,
query_box_feats,
query_dets,
query_feats,
k1=30,
k2=4,
det_thresh=0.5,
cbgm=False,
gallery_size=None, # not used in PRW
ignore_cam_id=True,
):
"""
gallery_det (list of ndarray): n_det x [x1, x2, y1, y2, score] per image
gallery_feat (list of ndarray): n_det x D features per image
query_feat (list of ndarray): D dimensional features per query image
det_thresh (float): filter out gallery detections whose scores below this
gallery_size (int): -1 for using full set
ignore_cam_id (bool): Set to True acoording to CUHK-SYSU,
although it's a common practice to focus on cross-cam match only.
"""
assert len(gallery_dataset) == len(gallery_dets)
assert len(gallery_dataset) == len(gallery_feats)
assert len(query_dataset) == len(query_box_feats)
annos = gallery_dataset.annotations
name_to_det_feat = {}
for anno, det, feat in zip(annos, gallery_dets, gallery_feats):
name = anno["img_name"]
scores = det[:, 4].ravel()
inds = np.where(scores >= det_thresh)[0]
if len(inds) > 0:
name_to_det_feat[name] = (det[inds], feat[inds])
aps = []
accs = []
topk = [1, 5, 10]
ret = {"image_root": gallery_dataset.img_prefix, "results": []}
for i in range(len(query_dataset)):
y_true, y_score = [], []
imgs, rois = [], []
count_gt, count_tp = 0, 0
feat_p = query_box_feats[i].ravel()
query_imname = query_dataset.annotations[i]["img_name"]
query_roi = query_dataset.annotations[i]["boxes"]
query_pid = query_dataset.annotations[i]["pids"]
query_cam = query_dataset.annotations[i]["cam_id"]
# Find all occurence of this query
gallery_imgs = []
for x in annos:
if query_pid in x["pids"] and x["img_name"] != query_imname:
gallery_imgs.append(x)
query_gts = {}
for item in gallery_imgs:
query_gts[item["img_name"]] = item["boxes"][item["pids"] == query_pid]
# Construct gallery set for this query
if ignore_cam_id:
gallery_imgs = []
for x in annos:
if x["img_name"] != query_imname:
gallery_imgs.append(x)
else:
gallery_imgs = []
for x in annos:
if x["img_name"] != query_imname and x["cam_id"] != query_cam:
gallery_imgs.append(x)
name2sim = {}
sims = []
imgs_cbgm = []
# 1. Go through all gallery samples
for item in gallery_imgs:
gallery_imname = item["img_name"]
# some contain the query (gt not empty), some not
count_gt += gallery_imname in query_gts
# compute distance between query and gallery dets
if gallery_imname not in name_to_det_feat:
continue
det, feat_g = name_to_det_feat[gallery_imname]
# get L2-normalized feature matrix NxD
assert feat_g.size == np.prod(feat_g.shape[:2])
feat_g = feat_g.reshape(feat_g.shape[:2])
# compute cosine similarities
sim = feat_g.dot(feat_p).ravel()
if gallery_imname in name2sim:
continue
name2sim[gallery_imname] = sim
sims.extend(list(sim))
imgs_cbgm.extend([gallery_imname] * len(sim))
if cbgm:
sims = np.array(sims)
imgs_cbgm = np.array(imgs_cbgm)
inds = np.argsort(sims)[-k1:]
imgs_cbgm = set(imgs_cbgm[inds])
for img in imgs_cbgm:
sim = name2sim[img]
det, feat_g = name_to_det_feat[img]
qboxes = query_dets[i][:k2]
qfeats = query_feats[i][:k2]
# assert (
# query_roi - qboxes[0][:4]
# ).sum() <= 0.001, "query_roi must be the first one in pboxes"
graph = []
for indx_i, pfeat in enumerate(qfeats):
for indx_j, gfeat in enumerate(feat_g):
graph.append((indx_i, indx_j, (pfeat * gfeat).sum()))
km_res, max_val = run_kuhn_munkres(graph)
for indx_i, indx_j, _ in km_res:
if indx_i == 0:
sim[indx_j] = max_val
break
for gallery_imname, sim in name2sim.items():
det, feat_g = name_to_det_feat[gallery_imname]
# assign label for each det
label = np.zeros(len(sim), dtype=np.int32)
if gallery_imname in query_gts:
gt = query_gts[gallery_imname].ravel()
w, h = gt[2] - gt[0], gt[3] - gt[1]
iou_thresh = min(0.5, (w * h * 1.0) / ((w + 10) * (h + 10)))
inds = np.argsort(sim)[::-1]
sim = sim[inds]
det = det[inds]
# only set the first matched det as true positive
for j, roi in enumerate(det[:, :4]):
if _compute_iou(roi, gt) >= iou_thresh:
label[j] = 1
count_tp += 1
break
y_true.extend(list(label))
y_score.extend(list(sim))
imgs.extend([gallery_imname] * len(sim))
rois.extend(list(det))
# 2. Compute AP for this query (need to scale by recall rate)
y_score = np.asarray(y_score)
y_true = np.asarray(y_true)
assert count_tp <= count_gt
recall_rate = count_tp * 1.0 / count_gt
ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
aps.append(ap)
inds = np.argsort(y_score)[::-1]
y_score = y_score[inds]
y_true = y_true[inds]
accs.append([min(1, sum(y_true[:k])) for k in topk])
# 4. Save result for JSON dump
new_entry = {
"query_img": str(query_imname),
"query_roi": list(map(float, list(query_roi.squeeze()))),
"query_gt": query_gts,
"gallery": [],
}
# only save top-10 predictions
for k in range(10):
new_entry["gallery"].append(
{
"img": str(imgs[inds[k]]),
"roi": list(map(float, list(rois[inds[k]]))),
"score": float(y_score[k]),
"correct": int(y_true[k]),
}
)
ret["results"].append(new_entry)
print("search ranking:")
mAP = np.mean(aps)
print(" mAP = {:.2%}".format(mAP))
accs = np.mean(accs, axis=0)
for i, k in enumerate(topk):
print(" top-{:2d} = {:.2%}".format(k, accs[i]))
# write_json(ret, "vis/results.json")
ret["mAP"] = np.mean(aps)
ret["accs"] = accs
return ret

76
coat-pvtv2-b2/loss/oim.py Normal file
View File

@ -0,0 +1,76 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import torch
import torch.nn.functional as F
from torch import autograd, nn
class OIM(autograd.Function):
@staticmethod
def forward(ctx, inputs, targets, lut, cq, header, momentum):
ctx.save_for_backward(inputs, targets, lut, cq, header, momentum)
outputs_labeled = inputs.mm(lut.t())
outputs_unlabeled = inputs.mm(cq.t())
return torch.cat([outputs_labeled, outputs_unlabeled], dim=1)
@staticmethod
def backward(ctx, grad_outputs):
inputs, targets, lut, cq, header, momentum = ctx.saved_tensors
grad_inputs = None
if ctx.needs_input_grad[0]:
grad_inputs = grad_outputs.mm(torch.cat([lut, cq], dim=0))
if grad_inputs.dtype == torch.float16:
grad_inputs = grad_inputs.to(torch.float32)
for x, y in zip(inputs, targets):
if y < len(lut):
lut[y] = momentum * lut[y] + (1.0 - momentum) * x
lut[y] /= lut[y].norm()
else:
cq[header] = x
header = (header + 1) % cq.size(0)
return grad_inputs, None, None, None, None, None
def oim(inputs, targets, lut, cq, header, momentum=0.5):
return OIM.apply(inputs, targets, lut, cq, torch.tensor(header), torch.tensor(momentum))
class OIMLoss(nn.Module):
def __init__(self, num_features, num_pids, num_cq_size, oim_momentum, oim_scalar):
super(OIMLoss, self).__init__()
self.num_features = num_features
self.num_pids = num_pids
self.num_unlabeled = num_cq_size
self.momentum = oim_momentum
self.oim_scalar = oim_scalar
self.register_buffer("lut", torch.zeros(self.num_pids, self.num_features))
self.register_buffer("cq", torch.zeros(self.num_unlabeled, self.num_features))
self.header_cq = 0
def forward(self, inputs, roi_label):
# merge into one batch, background label = 0
targets = torch.cat(roi_label)
label = targets - 1 # background label = -1
inds = label >= 0
label = label[inds]
inputs = inputs[inds.unsqueeze(1).expand_as(inputs)].view(-1, self.num_features)
projected = oim(inputs, label, self.lut, self.cq, self.header_cq, momentum=self.momentum)
# projected - Tensor [M, lut+cq], e.g., [M, 482+500]=[M, 982]
projected *= self.oim_scalar
self.header_cq = (
self.header_cq + (label >= self.num_pids).long().sum().item()
) % self.num_unlabeled
loss_oim = F.cross_entropy(projected, label, ignore_index=5554)
return loss_oim, inputs, label

View File

@ -0,0 +1,62 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import torch
from torch import nn
import torch.nn.functional as F
class SoftmaxLoss(nn.Module):
def __init__(self, cfg):
super(SoftmaxLoss, self).__init__()
self.feat_dim = cfg.MODEL.EMBEDDING_DIM
self.num_classes = cfg.MODEL.LOSS.LUT_SIZE
self.bottleneck = nn.BatchNorm1d(self.feat_dim)
self.bottleneck.bias.requires_grad_(False) # no shift
self.classifier = nn.Linear(self.feat_dim, self.num_classes, bias=False)
self.bottleneck.apply(weights_init_kaiming)
self.classifier.apply(weights_init_classifier)
def forward(self, inputs, labels):
"""
Args:
inputs: feature matrix with shape (batch_size, feat_dim).
labels: ground truth labels with shape (num_classes).
"""
assert inputs.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)"
target = labels.clone()
target[target >= self.num_classes] = 5554
feat = self.bottleneck(inputs)
score = self.classifier(feat)
loss = F.cross_entropy(score, target, ignore_index=5554)
return loss
def weights_init_kaiming(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
nn.init.constant_(m.bias, 0.0)
elif classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif classname.find('BatchNorm') != -1:
if m.affine:
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
def weights_init_classifier(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
nn.init.normal_(m.weight, std=0.001)
if m.bias:
nn.init.constant_(m.bias, 0.0)

View File

@ -0,0 +1,765 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import boxes as box_ops
from torchvision.models.detection import _utils as det_utils
from loss.oim import OIMLoss
from models.resnet import build_resnet
from models.transformer import TransformerHead
class COAT(nn.Module):
def __init__(self, cfg):
super(COAT, self).__init__()
backbone, _ = build_resnet(name="resnet50", pretrained=True)
anchor_generator = AnchorGenerator(
sizes=((32, 64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)
)
head = RPNHead(
in_channels=backbone.out_channels,
num_anchors=anchor_generator.num_anchors_per_location()[0],
)
pre_nms_top_n = dict(
training=cfg.MODEL.RPN.PRE_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.PRE_NMS_TOPN_TEST
)
post_nms_top_n = dict(
training=cfg.MODEL.RPN.POST_NMS_TOPN_TRAIN, testing=cfg.MODEL.RPN.POST_NMS_TOPN_TEST
)
rpn = RegionProposalNetwork(
anchor_generator=anchor_generator,
head=head,
fg_iou_thresh=cfg.MODEL.RPN.POS_THRESH_TRAIN,
bg_iou_thresh=cfg.MODEL.RPN.NEG_THRESH_TRAIN,
batch_size_per_image=cfg.MODEL.RPN.BATCH_SIZE_TRAIN,
positive_fraction=cfg.MODEL.RPN.POS_FRAC_TRAIN,
pre_nms_top_n=pre_nms_top_n,
post_nms_top_n=post_nms_top_n,
nms_thresh=cfg.MODEL.RPN.NMS_THRESH,
)
box_head = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_1ST,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_1ST,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_1ST,
)
box_head_2nd = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_2ND,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_2ND,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_2ND,
)
box_head_3rd = TransformerHead(
cfg=cfg,
trans_names=cfg.MODEL.TRANSFORMER.NAMES_3RD,
kernel_size=cfg.MODEL.TRANSFORMER.KERNEL_SIZE_3RD,
use_feature_mask=cfg.MODEL.TRANSFORMER.USE_MASK_3RD,
)
faster_rcnn_predictor = FastRCNNPredictor(2048, 2)
box_roi_pool = MultiScaleRoIAlign(
featmap_names=["feat_res4"], output_size=14, sampling_ratio=2
)
box_predictor = BBoxRegressor(2048, num_classes=2, bn_neck=cfg.MODEL.ROI_HEAD.BN_NECK)
roi_heads = CascadedROIHeads(
cfg=cfg,
# Cascade Transformer Head
faster_rcnn_predictor=faster_rcnn_predictor,
box_head_2nd=box_head_2nd,
box_head_3rd=box_head_3rd,
# parent class
box_roi_pool=box_roi_pool,
box_head=box_head,
box_predictor=box_predictor,
fg_iou_thresh=cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN,
bg_iou_thresh=cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN,
batch_size_per_image=cfg.MODEL.ROI_HEAD.BATCH_SIZE_TRAIN,
positive_fraction=cfg.MODEL.ROI_HEAD.POS_FRAC_TRAIN,
bbox_reg_weights=None,
score_thresh=cfg.MODEL.ROI_HEAD.SCORE_THRESH_TEST,
nms_thresh=cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST,
detections_per_img=cfg.MODEL.ROI_HEAD.DETECTIONS_PER_IMAGE_TEST,
)
transform = GeneralizedRCNNTransform(
min_size=cfg.INPUT.MIN_SIZE,
max_size=cfg.INPUT.MAX_SIZE,
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
)
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
self.transform = transform
self.eval_feat = cfg.EVAL_FEATURE
# loss weights
self.lw_rpn_reg = cfg.SOLVER.LW_RPN_REG
self.lw_rpn_cls = cfg.SOLVER.LW_RPN_CLS
self.lw_rcnn_reg_1st = cfg.SOLVER.LW_RCNN_REG_1ST
self.lw_rcnn_cls_1st = cfg.SOLVER.LW_RCNN_CLS_1ST
self.lw_rcnn_reg_2nd = cfg.SOLVER.LW_RCNN_REG_2ND
self.lw_rcnn_cls_2nd = cfg.SOLVER.LW_RCNN_CLS_2ND
self.lw_rcnn_reg_3rd = cfg.SOLVER.LW_RCNN_REG_3RD
self.lw_rcnn_cls_3rd = cfg.SOLVER.LW_RCNN_CLS_3RD
self.lw_rcnn_reid_2nd = cfg.SOLVER.LW_RCNN_REID_2ND
self.lw_rcnn_reid_3rd = cfg.SOLVER.LW_RCNN_REID_3RD
def inference(self, images, targets=None, query_img_as_gallery=False):
original_image_sizes = [img.shape[-2:] for img in images]
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
if query_img_as_gallery:
assert targets is not None
if targets is not None and not query_img_as_gallery:
# query
boxes = [t["boxes"] for t in targets]
box_features = self.roi_heads.box_roi_pool(features, boxes, images.image_sizes)
box_features_2nd = self.roi_heads.box_head_2nd(box_features)
embeddings_2nd, _ = self.roi_heads.embedding_head_2nd(box_features_2nd)
box_features_3rd = self.roi_heads.box_head_3rd(box_features)
embeddings_3rd, _ = self.roi_heads.embedding_head_3rd(box_features_3rd)
if self.eval_feat == 'concat':
embeddings = torch.cat((embeddings_2nd, embeddings_3rd), dim=1)
elif self.eval_feat == 'stage2':
embeddings = embeddings_2nd
elif self.eval_feat == 'stage3':
embeddings = embeddings_3rd
else:
raise Exception("Unknown evaluation feature name")
return embeddings.split(1, 0)
else:
# gallery
boxes, _ = self.rpn(images, features, targets)
detections = self.roi_heads(features, boxes, images.image_sizes, targets, query_img_as_gallery)[0]
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
return detections
def forward(self, images, targets=None, query_img_as_gallery=False):
if not self.training:
return self.inference(images, targets, query_img_as_gallery)
images, targets = self.transform(images, targets)
features = self.backbone(images.tensors)
boxes, rpn_losses = self.rpn(images, features, targets)
_, rcnn_losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd = self.roi_heads(features, boxes, images.image_sizes, targets)
# rename rpn losses to be consistent with detection losses
rpn_losses["loss_rpn_reg"] = rpn_losses.pop("loss_rpn_box_reg")
rpn_losses["loss_rpn_cls"] = rpn_losses.pop("loss_objectness")
losses = {}
losses.update(rcnn_losses)
losses.update(rpn_losses)
# apply loss weights
losses["loss_rpn_reg"] *= self.lw_rpn_reg
losses["loss_rpn_cls"] *= self.lw_rpn_cls
losses["loss_rcnn_reg_1st"] *= self.lw_rcnn_reg_1st
losses["loss_rcnn_cls_1st"] *= self.lw_rcnn_cls_1st
losses["loss_rcnn_reg_2nd"] *= self.lw_rcnn_reg_2nd
losses["loss_rcnn_cls_2nd"] *= self.lw_rcnn_cls_2nd
losses["loss_rcnn_reg_3rd"] *= self.lw_rcnn_reg_3rd
losses["loss_rcnn_cls_3rd"] *= self.lw_rcnn_cls_3rd
losses["loss_rcnn_reid_2nd"] *= self.lw_rcnn_reid_2nd
losses["loss_rcnn_reid_3rd"] *= self.lw_rcnn_reid_3rd
return losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
class CascadedROIHeads(RoIHeads):
'''
https://github.com/pytorch/vision/blob/master/torchvision/models/detection/roi_heads.py
'''
def __init__(
self,
cfg,
faster_rcnn_predictor,
box_head_2nd,
box_head_3rd,
*args,
**kwargs
):
super(CascadedROIHeads, self).__init__(*args, **kwargs)
# ROI head
self.use_diff_thresh=cfg.MODEL.ROI_HEAD.USE_DIFF_THRESH
self.nms_thresh_1st = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_1ST
self.nms_thresh_2nd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_2ND
self.nms_thresh_3rd = cfg.MODEL.ROI_HEAD.NMS_THRESH_TEST_3RD
self.fg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN
self.bg_iou_thresh_1st = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN
self.fg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_2ND
self.bg_iou_thresh_2nd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_2ND
self.fg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.POS_THRESH_TRAIN_3RD
self.bg_iou_thresh_3rd = cfg.MODEL.ROI_HEAD.NEG_THRESH_TRAIN_3RD
# Regression head
self.box_predictor_1st = faster_rcnn_predictor
self.box_predictor_2nd = self.box_predictor
self.box_predictor_3rd = deepcopy(self.box_predictor)
# Transformer head
self.box_head_1st = self.box_head
self.box_head_2nd = box_head_2nd
self.box_head_3rd = box_head_3rd
# feature mask
self.use_feature_mask = cfg.MODEL.USE_FEATURE_MASK
self.feature_mask_size = cfg.MODEL.FEATURE_MASK_SIZE
# Feature embedding
embedding_dim = cfg.MODEL.EMBEDDING_DIM
self.embedding_head_2nd = NormAwareEmbedding(featmap_names=["before_trans", "after_trans"], in_channels=[1024, 2048], dim=embedding_dim)
self.embedding_head_3rd = deepcopy(self.embedding_head_2nd)
# OIM
num_pids = cfg.MODEL.LOSS.LUT_SIZE
num_cq_size = cfg.MODEL.LOSS.CQ_SIZE
oim_momentum = cfg.MODEL.LOSS.OIM_MOMENTUM
oim_scalar = cfg.MODEL.LOSS.OIM_SCALAR
self.reid_loss_2nd = OIMLoss(embedding_dim, num_pids, num_cq_size, oim_momentum, oim_scalar)
self.reid_loss_3rd = deepcopy(self.reid_loss_2nd)
# rename the method inherited from parent class
self.postprocess_proposals = self.postprocess_detections
# evaluation
self.eval_feat = cfg.EVAL_FEATURE
def forward(self, features, boxes, image_shapes, targets=None, query_img_as_gallery=False):
"""
Arguments:
features (List[Tensor])
boxes (List[Tensor[N, 4]])
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
cws = True
gt_det_2nd = None
gt_det_3rd = None
feats_reid_2nd = None
feats_reid_3rd = None
targets_reid_2nd = None
targets_reid_3rd = None
if self.training:
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_1st,
self.bg_iou_thresh_1st,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_1st, box_reg_targets_1st = self.select_training_samples(
boxes, targets
)
# ------------------- The first stage ------------------ #
box_features_1st = self.box_roi_pool(features, boxes, image_shapes)
box_features_1st = self.box_head_1st(box_features_1st)
box_cls_scores_1st, box_regs_1st = self.box_predictor_1st(box_features_1st["after_trans"])
if self.training:
boxes = self.get_boxes(box_regs_1st, boxes, image_shapes)
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_2nd,
self.bg_iou_thresh_2nd,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_2nd, box_reg_targets_2nd = self.select_training_samples(boxes, targets)
else:
orig_thresh = self.nms_thresh # 0.4
self.nms_thresh = self.nms_thresh_1st
boxes, scores, _ = self.postprocess_proposals(
box_cls_scores_1st, box_regs_1st, boxes, image_shapes
)
if not self.training and query_img_as_gallery:
# When regarding the query image as gallery, GT boxes may be excluded
# from detected boxes. To avoid this, we compulsorily include GT in the
# detection results. Additionally, CWS should be disabled as the
# confidences of these people in query image are 1
cws = False
gt_box = [targets[0]["boxes"]]
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
gt_box_features = self.box_head_2nd(gt_box_features)
embeddings, _ = self.embedding_head_2nd(gt_box_features)
gt_det_2nd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
# no detection predicted by Faster R-CNN head in test phase
if boxes[0].shape[0] == 0:
assert not self.training
boxes = gt_det_2nd["boxes"] if gt_det_2nd else torch.zeros(0, 4)
labels = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
scores = torch.ones(1).type_as(boxes) if gt_det_2nd else torch.zeros(0)
if self.eval_feat == 'concat':
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_2nd["embeddings"]), dim=1) if gt_det_2nd else torch.zeros(0, 512)
elif self.eval_feat == 'stage2' or self.eval_feat == 'stage3':
embeddings = gt_det_2nd["embeddings"] if gt_det_2nd else torch.zeros(0, 256)
else:
raise Exception("Unknown evaluation feature name")
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
# --------------------- The second stage -------------------- #
box_features = self.box_roi_pool(features, boxes, image_shapes)
box_features = self.box_head_2nd(box_features)
box_regs_2nd = self.box_predictor_2nd(box_features["after_trans"])
box_embeddings_2nd, box_cls_scores_2nd = self.embedding_head_2nd(box_features)
if box_cls_scores_2nd.dim() == 0:
box_cls_scores_2nd = box_cls_scores_2nd.unsqueeze(0)
if self.training:
boxes = self.get_boxes(box_regs_2nd, boxes, image_shapes)
boxes = [boxes_per_image.detach() for boxes_per_image in boxes]
if self.use_diff_thresh:
self.proposal_matcher = det_utils.Matcher(
self.fg_iou_thresh_3rd,
self.bg_iou_thresh_3rd,
allow_low_quality_matches=False)
boxes, _, box_pid_labels_3rd, box_reg_targets_3rd = self.select_training_samples(boxes, targets)
else:
self.nms_thresh = self.nms_thresh_2nd
if self.eval_feat != 'stage2':
boxes, scores, _, _ = self.postprocess_boxes(
box_cls_scores_2nd,
box_regs_2nd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=None,
cws=cws,
)
if not self.training and query_img_as_gallery and self.eval_feat != 'stage2':
cws = False
gt_box = [targets[0]["boxes"]]
gt_box_features = self.box_roi_pool(features, gt_box, image_shapes)
gt_box_features = self.box_head_3rd(gt_box_features)
embeddings, _ = self.embedding_head_3rd(gt_box_features)
gt_det_3rd = {"boxes": targets[0]["boxes"], "embeddings": embeddings}
# no detection predicted by Faster R-CNN head in test phase
if boxes[0].shape[0] == 0 and self.eval_feat != 'stage2':
assert not self.training
boxes = gt_det_3rd["boxes"] if gt_det_3rd else torch.zeros(0, 4)
labels = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
scores = torch.ones(1).type_as(boxes) if gt_det_3rd else torch.zeros(0)
if self.eval_feat == 'concat':
embeddings = torch.cat((gt_det_2nd["embeddings"], gt_det_3rd["embeddings"]), dim=1) if gt_det_3rd else torch.zeros(0, 512)
elif self.eval_feat == 'stage3':
embeddings = gt_det_2nd["embeddings"] if gt_det_3rd else torch.zeros(0, 256)
else:
raise Exception("Unknown evaluation feature name")
return [dict(boxes=boxes, labels=labels, scores=scores, embeddings=embeddings)], []
# --------------------- The third stage -------------------- #
box_features = self.box_roi_pool(features, boxes, image_shapes)
if not self.training:
box_features_2nd = self.box_head_2nd(box_features)
box_embeddings_2nd, _ = self.embedding_head_2nd(box_features_2nd)
box_features = self.box_head_3rd(box_features)
box_regs_3rd = self.box_predictor_3rd(box_features["after_trans"])
box_embeddings_3rd, box_cls_scores_3rd = self.embedding_head_3rd(box_features)
if box_cls_scores_3rd.dim() == 0:
box_cls_scores_3rd = box_cls_scores_3rd.unsqueeze(0)
result, losses = [], {}
if self.training:
box_labels_1st = [y.clamp(0, 1) for y in box_pid_labels_1st]
box_labels_2nd = [y.clamp(0, 1) for y in box_pid_labels_2nd]
box_labels_3rd = [y.clamp(0, 1) for y in box_pid_labels_3rd]
losses = detection_losses(
box_cls_scores_1st,
box_regs_1st,
box_labels_1st,
box_reg_targets_1st,
box_cls_scores_2nd,
box_regs_2nd,
box_labels_2nd,
box_reg_targets_2nd,
box_cls_scores_3rd,
box_regs_3rd,
box_labels_3rd,
box_reg_targets_3rd,
)
loss_rcnn_reid_2nd, feats_reid_2nd, targets_reid_2nd = self.reid_loss_2nd(box_embeddings_2nd, box_pid_labels_2nd)
loss_rcnn_reid_3rd, feats_reid_3rd, targets_reid_3rd = self.reid_loss_3rd(box_embeddings_3rd, box_pid_labels_3rd)
losses.update(loss_rcnn_reid_2nd=loss_rcnn_reid_2nd)
losses.update(loss_rcnn_reid_3rd=loss_rcnn_reid_3rd)
else:
if self.eval_feat == 'stage2':
boxes, scores, embeddings_2nd, labels = self.postprocess_boxes(
box_cls_scores_2nd,
box_regs_2nd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_2nd,
cws=cws,
)
else:
self.nms_thresh = self.nms_thresh_3rd
_, _, embeddings_2nd, _ = self.postprocess_boxes(
box_cls_scores_3rd,
box_regs_3rd,
box_embeddings_2nd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_2nd,
cws=cws,
)
boxes, scores, embeddings_3rd, labels = self.postprocess_boxes(
box_cls_scores_3rd,
box_regs_3rd,
box_embeddings_3rd,
boxes,
image_shapes,
fcs=scores,
gt_det=gt_det_3rd,
cws=cws,
)
# set to original thresh after finishing postprocess
self.nms_thresh = orig_thresh
num_images = len(boxes)
for i in range(num_images):
if self.eval_feat == 'concat':
embeddings = torch.cat((embeddings_2nd[i],embeddings_3rd[i]), dim=1)
elif self.eval_feat == 'stage2':
embeddings = embeddings_2nd[i]
elif self.eval_feat == 'stage3':
embeddings = embeddings_3rd[i]
else:
raise Exception("Unknown evaluation feature name")
result.append(
dict(
boxes=boxes[i],
labels=labels[i],
scores=scores[i],
embeddings=embeddings
)
)
return result, losses, feats_reid_2nd, targets_reid_2nd, feats_reid_3rd, targets_reid_3rd
def get_boxes(self, box_regression, proposals, image_shapes):
"""
Get boxes from proposals.
"""
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
pred_boxes = pred_boxes.split(boxes_per_image, 0)
all_boxes = []
for boxes, image_shape in zip(pred_boxes, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# remove predictions with the background label
boxes = boxes[:, 1:].reshape(-1, 4)
all_boxes.append(boxes)
return all_boxes
def postprocess_boxes(
self,
class_logits,
box_regression,
embeddings,
proposals,
image_shapes,
fcs=None,
gt_det=None,
cws=True,
):
"""
Similar to RoIHeads.postprocess_detections, but can handle embeddings and implement
First Classification Score (FCS).
"""
device = class_logits.device
boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
if fcs is not None:
# Fist Classification Score (FCS)
pred_scores = fcs[0]
else:
pred_scores = torch.sigmoid(class_logits)
if cws:
# Confidence Weighted Similarity (CWS)
embeddings = embeddings * pred_scores.view(-1, 1)
# split boxes and scores per image
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_scores = pred_scores.split(boxes_per_image, 0)
pred_embeddings = embeddings.split(boxes_per_image, 0)
all_boxes = []
all_scores = []
all_labels = []
all_embeddings = []
for boxes, scores, embeddings, image_shape in zip(
pred_boxes, pred_scores, pred_embeddings, image_shapes
):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# create labels for each prediction
labels = torch.ones(scores.size(0), device=device)
# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores.unsqueeze(1)
labels = labels.unsqueeze(1)
# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.flatten()
labels = labels.flatten()
embeddings = embeddings.reshape(-1, self.embedding_head_2nd.dim)
# remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
boxes, scores, labels, embeddings = (
boxes[inds],
scores[inds],
labels[inds],
embeddings[inds],
)
# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
boxes, scores, labels, embeddings = (
boxes[keep],
scores[keep],
labels[keep],
embeddings[keep],
)
if gt_det is not None:
# include GT into the detection results
boxes = torch.cat((boxes, gt_det["boxes"]), dim=0)
labels = torch.cat((labels, torch.tensor([1.0]).to(device)), dim=0)
scores = torch.cat((scores, torch.tensor([1.0]).to(device)), dim=0)
embeddings = torch.cat((embeddings, gt_det["embeddings"]), dim=0)
# non-maximum suppression, independently done per class
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[: self.detections_per_img]
boxes, scores, labels, embeddings = (
boxes[keep],
scores[keep],
labels[keep],
embeddings[keep],
)
all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)
all_embeddings.append(embeddings)
return all_boxes, all_scores, all_embeddings, all_labels
class NormAwareEmbedding(nn.Module):
"""
Implements the Norm-Aware Embedding proposed in
Chen, Di, et al. "Norm-aware embedding for efficient person search." CVPR 2020.
"""
def __init__(self, featmap_names=["feat_res4", "feat_res5"], in_channels=[1024, 2048], dim=256):
super(NormAwareEmbedding, self).__init__()
self.featmap_names = featmap_names
self.in_channels = in_channels
self.dim = dim
self.projectors = nn.ModuleDict()
indv_dims = self._split_embedding_dim()
for ftname, in_channel, indv_dim in zip(self.featmap_names, self.in_channels, indv_dims):
proj = nn.Sequential(nn.Linear(in_channel, indv_dim), nn.BatchNorm1d(indv_dim))
init.normal_(proj[0].weight, std=0.01)
init.normal_(proj[1].weight, std=0.01)
init.constant_(proj[0].bias, 0)
init.constant_(proj[1].bias, 0)
self.projectors[ftname] = proj
self.rescaler = nn.BatchNorm1d(1, affine=True)
def forward(self, featmaps):
"""
Arguments:
featmaps: OrderedDict[Tensor], and in featmap_names you can choose which
featmaps to use
Returns:
tensor of size (BatchSize, dim), L2 normalized embeddings.
tensor of size (BatchSize, ) rescaled norm of embeddings, as class_logits.
"""
assert len(featmaps) == len(self.featmap_names)
if len(featmaps) == 1:
k, v = featmaps.items()[0]
v = self._flatten_fc_input(v)
embeddings = self.projectors[k](v)
norms = embeddings.norm(2, 1, keepdim=True)
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
norms = self.rescaler(norms).squeeze()
return embeddings, norms
else:
outputs = []
for k, v in featmaps.items():
v = self._flatten_fc_input(v)
outputs.append(self.projectors[k](v))
embeddings = torch.cat(outputs, dim=1)
norms = embeddings.norm(2, 1, keepdim=True)
embeddings = embeddings / norms.expand_as(embeddings).clamp(min=1e-12)
norms = self.rescaler(norms).squeeze()
return embeddings, norms
def _flatten_fc_input(self, x):
if x.ndimension() == 4:
assert list(x.shape[2:]) == [1, 1]
return x.flatten(start_dim=1)
return x
def _split_embedding_dim(self):
parts = len(self.in_channels)
tmp = [self.dim // parts] * parts
if sum(tmp) == self.dim:
return tmp
else:
res = self.dim % parts
for i in range(1, res + 1):
tmp[-i] += 1
assert sum(tmp) == self.dim
return tmp
class BBoxRegressor(nn.Module):
"""
Bounding box regression layer.
"""
def __init__(self, in_channels, num_classes=2, bn_neck=True):
"""
Args:
in_channels (int): Input channels.
num_classes (int, optional): Defaults to 2 (background and pedestrian).
bn_neck (bool, optional): Whether to use BN after Linear. Defaults to True.
"""
super(BBoxRegressor, self).__init__()
if bn_neck:
self.bbox_pred = nn.Sequential(
nn.Linear(in_channels, 4 * num_classes), nn.BatchNorm1d(4 * num_classes)
)
init.normal_(self.bbox_pred[0].weight, std=0.01)
init.normal_(self.bbox_pred[1].weight, std=0.01)
init.constant_(self.bbox_pred[0].bias, 0)
init.constant_(self.bbox_pred[1].bias, 0)
else:
self.bbox_pred = nn.Linear(in_channels, 4 * num_classes)
init.normal_(self.bbox_pred.weight, std=0.01)
init.constant_(self.bbox_pred.bias, 0)
def forward(self, x):
if x.ndimension() == 4:
if list(x.shape[2:]) != [1, 1]:
x = F.adaptive_avg_pool2d(x, output_size=1)
x = x.flatten(start_dim=1)
bbox_deltas = self.bbox_pred(x)
return bbox_deltas
def detection_losses(
box_cls_scores_1st,
box_regs_1st,
box_labels_1st,
box_reg_targets_1st,
box_cls_scores_2nd,
box_regs_2nd,
box_labels_2nd,
box_reg_targets_2nd,
box_cls_scores_3rd,
box_regs_3rd,
box_labels_3rd,
box_reg_targets_3rd,
):
# --------------------- The first stage -------------------- #
box_labels_1st = torch.cat(box_labels_1st, dim=0)
box_reg_targets_1st = torch.cat(box_reg_targets_1st, dim=0)
loss_rcnn_cls_1st = F.cross_entropy(box_cls_scores_1st, box_labels_1st)
# get indices that correspond to the regression targets for the
# corresponding ground truth labels, to be used with advanced indexing
sampled_pos_inds_subset = torch.nonzero(box_labels_1st > 0).squeeze(1)
labels_pos = box_labels_1st[sampled_pos_inds_subset]
N = box_cls_scores_1st.size(0)
box_regs_1st = box_regs_1st.reshape(N, -1, 4)
loss_rcnn_reg_1st = F.smooth_l1_loss(
box_regs_1st[sampled_pos_inds_subset, labels_pos],
box_reg_targets_1st[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_1st = loss_rcnn_reg_1st / box_labels_1st.numel()
# --------------------- The second stage -------------------- #
box_labels_2nd = torch.cat(box_labels_2nd, dim=0)
box_reg_targets_2nd = torch.cat(box_reg_targets_2nd, dim=0)
loss_rcnn_cls_2nd = F.binary_cross_entropy_with_logits(box_cls_scores_2nd, box_labels_2nd.float())
sampled_pos_inds_subset = torch.nonzero(box_labels_2nd > 0).squeeze(1)
labels_pos = box_labels_2nd[sampled_pos_inds_subset]
N = box_cls_scores_2nd.size(0)
box_regs_2nd = box_regs_2nd.reshape(N, -1, 4)
loss_rcnn_reg_2nd = F.smooth_l1_loss(
box_regs_2nd[sampled_pos_inds_subset, labels_pos],
box_reg_targets_2nd[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_2nd = loss_rcnn_reg_2nd / box_labels_2nd.numel()
# --------------------- The third stage -------------------- #
box_labels_3rd = torch.cat(box_labels_3rd, dim=0)
box_reg_targets_3rd = torch.cat(box_reg_targets_3rd, dim=0)
loss_rcnn_cls_3rd = F.binary_cross_entropy_with_logits(box_cls_scores_3rd, box_labels_3rd.float())
sampled_pos_inds_subset = torch.nonzero(box_labels_3rd > 0).squeeze(1)
labels_pos = box_labels_3rd[sampled_pos_inds_subset]
N = box_cls_scores_3rd.size(0)
box_regs_3rd = box_regs_3rd.reshape(N, -1, 4)
loss_rcnn_reg_3rd = F.smooth_l1_loss(
box_regs_3rd[sampled_pos_inds_subset, labels_pos],
box_reg_targets_3rd[sampled_pos_inds_subset],
reduction="sum",
)
loss_rcnn_reg_3rd = loss_rcnn_reg_3rd / box_labels_3rd.numel()
return dict(
loss_rcnn_cls_1st=loss_rcnn_cls_1st,
loss_rcnn_reg_1st=loss_rcnn_reg_1st,
loss_rcnn_cls_2nd=loss_rcnn_cls_2nd,
loss_rcnn_reg_2nd=loss_rcnn_reg_2nd,
loss_rcnn_cls_3rd=loss_rcnn_cls_3rd,
loss_rcnn_reg_3rd=loss_rcnn_reg_3rd,
)

View File

@ -0,0 +1,53 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
from collections import OrderedDict
import torch.nn.functional as F
import torchvision
from torch import nn
class Backbone(nn.Sequential):
def __init__(self, resnet):
super(Backbone, self).__init__(
OrderedDict(
[
["conv1", resnet.conv1],
["bn1", resnet.bn1],
["relu", resnet.relu],
["maxpool", resnet.maxpool],
["layer1", resnet.layer1], # res2
["layer2", resnet.layer2], # res3
["layer3", resnet.layer3], # res4
]
)
)
self.out_channels = 1024
def forward(self, x):
# using the forward method from nn.Sequential
feat = super(Backbone, self).forward(x)
return OrderedDict([["feat_res4", feat]])
class Res5Head(nn.Sequential):
def __init__(self, resnet):
super(Res5Head, self).__init__(OrderedDict([["layer4", resnet.layer4]])) # res5
self.out_channels = [1024, 2048]
def forward(self, x):
feat = super(Res5Head, self).forward(x)
x = F.adaptive_max_pool2d(x, 1)
feat = F.adaptive_max_pool2d(feat, 1)
return OrderedDict([["feat_res4", x], ["feat_res5", feat]])
def build_resnet(name="resnet50", pretrained=True):
resnet = torchvision.models.resnet.__dict__[name](pretrained=pretrained)
# freeze layers
resnet.conv1.weight.requires_grad_(False)
resnet.bn1.weight.requires_grad_(False)
resnet.bn1.bias.requires_grad_(False)
return Backbone(resnet), Res5Head(resnet)

View File

@ -0,0 +1,300 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import math
import random
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.mask import exchange_token, exchange_patch, get_mask_box, jigsaw_token, cutout_patch, erase_patch, mixup_patch, jigsaw_patch
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class TransformerHead(nn.Module):
def __init__(
self,
cfg,
trans_names,
kernel_size,
use_feature_mask,
):
super(TransformerHead, self).__init__()
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
# Mask parameters
self.use_feature_mask = use_feature_mask
mask_shape = cfg.MODEL.MASK_SHAPE
mask_size = cfg.MODEL.MASK_SIZE
mask_mode = cfg.MODEL.MASK_MODE
self.bypass_mask = exchange_patch(mask_shape, mask_size, mask_mode)
self.get_mask_box = get_mask_box(mask_shape, mask_size, mask_mode)
self.transformer_encoder = Transformers(
cfg=cfg,
trans_names=trans_names,
kernel_size=kernel_size,
use_feature_mask=use_feature_mask,
)
self.conv0 = conv1x1(1024, 1024)
self.conv1 = conv1x1(1024, d_model)
self.conv2 = conv1x1(d_model, 2048)
def forward(self, box_features):
mask_box = self.get_mask_box(box_features)
if self.use_feature_mask:
skip_features = self.conv0(box_features)
if self.training:
skip_features = self.bypass_mask(skip_features)
else:
skip_features = box_features
trans_features = {}
trans_features["before_trans"] = F.adaptive_max_pool2d(skip_features, 1)
box_features = self.conv1(box_features)
box_features = self.transformer_encoder((box_features,mask_box))
box_features = self.conv2(box_features)
trans_features["after_trans"] = F.adaptive_max_pool2d(box_features, 1)
return trans_features
class Transformers(nn.Module):
def __init__(
self,
cfg,
trans_names,
kernel_size,
use_feature_mask,
):
super(Transformers, self).__init__()
d_model = cfg.MODEL.TRANSFORMER.DIM_MODEL
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
self.use_feature_mask = use_feature_mask
# If no conv before transformer, we do not use scales
if not cfg.MODEL.TRANSFORMER.USE_PATCH2VEC:
trans_names = ['scale1']
kernel_size = [(1,1)]
self.trans_names = trans_names
self.scale_size = len(self.trans_names)
hidden = d_model//(2*self.scale_size)
# kernel_size: (padding, stride)
kernels = {
(1,1): [(0,0),(1,1)],
(3,3): [(1,1),(1,1)]
}
padding = []
stride = []
for ksize in kernel_size:
if ksize not in [(1,1),(3,3)]:
raise ValueError('Undefined kernel size.')
padding.append(kernels[ksize][0])
stride.append(kernels[ksize][1])
self.use_output_layer = cfg.MODEL.TRANSFORMER.USE_OUTPUT_LAYER
self.use_global_shortcut = cfg.MODEL.TRANSFORMER.USE_GLOBAL_SHORTCUT
self.blocks = nn.ModuleDict()
for tname, ksize, psize, ssize in zip(self.trans_names, kernel_size, padding, stride):
transblock = Transformer(
cfg, d_model//self.scale_size, ksize, psize, ssize, hidden, use_feature_mask
)
self.blocks[tname] = nn.Sequential(transblock)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
self.mask_para = [cfg.MODEL.MASK_SHAPE, cfg.MODEL.MASK_SIZE, cfg.MODEL.MASK_MODE]
def forward(self, inputs):
trans_feat = []
enc_feat, mask_box = inputs
if self.training and self.use_feature_mask and self.feature_aug_type == 'exchange_patch':
feature_mask = exchange_patch(self.mask_para[0], self.mask_para[1], self.mask_para[2])
enc_feat = feature_mask(enc_feat)
for tname, feat in zip(self.trans_names, torch.chunk(enc_feat, len(self.trans_names), dim=1)):
feat = self.blocks[tname]((feat, mask_box))
trans_feat.append(feat)
trans_feat = torch.cat(trans_feat, 1)
if self.use_output_layer:
trans_feat = self.output_linear(trans_feat)
if self.use_global_shortcut:
trans_feat = enc_feat + trans_feat
return trans_feat
class Transformer(nn.Module):
def __init__(self, cfg, channel, kernel_size, padding, stride, hidden, use_feature_mask
):
super(Transformer, self).__init__()
self.k = kernel_size[0]
stack_num = cfg.MODEL.TRANSFORMER.ENCODER_LAYERS
num_head = cfg.MODEL.TRANSFORMER.N_HEAD
dropout = cfg.MODEL.TRANSFORMER.DROPOUT
output_size = (14,14)
token_size = tuple(map(lambda x,y:x//y, output_size, stride))
blocks = []
self.transblock = TransformerBlock(token_size, hidden=hidden, num_head=num_head, dropout=dropout)
for _ in range(stack_num):
blocks.append(self.transblock)
self.transformer = nn.Sequential(*blocks)
self.patch2vec = nn.Conv2d(channel, hidden, kernel_size=kernel_size, stride=stride, padding=padding)
self.vec2patch = Vec2Patch(channel, hidden, output_size, kernel_size, stride, padding)
self.use_local_shortcut = cfg.MODEL.TRANSFORMER.USE_LOCAL_SHORTCUT
self.use_feature_mask = use_feature_mask
self.feature_aug_type = cfg.MODEL.FEATURE_AUG_TYPE
self.use_patch2vec = cfg.MODEL.TRANSFORMER.USE_PATCH2VEC
def forward(self, inputs):
enc_feat, mask_box = inputs
b, c, h, w = enc_feat.size()
trans_feat = self.patch2vec(enc_feat)
_, c, h, w = trans_feat.size()
trans_feat = trans_feat.view(b, c, -1).permute(0, 2, 1)
# For 1x1 & 3x3 kernels, exchange tokens
if self.training and self.use_feature_mask:
if self.feature_aug_type == 'exchange_token':
feature_mask = exchange_token()
trans_feat = feature_mask(trans_feat, mask_box)
elif self.feature_aug_type == 'cutout_patch':
feature_mask = cutout_patch()
trans_feat = feature_mask(trans_feat)
elif self.feature_aug_type == 'erase_patch':
feature_mask = erase_patch()
trans_feat = feature_mask(trans_feat)
elif self.feature_aug_type == 'mixup_patch':
feature_mask = mixup_patch()
trans_feat = feature_mask(trans_feat)
if self.use_feature_mask:
if self.feature_aug_type == 'jigsaw_patch':
feature_mask = jigsaw_patch()
trans_feat = feature_mask(trans_feat)
elif self.feature_aug_type == 'jigsaw_token':
feature_mask = jigsaw_token()
trans_feat = feature_mask(trans_feat)
trans_feat = self.transformer(trans_feat)
trans_feat = self.vec2patch(trans_feat)
if self.use_local_shortcut:
trans_feat = enc_feat + trans_feat
return trans_feat
class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, tokensize, hidden=128, num_head=4, dropout=0.1):
super().__init__()
self.attention = MultiHeadedAttention(tokensize, d_model=hidden, head=num_head, p=dropout)
self.ffn = FeedForward(hidden, p=dropout)
self.norm1 = nn.LayerNorm(hidden)
self.norm2 = nn.LayerNorm(hidden)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
x = self.norm1(x)
x = x + self.dropout(self.attention(x))
y = self.norm2(x)
x = x + self.ffn(y)
return x
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def __init__(self, p=0.1):
super(Attention, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
p_attn = F.softmax(scores, dim=-1)
p_attn = self.dropout(p_attn)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class Vec2Patch(nn.Module):
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding):
super(Vec2Patch, self).__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(hidden, c_out)
self.to_patch = torch.nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding)
h, w = output_size
def forward(self, x):
feat = self.embedding(x)
b, n, c = feat.size()
feat = feat.permute(0, 2, 1)
feat = self.to_patch(feat)
return feat
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
def __init__(self, tokensize, d_model, head, p=0.1):
super().__init__()
self.query_embedding = nn.Linear(d_model, d_model)
self.value_embedding = nn.Linear(d_model, d_model)
self.key_embedding = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention(p=p)
self.head = head
self.h, self.w = tokensize
def forward(self, x):
b, n, c = x.size()
c_h = c // self.head
key = self.key_embedding(x)
query = self.query_embedding(x)
value = self.value_embedding(x)
key = key.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
query = query.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
value = value.view(b, n, self.head, c_h).permute(0, 2, 1, 3)
att, _ = self.attention(query, key, value)
att = att.permute(0, 2, 1, 3).contiguous().view(b, n, c)
output = self.output_linear(att)
return output
class FeedForward(nn.Module):
def __init__(self, d_model, p=0.1):
super(FeedForward, self).__init__()
self.conv = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.ReLU(inplace=True),
nn.Dropout(p=p),
nn.Linear(d_model * 4, d_model),
nn.Dropout(p=p))
def forward(self, x):
x = self.conv(x)
return x

154
coat-pvtv2-b2/train.py Normal file
View File

@ -0,0 +1,154 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import argparse
import datetime
import os.path as osp
import time
import torch
import torch.utils.data
from datasets import build_test_loader, build_train_loader
from defaults import get_default_cfg
from engine import evaluate_performance, train_one_epoch
from models.coat import COAT
from utils.utils import mkdir, resume_from_ckpt, save_on_master, set_random_seed
from loss.softmax_loss import SoftmaxLoss
def main(args):
cfg = get_default_cfg()
if args.cfg_file:
cfg.merge_from_file(args.cfg_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
device = torch.device(cfg.DEVICE)
if cfg.SEED >= 0:
set_random_seed(cfg.SEED)
print("Creating model...")
model = COAT(cfg)
model.to(device)
print("Loading data...")
train_loader = build_train_loader(cfg)
gallery_loader, query_loader = build_test_loader(cfg)
softmax_criterion_s2 = None
softmax_criterion_s3 = None
if cfg.MODEL.LOSS.USE_SOFTMAX:
softmax_criterion_s2 = SoftmaxLoss(cfg)
softmax_criterion_s3 = SoftmaxLoss(cfg)
softmax_criterion_s2.to(device)
softmax_criterion_s3.to(device)
if args.eval:
assert args.ckpt, "--ckpt must be specified when --eval enabled"
resume_from_ckpt(args.ckpt, model)
evaluate_performance(
model,
gallery_loader,
query_loader,
device,
use_gt=cfg.EVAL_USE_GT,
use_cache=cfg.EVAL_USE_CACHE,
use_cbgm=cfg.EVAL_USE_CBGM,
gallery_size=cfg.EVAL_GALLERY_SIZE,
)
exit(0)
params = [p for p in model.parameters() if p.requires_grad]
if cfg.MODEL.LOSS.USE_SOFTMAX:
params_softmax_s2 = [p for p in softmax_criterion_s2.parameters() if p.requires_grad]
params_softmax_s3 = [p for p in softmax_criterion_s3.parameters() if p.requires_grad]
params.extend(params_softmax_s2)
params.extend(params_softmax_s3)
optimizer = torch.optim.SGD(
params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.SGD_MOMENTUM,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=cfg.SOLVER.LR_DECAY_MILESTONES, gamma=cfg.SOLVER.GAMMA
)
start_epoch = 0
if args.resume:
assert args.ckpt, "--ckpt must be specified when --resume enabled"
start_epoch = resume_from_ckpt(args.ckpt, model, optimizer, lr_scheduler) + 1
print("Creating output folder...")
output_dir = cfg.OUTPUT_DIR
mkdir(output_dir)
path = osp.join(output_dir, "config.yaml")
with open(path, "w") as f:
f.write(cfg.dump())
print(f"Full config is saved to {path}")
tfboard = None
if cfg.TF_BOARD:
from torch.utils.tensorboard import SummaryWriter
tf_log_path = osp.join(output_dir, "tf_log")
mkdir(tf_log_path)
tfboard = SummaryWriter(log_dir=tf_log_path)
print(f"TensorBoard files are saved to {tf_log_path}")
print("Start training...")
start_time = time.time()
for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
train_one_epoch(cfg, model, optimizer, train_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3)
lr_scheduler.step()
# only save the last three checkpoints
if epoch >= cfg.SOLVER.MAX_EPOCHS - 3:
save_on_master(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
},
osp.join(output_dir, f"epoch_{epoch}.pth"),
)
# evaluate the current checkpoint
evaluate_performance(
model,
gallery_loader,
query_loader,
device,
use_gt=cfg.EVAL_USE_GT,
use_cache=cfg.EVAL_USE_CACHE,
use_cbgm=cfg.EVAL_USE_CBGM,
gallery_size=cfg.EVAL_GALLERY_SIZE,
)
if tfboard:
tfboard.close()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Total training time {total_time_str}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a person search network.")
parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.")
parser.add_argument(
"--eval", action="store_true", help="Evaluate the performance of a given checkpoint."
)
parser.add_argument(
"--resume", action="store_true", help="Resume from the specified checkpoint."
)
parser.add_argument("--ckpt", help="Path to checkpoint to resume or evaluate.")
parser.add_argument(
"opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line"
)
args = parser.parse_args()
main(args)

150
coat-pvtv2-b2/utils/km.py Normal file
View File

@ -0,0 +1,150 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import random
import numpy as np
zero_threshold = 0.00000001
class KMNode(object):
def __init__(self, id, exception=0, match=None, visit=False):
self.id = id
self.exception = exception
self.match = match
self.visit = visit
class KuhnMunkres(object):
def __init__(self):
self.matrix = None
self.x_nodes = []
self.y_nodes = []
self.minz = float("inf")
self.x_length = 0
self.y_length = 0
self.index_x = 0
self.index_y = 1
def __del__(self):
pass
def set_matrix(self, x_y_values):
xs = set()
ys = set()
for x, y, value in x_y_values:
xs.add(x)
ys.add(y)
if len(xs) < len(ys):
self.index_x = 0
self.index_y = 1
else:
self.index_x = 1
self.index_y = 0
xs, ys = ys, xs
x_dic = {x: i for i, x in enumerate(xs)}
y_dic = {y: j for j, y in enumerate(ys)}
self.x_nodes = [KMNode(x) for x in xs]
self.y_nodes = [KMNode(y) for y in ys]
self.x_length = len(xs)
self.y_length = len(ys)
self.matrix = np.zeros((self.x_length, self.y_length))
for row in x_y_values:
x = row[self.index_x]
y = row[self.index_y]
value = row[2]
x_index = x_dic[x]
y_index = y_dic[y]
self.matrix[x_index, y_index] = value
for i in range(self.x_length):
self.x_nodes[i].exception = max(self.matrix[i, :])
def km(self):
for i in range(self.x_length):
while True:
self.minz = float("inf")
self.set_false(self.x_nodes)
self.set_false(self.y_nodes)
if self.dfs(i):
break
self.change_exception(self.x_nodes, -self.minz)
self.change_exception(self.y_nodes, self.minz)
def dfs(self, i):
x_node = self.x_nodes[i]
x_node.visit = True
for j in range(self.y_length):
y_node = self.y_nodes[j]
if not y_node.visit:
t = x_node.exception + y_node.exception - self.matrix[i][j]
if abs(t) < zero_threshold:
y_node.visit = True
if y_node.match is None or self.dfs(y_node.match):
x_node.match = j
y_node.match = i
return True
else:
if t >= zero_threshold:
self.minz = min(self.minz, t)
return False
def set_false(self, nodes):
for node in nodes:
node.visit = False
def change_exception(self, nodes, change):
for node in nodes:
if node.visit:
node.exception += change
def get_connect_result(self):
ret = []
for i in range(self.x_length):
x_node = self.x_nodes[i]
j = x_node.match
y_node = self.y_nodes[j]
x_id = x_node.id
y_id = y_node.id
value = self.matrix[i][j]
if self.index_x == 1 and self.index_y == 0:
x_id, y_id = y_id, x_id
ret.append((x_id, y_id, value))
return ret
def get_max_value_result(self):
ret = -100
for i in range(self.x_length):
j = self.x_nodes[i].match
ret = max(ret, self.matrix[i][j])
return ret
def run_kuhn_munkres(x_y_values):
process = KuhnMunkres()
process.set_matrix(x_y_values)
process.km()
return process.get_connect_result(), process.get_max_value_result()
def test():
values = []
random.seed(0)
for i in range(500):
for j in range(1000):
value = random.random()
values.append((i, j, value))
return run_kuhn_munkres(values)
if __name__ == "__main__":
values = [(1, 1, 3), (1, 3, 4), (2, 1, 2), (2, 2, 1), (2, 3, 3), (3, 2, 4), (3, 3, 5)]
print(run_kuhn_munkres(values))

325
coat-pvtv2-b2/utils/mask.py Normal file
View File

@ -0,0 +1,325 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import random
import torch
class exchange_token:
def __init__(self):
pass
def __call__(self, features, mask_box):
b, hw, c = features.size()
assert hw == 14*14
new_idx, mask_x1, mask_x2, mask_y1, mask_y2 = mask_box
features = features.view(b, 14, 14, c)
features[:, mask_x1 : mask_x2, mask_y1 : mask_y2, :] = features[new_idx, mask_x1 : mask_x2, mask_y1 : mask_y2, :]
features = features.view(b, hw, c)
return features
class jigsaw_token:
def __init__(self, shift=5, group=2, begin=1):
self.shift = shift
self.group = group
self.begin = begin
def __call__(self, features):
batchsize = features.size(0)
dim = features.size(2)
num_tokens = features.size(1)
if num_tokens == 196:
self.group = 2
elif num_tokens == 25:
self.group = 5
else:
raise Exception("Jigsaw - Unwanted number of tokens")
# Shift Operation
feature_random = torch.cat([features[:, self.begin-1+self.shift:, :], features[:, self.begin-1:self.begin-1+self.shift, :]], dim=1)
x = feature_random
# Patch Shuffle Operation
try:
x = x.view(batchsize, self.group, -1, dim)
except:
raise Exception("Jigsaw - Unwanted number of groups")
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, dim)
return x
class get_mask_box:
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
self.shape = shape
self.mask_size = mask_size
self.mode = mode
def __call__(self, features):
# Stripe mask
if self.shape == 'stripe':
if self.mode == 'horizontal':
mask_box = self.hstripe(features, self.mask_size)
elif self.mode == 'vertical':
mask_box = self.vstripe(features, self.mask_size)
elif self.mode == 'random_direction':
if random.random() < 0.5:
mask_box = self.hstripe(features, self.mask_size)
else:
mask_box = self.vstripe(features, self.mask_size)
else:
raise Exception("Unknown stripe mask mode name")
# Square mask
elif self.shape == 'square':
if self.mode == 'random_size':
self.mask_size = 4 if random.random() < 0.5 else 5
mask_box = self.square(features, self.mask_size)
# Random stripe/square mask
elif self.shape == 'random':
random_num = random.random()
if random_num < 0.25:
mask_box = self.hstripe(features, 2)
elif random_num < 0.5 and random_num >= 0.25:
mask_box = self.vstripe(features, 2)
elif random_num < 0.75 and random_num >= 0.5:
mask_box = self.square(features, 4)
else:
mask_box = self.square(features, 5)
else:
raise Exception("Unknown mask shape name")
return mask_box
def hstripe(self, features, mask_size):
"""
"""
# horizontal stripe
mask_x1 = 0
mask_x2 = features.shape[2]
y1_max = features.shape[3] - mask_size
mask_y1 = torch.randint(y1_max, (1,))
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
return mask_box
def vstripe(self, features, mask_size):
"""
"""
# vertical stripe
mask_y1 = 0
mask_y2 = features.shape[3]
x1_max = features.shape[2] - mask_size
mask_x1 = torch.randint(x1_max, (1,))
mask_x2 = mask_x1 + mask_size
new_idx = torch.randperm(features.shape[0])
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
return mask_box
def square(self, features, mask_size):
"""
"""
# square
x1_max = features.shape[2] - mask_size
y1_max = features.shape[3] - mask_size
mask_x1 = torch.randint(x1_max, (1,))
mask_y1 = torch.randint(y1_max, (1,))
mask_x2 = mask_x1 + mask_size
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
mask_box = (new_idx, mask_x1, mask_x2, mask_y1, mask_y2)
return mask_box
class exchange_patch:
def __init__(self, shape='stripe', mask_size=2, mode='random_direct'):
self.shape = shape
self.mask_size = mask_size
self.mode = mode
def __call__(self, features):
# Stripe mask
if self.shape == 'stripe':
if self.mode == 'horizontal':
features = self.xpatch_hstripe(features, self.mask_size)
elif self.mode == 'vertical':
features = self.xpatch_vstripe(features, self.mask_size)
elif self.mode == 'random_direction':
if random.random() < 0.5:
features = self.xpatch_hstripe(features, self.mask_size)
else:
features = self.xpatch_vstripe(features, self.mask_size)
else:
raise Exception("Unknown stripe mask mode name")
# Square mask
elif self.shape == 'square':
if self.mode == 'random_size':
self.mask_size = 4 if random.random() < 0.5 else 5
features = self.xpatch_square(features, self.mask_size)
# Random stripe/square mask
elif self.shape == 'random':
random_num = random.random()
if random_num < 0.25:
features = self.xpatch_hstripe(features, 2)
elif random_num < 0.5 and random_num >= 0.25:
features = self.xpatch_vstripe(features, 2)
elif random_num < 0.75 and random_num >= 0.5:
features = self.xpatch_square(features, 4)
else:
features = self.xpatch_square(features, 5)
else:
raise Exception("Unknown mask shape name")
return features
def xpatch_hstripe(self, features, mask_size):
"""
"""
# horizontal stripe
y1_max = features.shape[3] - mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (1,))
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, :, mask_y1 : mask_y2] = features[new_idx, :, :, mask_y1 : mask_y2]
return features
def xpatch_vstripe(self, features, mask_size):
"""
"""
# vertical stripe
x1_max = features.shape[2] - mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (1,))
mask_x2 = mask_x1 + mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, mask_x1 : mask_x2, :] = features[new_idx, :, mask_x1 : mask_x2, :]
return features
def xpatch_square(self, features, mask_size):
"""
"""
# square
x1_max = features.shape[2] - mask_size
y1_max = features.shape[3] - mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (1,))
mask_y1 = torch.randint(y1_max, (1,))
mask_x2 = mask_x1 + mask_size
mask_y2 = mask_y1 + mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, mask_x1 : mask_x2, mask_y1 : mask_y2] = features[new_idx, :, mask_x1 : mask_x2, mask_y1 : mask_y2]
return features
class cutout_patch:
def __init__(self, mask_size=2):
self.mask_size = mask_size
def __call__(self, features):
if random.random() < 0.5:
y1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (features.shape[0],))
mask_y2 = mask_y1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, :, mask_y1[k] : mask_y2[k]] = 0
else:
x1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (features.shape[0],))
mask_x2 = mask_x1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, mask_x1[k] : mask_x2[k], :] = 0
return features
class erase_patch:
def __init__(self, mask_size=2):
self.mask_size = mask_size
def __call__(self, features):
std, mean = torch.std_mean(features.detach())
dim = features.shape[1]
if random.random() < 0.5:
y1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (features.shape[0],))
mask_y2 = mask_y1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, :, mask_y1[k] : mask_y2[k]] = torch.normal(mean.repeat(dim,14,2), std.repeat(dim,14,2))
else:
x1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (features.shape[0],))
mask_x2 = mask_x1 + self.mask_size
for k in range(features.shape[0]):
features[k, :, mask_x1[k] : mask_x2[k], :] = torch.normal(mean.repeat(dim,2,14), std.repeat(dim,2,14))
return features
class mixup_patch:
def __init__(self, mask_size=2):
self.mask_size = mask_size
def __call__(self, features):
lam = random.uniform(0, 1)
if random.random() < 0.5:
y1_max = features.shape[3] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_y1 = torch.randint(y1_max, (1,))
mask_y2 = mask_y1 + self.mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, :, mask_y1 : mask_y2] = lam*features[:, :, :, mask_y1 : mask_y2] + (1-lam)*features[new_idx, :, :, mask_y1 : mask_y2]
else:
x1_max = features.shape[2] - self.mask_size
num_masks = 1
for i in range(num_masks):
mask_x1 = torch.randint(x1_max, (1,))
mask_x2 = mask_x1 + self.mask_size
new_idx = torch.randperm(features.shape[0])
features[:, :, mask_x1 : mask_x2, :] = lam*features[:, :, mask_x1 : mask_x2, :] + (1-lam)*features[new_idx, :, mask_x1 : mask_x2, :]
return features
class jigsaw_patch:
def __init__(self, shift=5, group=2):
self.shift = shift
self.group = group
def __call__(self, features):
batchsize = features.size(0)
dim = features.size(1)
features = features.view(batchsize, dim, -1)
# Shift Operation
feature_random = torch.cat([features[:, :, self.shift:], features[:, :, :self.shift]], dim=2)
x = feature_random
# Patch Shuffle Operation
try:
x = x.view(batchsize, dim, self.group, -1)
except:
x = torch.cat([x, x[:, -2:-1, :]], dim=1)
x = x.view(batchsize, self.group, -1, dim)
x = torch.transpose(x, 2, 3).contiguous()
x = x.view(batchsize, dim, -1)
x = x.view(batchsize, dim, 14, 14)
return x

View File

@ -0,0 +1,144 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import random
import math
import torch
import numpy as np
from copy import deepcopy
from torchvision.transforms import functional as F
def mixup_data(images, alpha=0.8):
if alpha > 0. and alpha < 1.:
lam = random.uniform(alpha, 1)
else:
lam = 1.
batch_size = len(images)
min_x = 9999
min_y = 9999
for i in range(batch_size):
min_x = min(min_x, images[i].shape[1])
min_y = min(min_y, images[i].shape[2])
shuffle_images = deepcopy(images)
random.shuffle(shuffle_images)
mixed_images = deepcopy(images)
for i in range(batch_size):
mixed_images[i][:, :min_x, :min_y] = lam * images[i][:, :min_x, :min_y] + (1 - lam) * shuffle_images[i][:, :min_x, :min_y]
return mixed_images
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class RandomHorizontalFlip:
def __init__(self, prob=0.5):
self.prob = prob
def __call__(self, image, target):
if random.random() < self.prob:
height, width = image.shape[-2:]
image = image.flip(-1)
bbox = target["boxes"]
bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
target["boxes"] = bbox
return image, target
class Cutout(object):
"""Randomly mask out one or more patches from an image.
https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes=2, length=100):
self.n_holes = n_holes
self.length = length
def __call__(self, img, target):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img, target
class RandomErasing(object):
'''
https://github.com/zhunzhong07/CamStyle/blob/master/reid/utils/data/transforms.py
'''
def __init__(self, EPSILON=0.5, mean=[0.485, 0.456, 0.406]):
self.EPSILON = EPSILON
self.mean = mean
def __call__(self, img, target):
if random.uniform(0, 1) > self.EPSILON:
return img, target
for attempt in range(100):
area = img.size()[1] * img.size()[2]
target_area = random.uniform(0.02, 0.2) * area
aspect_ratio = random.uniform(0.3, 3)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w <= img.size()[2] and h <= img.size()[1]:
x1 = random.randint(0, img.size()[1] - h)
y1 = random.randint(0, img.size()[2] - w)
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
return img, target
return img, target
class ToTensor:
def __call__(self, image, target):
# convert [0, 255] to [0, 1]
image = F.to_tensor(image)
return image, target
def build_transforms(cfg, is_train):
transforms = []
transforms.append(ToTensor())
if is_train:
transforms.append(RandomHorizontalFlip())
if cfg.INPUT.IMAGE_CUTOUT:
transforms.append(Cutout())
if cfg.INPUT.IMAGE_ERASE:
transforms.append(RandomErasing())
return Compose(transforms)

View File

@ -0,0 +1,436 @@
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import datetime
import errno
import json
import os
import os.path as osp
import pickle
import random
import time
from collections import defaultdict, deque
import numpy as np
import torch
import torch.distributed as dist
from tabulate import tabulate
# -------------------------------------------------------- #
# Logger #
# -------------------------------------------------------- #
class SmoothedValue(object):
"""
Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value,
)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append("{}: {}".format(name, str(meter)))
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(
"{} Total time: {} ({:.4f} s / it)".format(
header, total_time_str, total_time / len(iterable)
)
)
# -------------------------------------------------------- #
# Distributed training #
# -------------------------------------------------------- #
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
elif "SLURM_PROCID" in os.environ:
args.rank = int(os.environ["SLURM_PROCID"])
args.gpu = args.rank % torch.cuda.device_count()
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
# -------------------------------------------------------- #
# File operation #
# -------------------------------------------------------- #
def filename(path):
return osp.splitext(osp.basename(path))[0]
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def read_json(fpath):
with open(fpath, "r") as f:
obj = json.load(f)
return obj
def write_json(obj, fpath):
mkdir(osp.dirname(fpath))
_obj = obj.copy()
for k, v in _obj.items():
if isinstance(v, np.ndarray):
_obj.pop(k)
with open(fpath, "w") as f:
json.dump(_obj, f, indent=4, separators=(",", ": "))
def symlink(src, dst, overwrite=True, **kwargs):
if os.path.lexists(dst) and overwrite:
os.remove(dst)
os.symlink(src, dst, **kwargs)
# -------------------------------------------------------- #
# Misc #
# -------------------------------------------------------- #
def create_small_table(small_dict):
"""
Create a small table using the keys of small_dict as headers. This is only
suitable for small dictionaries.
Args:
small_dict (dict): a result dictionary of only a few items.
Returns:
str: the table as a string.
"""
keys, values = tuple(zip(*small_dict.items()))
table = tabulate(
[values],
headers=keys,
tablefmt="pipe",
floatfmt=".3f",
stralign="center",
numalign="center",
)
return table
def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor):
def f(x):
if x >= warmup_iters:
return 1
alpha = float(x) / warmup_iters
return warmup_factor * (1 - alpha) + alpha
return torch.optim.lr_scheduler.LambdaLR(optimizer, f)
def resume_from_ckpt(ckpt_path, model, optimizer=None, lr_scheduler=None):
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt["model"], strict=False)
if optimizer is not None:
optimizer.load_state_dict(ckpt["optimizer"])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(ckpt["lr_scheduler"])
print(f"loaded checkpoint {ckpt_path}")
print(f"model was trained for {ckpt['epoch']} epochs")
return ckpt["epoch"]
def set_random_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

25115
coat-pvtv2-b2/vis/results.json Normal file

File diff suppressed because it is too large Load Diff