COAT/coat-pvtv2-b2/datasets/prw.py

98 lines
3.3 KiB
Python
Raw Permalink Normal View History

2023-10-10 21:52:30 +08:00
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import os.path as osp
import re
import numpy as np
from scipy.io import loadmat
from .base import BaseDataset
class PRW(BaseDataset):
def __init__(self, root, transforms, split):
self.name = "PRW"
self.img_prefix = osp.join(root, "frames")
super(PRW, self).__init__(root, transforms, split)
def _get_cam_id(self, img_name):
match = re.search(r"c\d", img_name).group().replace("c", "")
return int(match)
def _load_queries(self):
query_info = osp.join(self.root, "query_info.txt")
with open(query_info, "rb") as f:
raw = f.readlines()
queries = []
for line in raw:
linelist = str(line, "utf-8").split(" ")
pid = int(linelist[0])
x, y, w, h = (
float(linelist[1]),
float(linelist[2]),
float(linelist[3]),
float(linelist[4]),
)
roi = np.array([x, y, x + w, y + h]).astype(np.int32)
roi = np.clip(roi, 0, None) # several coordinates are negative
img_name = linelist[5][:-2] + ".jpg"
queries.append(
{
"img_name": img_name,
"img_path": osp.join(self.img_prefix, img_name),
"boxes": roi[np.newaxis, :],
"pids": np.array([pid]),
"cam_id": self._get_cam_id(img_name),
}
)
return queries
def _load_split_img_names(self):
"""
Load the image names for the specific split.
"""
assert self.split in ("train", "gallery")
if self.split == "train":
imgs = loadmat(osp.join(self.root, "frame_train.mat"))["img_index_train"]
else:
imgs = loadmat(osp.join(self.root, "frame_test.mat"))["img_index_test"]
return [img[0][0] + ".jpg" for img in imgs]
def _load_annotations(self):
if self.split == "query":
return self._load_queries()
annotations = []
imgs = self._load_split_img_names()
for img_name in imgs:
anno_path = osp.join(self.root, "annotations", img_name)
anno = loadmat(anno_path)
box_key = "box_new"
if box_key not in anno.keys():
box_key = "anno_file"
if box_key not in anno.keys():
box_key = "anno_previous"
rois = anno[box_key][:, 1:]
ids = anno[box_key][:, 0]
rois = np.clip(rois, 0, None) # several coordinates are negative
assert len(rois) == len(ids)
rois[:, 2:] += rois[:, :2]
ids[ids == -2] = 5555 # assign pid = 5555 for unlabeled people
annotations.append(
{
"img_name": img_name,
"img_path": osp.join(self.img_prefix, img_name),
"boxes": rois.astype(np.int32),
# (training pids) 1, 2,..., 478, 480, 481, 482, 483, 932, 5555
"pids": ids.astype(np.int32),
"cam_id": self._get_cam_id(img_name),
}
)
return annotations