import torch
import logging
from tqdm import tqdm
import numpy as np
import re
from datetime import datetime
from os.path import exists, join
from torch.utils.data import DataLoader
from pathlib import Path
from .base_pipeline import BasePipeline
from ..dataloaders import TorchDataloader, ConcatBatcher
from torch.utils.tensorboard import SummaryWriter
from ..utils import latest_torch_ckpt
from ...utils import make_dir, PIPELINE, LogRecord, get_runid, code2md
from ...datasets.utils import BEVBox3D
from ...metrics.mAP import mAP
logging.setLogRecordFactory(LogRecord)
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s - %(asctime)s - %(module)s - %(message)s',
)
log = logging.getLogger(__name__)
[docs]class ObjectDetection(BasePipeline):
"""Pipeline for object detection."""
[docs] def __init__(self,
model,
dataset=None,
name='ObjectDetection',
main_log_dir='./logs/',
device='cuda',
split='train',
**kwargs):
super().__init__(model=model,
dataset=dataset,
name=name,
main_log_dir=main_log_dir,
device=device,
split=split,
**kwargs)
[docs] def run_inference(self, data):
"""Run inference on given data.
Args:
data: A raw data.
Returns:
Returns the inference results.
"""
model = self.model
model.eval()
# If run_inference is called on raw data.
if isinstance(data, dict):
batcher = ConcatBatcher(self.device, model.cfg.name)
data = batcher.collate_fn([{
'data': data,
'attr': {
'split': 'test'
}
}])
data.to(self.device)
with torch.no_grad():
results = model(data)
boxes = model.inference_end(results, data)
return boxes
[docs] def run_test(self):
"""Run test with test data split, computes mean average precision of the
prediction results.
"""
model = self.model
dataset = self.dataset
device = self.device
cfg = self.cfg
model.eval()
timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
log.info("DEVICE : {}".format(device))
log_file_path = join(cfg.logs_dir, 'log_test_' + timestamp + '.txt')
log.info("Logging in file : {}".format(log_file_path))
log.addHandler(logging.FileHandler(log_file_path))
batcher = ConcatBatcher(device, model.cfg.name)
test_split = TorchDataloader(dataset=dataset.get_split('test'),
preprocess=model.preprocess,
transform=model.transform,
use_cache=False,
shuffle=False)
test_loader = DataLoader(
test_split,
batch_size=cfg.test_batch_size,
num_workers=cfg.get('num_workers', 4),
pin_memory=cfg.get('pin_memory', True),
collate_fn=batcher.collate_fn,
worker_init_fn=lambda x: np.random.seed(x + np.uint32(
torch.utils.data.get_worker_info().seed)))
self.load_ckpt(model.cfg.ckpt_path)
if cfg.get('test_compute_metric', True):
self.run_valid()
log.info("Started testing")
self.test_ious = []
pred = []
with torch.no_grad():
for data in tqdm(test_loader, desc='testing'):
results = self.run_inference(data)
pred.extend(results)
dataset.save_test_result(results, data.attr)
[docs] def run_valid(self):
"""Run validation with validation data split, computes mean average
precision and the loss of the prediction results.
"""
model = self.model
dataset = self.dataset
device = self.device
cfg = self.cfg
model.eval()
timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
log.info("DEVICE : {}".format(device))
log_file_path = join(cfg.logs_dir, 'log_valid_' + timestamp + '.txt')
log.info("Logging in file : {}".format(log_file_path))
log.addHandler(logging.FileHandler(log_file_path))
batcher = ConcatBatcher(device, model.cfg.name)
valid_dataset = dataset.get_split('validation')
valid_split = TorchDataloader(dataset=valid_dataset,
preprocess=model.preprocess,
transform=model.transform,
use_cache=dataset.cfg.use_cache,
shuffle=True,
steps_per_epoch=dataset.cfg.get(
'steps_per_epoch_valid', None))
valid_loader = DataLoader(
valid_split,
batch_size=cfg.val_batch_size,
num_workers=cfg.get('num_workers', 4),
pin_memory=cfg.get('pin_memory', False),
collate_fn=batcher.collate_fn,
worker_init_fn=lambda x: np.random.seed(x + np.uint32(
torch.utils.data.get_worker_info().seed)))
log.info("Started validation")
self.valid_losses = {}
pred = []
gt = []
with torch.no_grad():
for data in tqdm(valid_loader, desc='validation'):
data.to(device)
results = model(data)
loss = model.loss(results, data)
for l, v in loss.items():
if not l in self.valid_losses:
self.valid_losses[l] = []
self.valid_losses[l].append(v.cpu().numpy())
# convert to bboxes for mAP evaluation
boxes = model.inference_end(results, data)
pred.extend([BEVBox3D.to_dicts(b) for b in boxes])
gt.extend([BEVBox3D.to_dicts(b) for b in data.bbox_objs])
sum_loss = 0
desc = "validation - "
for l, v in self.valid_losses.items():
desc += " %s: %.03f" % (l, np.mean(v))
sum_loss += np.mean(v)
desc += " > loss: %.03f" % sum_loss
log.info(desc)
overlaps = cfg.get("overlaps", [0.5])
similar_classes = cfg.get("similar_classes", {})
difficulties = cfg.get("difficulties", [0])
ap = mAP(pred,
gt,
model.classes,
difficulties,
overlaps,
similar_classes=similar_classes)
log.info("")
log.info("=============== mAP BEV ===============")
log.info(("class \\ difficulty " +
"{:>5} " * len(difficulties)).format(*difficulties))
for i, c in enumerate(model.classes):
log.info(("{:<20} " + "{:>5.2f} " * len(difficulties)).format(
c + ":", *ap[i, :, 0]))
log.info("Overall: {:.2f}".format(np.mean(ap[:, -1])))
self.valid_losses["mAP BEV"] = np.mean(ap[:, -1])
ap = mAP(pred,
gt,
model.classes,
difficulties,
overlaps,
similar_classes=similar_classes,
bev=False)
log.info("")
log.info("=============== mAP 3D ===============")
log.info(("class \\ difficulty " +
"{:>5} " * len(difficulties)).format(*difficulties))
for i, c in enumerate(model.classes):
log.info(("{:<20} " + "{:>5.2f} " * len(difficulties)).format(
c + ":", *ap[i, :, 0]))
log.info("Overall: {:.2f}".format(np.mean(ap[:, -1])))
self.valid_losses["mAP 3D"] = np.mean(ap[:, -1])
[docs] def run_train(self):
"""Run training with train data split."""
model = self.model
device = self.device
dataset = self.dataset
cfg = self.cfg
log.info("DEVICE : {}".format(device))
timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt')
log.info("Logging in file : {}".format(log_file_path))
log.addHandler(logging.FileHandler(log_file_path))
batcher = ConcatBatcher(device, model.cfg.name)
train_dataset = dataset.get_split('training')
train_split = TorchDataloader(dataset=train_dataset,
preprocess=model.preprocess,
transform=model.transform,
use_cache=dataset.cfg.use_cache,
steps_per_epoch=dataset.cfg.get(
'steps_per_epoch_train', None))
train_loader = DataLoader(
train_split,
batch_size=cfg.batch_size,
num_workers=cfg.get('num_workers', 4),
pin_memory=cfg.get('pin_memory', False),
collate_fn=batcher.collate_fn,
worker_init_fn=lambda x: np.random.seed(x + np.uint32(
torch.utils.data.get_worker_info().seed))
) # numpy expects np.uint32, whereas torch returns np.uint64.
self.optimizer, self.scheduler = model.get_optimizer(cfg.optimizer)
is_resume = model.cfg.get('is_resume', True)
start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume)
dataset_name = dataset.name if dataset is not None else ''
tensorboard_dir = join(
self.cfg.train_sum_dir,
model.__class__.__name__ + '_' + dataset_name + '_torch')
runid = get_runid(tensorboard_dir)
self.tensorboard_dir = join(self.cfg.train_sum_dir,
runid + '_' + Path(tensorboard_dir).name)
writer = SummaryWriter(self.tensorboard_dir)
self.save_config(writer)
log.info("Writing summary in {}.".format(self.tensorboard_dir))
log.info("Started training")
for epoch in range(start_ep, cfg.max_epoch + 1):
log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===')
model.train()
self.losses = {}
process_bar = tqdm(train_loader, desc='training')
for data in process_bar:
data.to(device)
results = model(data)
loss = model.loss(results, data)
loss_sum = sum(loss.values())
self.optimizer.zero_grad()
loss_sum.backward()
if model.cfg.get('grad_clip_norm', -1) > 0:
torch.nn.utils.clip_grad_value_(model.parameters(),
model.cfg.grad_clip_norm)
self.optimizer.step()
desc = "training - "
for l, v in loss.items():
if not l in self.losses:
self.losses[l] = []
self.losses[l].append(v.cpu().detach().numpy())
desc += " %s: %.03f" % (l, v.cpu().detach().numpy())
desc += " > loss: %.03f" % loss_sum.cpu().detach().numpy()
process_bar.set_description(desc)
process_bar.refresh()
if self.scheduler is not None:
self.scheduler.step()
# --------------------- validation
self.run_valid()
self.save_logs(writer, epoch)
if epoch % cfg.save_ckpt_freq == 0:
self.save_ckpt(epoch)
[docs] def save_logs(self, writer, epoch):
for key, val in self.losses.items():
writer.add_scalar("train/" + key, np.mean(val), epoch)
for key, val in self.valid_losses.items():
writer.add_scalar("valid/" + key, np.mean(val), epoch)
[docs] def load_ckpt(self, ckpt_path=None, is_resume=True):
train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint')
make_dir(train_ckpt_dir)
epoch = 0
if ckpt_path is None:
ckpt_path = latest_torch_ckpt(train_ckpt_dir)
if ckpt_path is not None and is_resume:
log.info('ckpt_path not given. Restore from the latest ckpt')
epoch = int(re.findall(r'\d+', ckpt_path)[-1]) + 1
else:
log.info('Initializing from scratch.')
return epoch
if not exists(ckpt_path):
raise FileNotFoundError(f' ckpt {ckpt_path} not found')
log.info(f'Loading checkpoint {ckpt_path}')
ckpt = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(ckpt['model_state_dict'])
if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'):
log.info(f'Loading checkpoint optimizer_state_dict')
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
if 'scheduler_state_dict' in ckpt and hasattr(self, 'scheduler'):
log.info(f'Loading checkpoint scheduler_state_dict')
self.scheduler.load_state_dict(ckpt['scheduler_state_dict'])
return epoch
[docs] def save_ckpt(self, epoch):
path_ckpt = join(self.cfg.logs_dir, 'checkpoint')
make_dir(path_ckpt)
torch.save(
dict(epoch=epoch,
model_state_dict=self.model.state_dict(),
optimizer_state_dict=self.optimizer.state_dict()),
#scheduler_state_dict=self.scheduler.state_dict()),
join(path_ckpt, f'ckpt_{epoch:05d}.pth'))
log.info(f'Epoch {epoch:3d}: save ckpt to {path_ckpt:s}')
[docs] def save_config(self, writer):
"""Save experiment configuration with tensorboard summary."""
writer.add_text("Description/CloudViewer-ML", self.cfg_tb['readme'], 0)
writer.add_text("Description/Command line", self.cfg_tb['cmd_line'], 0)
writer.add_text('Configuration/Dataset',
code2md(self.cfg_tb['dataset'], language='json'), 0)
writer.add_text('Configuration/Model',
code2md(self.cfg_tb['model'], language='json'), 0)
writer.add_text('Configuration/Pipeline',
code2md(self.cfg_tb['pipeline'], language='json'), 0)
PIPELINE._register_module(ObjectDetection, "torch")