COAT/coat-pvtv2-b2/datasets/base.py

43 lines
1.5 KiB
Python

# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import torch
from PIL import Image
class BaseDataset:
"""
Base class of person search dataset.
"""
def __init__(self, root, transforms, split):
self.root = root
self.transforms = transforms
self.split = split
assert self.split in ("train", "gallery", "query")
self.annotations = self._load_annotations()
def _load_annotations(self):
"""
For each image, load its annotation that is a dictionary with the following keys:
img_name (str): image name
img_path (str): image path
boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format
pids (np.array[N]): person IDs corresponding to these boxes
cam_id (int): camera ID (only for PRW dataset)
"""
raise NotImplementedError
def __getitem__(self, index):
anno = self.annotations[index]
img = Image.open(anno["img_path"]).convert("RGB")
boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32)
labels = torch.as_tensor(anno["pids"], dtype=torch.int64)
target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels}
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.annotations)