COAT/coat-pvtv2-b2/train.py

155 lines
5.1 KiB
Python
Raw Normal View History

2023-10-10 21:52:30 +08:00
# This file is part of COAT, and is distributed under the
# OSI-approved BSD 3-Clause License. See top-level LICENSE file or
# https://github.com/Kitware/COAT/blob/master/LICENSE for details.
import argparse
import datetime
import os.path as osp
import time
import torch
import torch.utils.data
from datasets import build_test_loader, build_train_loader
from defaults import get_default_cfg
from engine import evaluate_performance, train_one_epoch
from models.coat import COAT
from utils.utils import mkdir, resume_from_ckpt, save_on_master, set_random_seed
from loss.softmax_loss import SoftmaxLoss
def main(args):
cfg = get_default_cfg()
if args.cfg_file:
cfg.merge_from_file(args.cfg_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
device = torch.device(cfg.DEVICE)
if cfg.SEED >= 0:
set_random_seed(cfg.SEED)
print("Creating model...")
model = COAT(cfg)
model.to(device)
print("Loading data...")
train_loader = build_train_loader(cfg)
gallery_loader, query_loader = build_test_loader(cfg)
softmax_criterion_s2 = None
softmax_criterion_s3 = None
if cfg.MODEL.LOSS.USE_SOFTMAX:
softmax_criterion_s2 = SoftmaxLoss(cfg)
softmax_criterion_s3 = SoftmaxLoss(cfg)
softmax_criterion_s2.to(device)
softmax_criterion_s3.to(device)
if args.eval:
assert args.ckpt, "--ckpt must be specified when --eval enabled"
resume_from_ckpt(args.ckpt, model)
evaluate_performance(
model,
gallery_loader,
query_loader,
device,
use_gt=cfg.EVAL_USE_GT,
use_cache=cfg.EVAL_USE_CACHE,
use_cbgm=cfg.EVAL_USE_CBGM,
gallery_size=cfg.EVAL_GALLERY_SIZE,
)
exit(0)
params = [p for p in model.parameters() if p.requires_grad]
if cfg.MODEL.LOSS.USE_SOFTMAX:
params_softmax_s2 = [p for p in softmax_criterion_s2.parameters() if p.requires_grad]
params_softmax_s3 = [p for p in softmax_criterion_s3.parameters() if p.requires_grad]
params.extend(params_softmax_s2)
params.extend(params_softmax_s3)
optimizer = torch.optim.SGD(
params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.SGD_MOMENTUM,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=cfg.SOLVER.LR_DECAY_MILESTONES, gamma=cfg.SOLVER.GAMMA
)
start_epoch = 0
if args.resume:
assert args.ckpt, "--ckpt must be specified when --resume enabled"
start_epoch = resume_from_ckpt(args.ckpt, model, optimizer, lr_scheduler) + 1
print("Creating output folder...")
output_dir = cfg.OUTPUT_DIR
mkdir(output_dir)
path = osp.join(output_dir, "config.yaml")
with open(path, "w") as f:
f.write(cfg.dump())
print(f"Full config is saved to {path}")
tfboard = None
if cfg.TF_BOARD:
from torch.utils.tensorboard import SummaryWriter
tf_log_path = osp.join(output_dir, "tf_log")
mkdir(tf_log_path)
tfboard = SummaryWriter(log_dir=tf_log_path)
print(f"TensorBoard files are saved to {tf_log_path}")
print("Start training...")
start_time = time.time()
for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
train_one_epoch(cfg, model, optimizer, train_loader, device, epoch, tfboard, softmax_criterion_s2, softmax_criterion_s3)
lr_scheduler.step()
# only save the last three checkpoints
if epoch >= cfg.SOLVER.MAX_EPOCHS - 3:
save_on_master(
{
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
},
osp.join(output_dir, f"epoch_{epoch}.pth"),
)
# evaluate the current checkpoint
evaluate_performance(
model,
gallery_loader,
query_loader,
device,
use_gt=cfg.EVAL_USE_GT,
use_cache=cfg.EVAL_USE_CACHE,
use_cbgm=cfg.EVAL_USE_CBGM,
gallery_size=cfg.EVAL_GALLERY_SIZE,
)
if tfboard:
tfboard.close()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Total training time {total_time_str}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a person search network.")
parser.add_argument("--cfg", dest="cfg_file", help="Path to configuration file.")
parser.add_argument(
"--eval", action="store_true", help="Evaluate the performance of a given checkpoint."
)
parser.add_argument(
"--resume", action="store_true", help="Resume from the specified checkpoint."
)
parser.add_argument("--ckpt", help="Path to checkpoint to resume or evaluate.")
parser.add_argument(
"opts", nargs=argparse.REMAINDER, help="Modify config options using the command-line"
)
args = parser.parse_args()
main(args)