105 lines
3.5 KiB
Python
105 lines
3.5 KiB
Python
# This file is part of COAT, and is distributed under the
|
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
|
|
|
import torch
|
|
from utils.transforms import build_transforms
|
|
from utils.utils import create_small_table
|
|
from .cuhk_sysu import CUHKSYSU
|
|
from .prw import PRW
|
|
|
|
def print_statistics(dataset):
|
|
"""
|
|
Print dataset statistics.
|
|
"""
|
|
num_imgs = len(dataset.annotations)
|
|
num_boxes = 0
|
|
pid_set = set()
|
|
for anno in dataset.annotations:
|
|
num_boxes += anno["boxes"].shape[0]
|
|
for pid in anno["pids"]:
|
|
pid_set.add(pid)
|
|
statistics = {
|
|
"dataset": dataset.name,
|
|
"split": dataset.split,
|
|
"num_images": num_imgs,
|
|
"num_boxes": num_boxes,
|
|
}
|
|
if dataset.name != "CUHK-SYSU" or dataset.split != "query":
|
|
pid_list = sorted(list(pid_set))
|
|
if dataset.split == "query":
|
|
num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
|
|
statistics.update(
|
|
{
|
|
"num_labeled_pids": num_pids,
|
|
"min_labeled_pid": int(min_pid),
|
|
"max_labeled_pid": int(max_pid),
|
|
}
|
|
)
|
|
else:
|
|
unlabeled_pid = pid_list[-1]
|
|
pid_list = pid_list[:-1] # remove unlabeled pid
|
|
num_pids, min_pid, max_pid = len(pid_list), min(pid_list), max(pid_list)
|
|
statistics.update(
|
|
{
|
|
"num_labeled_pids": num_pids,
|
|
"min_labeled_pid": int(min_pid),
|
|
"max_labeled_pid": int(max_pid),
|
|
"unlabeled_pid": int(unlabeled_pid),
|
|
}
|
|
)
|
|
print(f"=> {dataset.name}-{dataset.split} loaded:\n" + create_small_table(statistics))
|
|
|
|
|
|
def build_dataset(dataset_name, root, transforms, split, verbose=True):
|
|
if dataset_name == "CUHK-SYSU":
|
|
dataset = CUHKSYSU(root, transforms, split)
|
|
elif dataset_name == "PRW":
|
|
dataset = PRW(root, transforms, split)
|
|
else:
|
|
raise NotImplementedError(f"Unknow dataset: {dataset_name}")
|
|
if verbose:
|
|
print_statistics(dataset)
|
|
return dataset
|
|
|
|
|
|
def collate_fn(batch):
|
|
return tuple(zip(*batch))
|
|
|
|
|
|
def build_train_loader(cfg):
|
|
transforms = build_transforms(cfg, is_train=True)
|
|
dataset = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "train")
|
|
return torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=cfg.INPUT.BATCH_SIZE_TRAIN,
|
|
shuffle=True,
|
|
num_workers=cfg.INPUT.NUM_WORKERS_TRAIN,
|
|
pin_memory=True,
|
|
drop_last=True,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
|
|
def build_test_loader(cfg):
|
|
transforms = build_transforms(cfg, is_train=False)
|
|
gallery_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "gallery")
|
|
query_set = build_dataset(cfg.INPUT.DATASET, cfg.INPUT.DATA_ROOT, transforms, "query")
|
|
gallery_loader = torch.utils.data.DataLoader(
|
|
gallery_set,
|
|
batch_size=cfg.INPUT.BATCH_SIZE_TEST,
|
|
shuffle=False,
|
|
num_workers=cfg.INPUT.NUM_WORKERS_TEST,
|
|
pin_memory=True,
|
|
collate_fn=collate_fn,
|
|
)
|
|
query_loader = torch.utils.data.DataLoader(
|
|
query_set,
|
|
batch_size=cfg.INPUT.BATCH_SIZE_TEST,
|
|
shuffle=False,
|
|
num_workers=cfg.INPUT.NUM_WORKERS_TEST,
|
|
pin_memory=True,
|
|
collate_fn=collate_fn,
|
|
)
|
|
return gallery_loader, query_loader
|