43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
|
# This file is part of COAT, and is distributed under the
|
||
|
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
|
||
|
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
|
||
|
|
||
|
import torch
|
||
|
from PIL import Image
|
||
|
|
||
|
class BaseDataset:
|
||
|
"""
|
||
|
Base class of person search dataset.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, root, transforms, split):
|
||
|
self.root = root
|
||
|
self.transforms = transforms
|
||
|
self.split = split
|
||
|
assert self.split in ("train", "gallery", "query")
|
||
|
self.annotations = self._load_annotations()
|
||
|
|
||
|
def _load_annotations(self):
|
||
|
"""
|
||
|
For each image, load its annotation that is a dictionary with the following keys:
|
||
|
img_name (str): image name
|
||
|
img_path (str): image path
|
||
|
boxes (np.array[N, 4]): ground-truth boxes in (x1, y1, x2, y2) format
|
||
|
pids (np.array[N]): person IDs corresponding to these boxes
|
||
|
cam_id (int): camera ID (only for PRW dataset)
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
anno = self.annotations[index]
|
||
|
img = Image.open(anno["img_path"]).convert("RGB")
|
||
|
boxes = torch.as_tensor(anno["boxes"], dtype=torch.float32)
|
||
|
labels = torch.as_tensor(anno["pids"], dtype=torch.int64)
|
||
|
target = {"img_name": anno["img_name"], "boxes": boxes, "labels": labels}
|
||
|
if self.transforms is not None:
|
||
|
img, target = self.transforms(img, target)
|
||
|
return img, target
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.annotations)
|