import torch, pickle
import torch.nn as nn
import numpy as np
import logging
import sys
from datetime import datetime
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, IterableDataset, DataLoader
from pathlib import Path
from sklearn.metrics import confusion_matrix
from os.path import exists, join, isfile, dirname, abspath
from .base_pipeline import BasePipeline
from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher
from ..utils import latest_torch_ckpt
from ..modules.losses import SemSegLoss
from ..modules.metrics import SemSegMetric
from ...utils import make_dir, LogRecord, Config, PIPELINE, get_runid, code2md
from ...datasets.utils import DataProcessing
from ...datasets import InferenceDummySplit
logging.setLogRecordFactory(LogRecord)
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s - %(asctime)s - %(module)s - %(message)s',
)
log = logging.getLogger(__name__)
[docs]class SemanticSegmentation(BasePipeline):
"""This class allows you to perform semantic segmentation for both training
and inference using the Torch. This pipeline has multiple stages: Pre-
processing, loading dataset, testing, and inference or training.
**Example:**
This example loads the Semantic Segmentation and performs a training using the SemanticKITTI dataset.
import torch, pickle
import torch.nn as nn
from .base_pipeline import BasePipeline
from torch.utils.tensorboard import SummaryWriter
from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher
Mydataset = TorchDataloader(dataset=dataset.get_split('training')),
MyModel = SemanticSegmentation(self,model,dataset=Mydataset, name='SemanticSegmentation',
name='MySemanticSegmentation',
batch_size=4,
val_batch_size=4,
test_batch_size=3,
max_epoch=100,
learning_rate=1e-2,
lr_decays=0.95,
save_ckpt_freq=20,
adam_lr=1e-2,
scheduler_gamma=0.95,
momentum=0.98,
main_log_dir='./logs/',
device='gpu',
split='train',
train_sum_dir='train_log')
**Args:**
dataset: The 3D ML dataset class. You can use the base dataset, sample datasets , or a custom dataset.
model: The model to be used for building the pipeline.
name: The name of the current training.
batch_size: The batch size to be used for training.
val_batch_size: The batch size to be used for validation.
test_batch_size: The batch size to be used for testing.
max_epoch: The maximum size of the epoch to be used for training.
leanring_rate: The hyperparameter that controls the weights during training. Also, known as step size.
lr_decays: The learning rate decay for the training.
save_ckpt_freq: The frequency in which the checkpoint should be saved.
adam_lr: The leanring rate to be applied for Adam optimization.
scheduler_gamma: The decaying factor associated with the scheduler.
momentum: The momentum that accelerates the training rate schedule.
main_log_dir: The directory where logs are stored.
device: The device to be used for training.
split: The dataset split to be used. In this example, we have used "train".
train_sum_dir: The directory where the trainig summary is stored.
**Returns:**
class: The corresponding class.
"""
[docs] def __init__(
self,
model,
dataset=None,
name='SemanticSegmentation',
batch_size=4,
val_batch_size=4,
test_batch_size=3,
max_epoch=100, # maximum epoch during training
learning_rate=1e-2, # initial learning rate
lr_decays=0.95,
save_ckpt_freq=20,
adam_lr=1e-2,
scheduler_gamma=0.95,
momentum=0.98,
main_log_dir='./logs/',
device='gpu',
split='train',
train_sum_dir='train_log',
**kwargs):
super().__init__(model=model,
dataset=dataset,
name=name,
batch_size=batch_size,
val_batch_size=val_batch_size,
test_batch_size=test_batch_size,
max_epoch=max_epoch,
learning_rate=learning_rate,
lr_decays=lr_decays,
save_ckpt_freq=save_ckpt_freq,
adam_lr=adam_lr,
scheduler_gamma=scheduler_gamma,
momentum=momentum,
main_log_dir=main_log_dir,
device=device,
split=split,
train_sum_dir=train_sum_dir,
**kwargs)
"""
Run inference on given data.
Args:
data: A raw data.
Returns:
Returns the inference results.
"""
[docs] def run_inference(self, data):
cfg = self.cfg
model = self.model
device = self.device
model.to(device)
model.device = device
model.eval()
batcher = self.get_batcher(device)
infer_dataset = InferenceDummySplit(data)
self.dataset_split = infer_dataset
infer_sampler = infer_dataset.sampler
infer_split = TorchDataloader(dataset=infer_dataset,
preprocess=model.preprocess,
transform=model.transform,
sampler=infer_sampler,
use_cache=False)
infer_loader = DataLoader(infer_split,
batch_size=cfg.batch_size,
sampler=get_sampler(infer_sampler),
collate_fn=batcher.collate_fn)
model.trans_point_sampler = infer_sampler.get_point_sampler()
self.curr_cloud_id = -1
self.test_probs = []
self.test_labels = []
self.ori_test_probs = []
self.ori_test_labels = []
with torch.no_grad():
for step, inputs in enumerate(infer_loader):
results = model(inputs['data'])
self.update_tests(infer_sampler, inputs, results)
inference_result = {
'predict_labels': self.ori_test_labels.pop(),
'predict_scores': self.ori_test_probs.pop()
}
return inference_result
"""
Run the test using the data passed.
"""
[docs] def run_test(self):
model = self.model
dataset = self.dataset
device = self.device
cfg = self.cfg
model.device = device
model.to(device)
model.eval()
timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
metric = SemSegMetric()
log.info("DEVICE : {}".format(device))
log_file_path = join(cfg.logs_dir, 'log_test_' + timestamp + '.txt')
log.info("Logging in file : {}".format(log_file_path))
log.addHandler(logging.FileHandler(log_file_path))
batcher = self.get_batcher(device)
test_dataset = dataset.get_split('test')
test_sampler = test_dataset.sampler
test_split = TorchDataloader(dataset=test_dataset,
preprocess=model.preprocess,
transform=model.transform,
sampler=test_sampler,
use_cache=dataset.cfg.use_cache)
test_loader = DataLoader(test_split,
batch_size=cfg.test_batch_size,
sampler=get_sampler(test_sampler),
collate_fn=batcher.collate_fn)
self.dataset_split = test_dataset
self.load_ckpt(model.cfg.ckpt_path)
model.trans_point_sampler = test_sampler.get_point_sampler()
self.curr_cloud_id = -1
self.test_probs = []
self.test_labels = []
self.ori_test_probs = []
self.ori_test_labels = []
log.info("Started testing")
with torch.no_grad():
for step, inputs in enumerate(test_loader):
if hasattr(inputs['data'], 'to'):
inputs['data'].to(device)
results = model(inputs['data'])
self.update_tests(test_sampler, inputs, results)
if self.complete_infer:
inference_result = {
'predict_labels': self.ori_test_labels.pop(),
'predict_scores': self.ori_test_probs.pop()
}
attr = self.dataset_split.get_attr(test_sampler.cloud_id)
dataset.save_test_result(inference_result, attr)
log.info("Finshed testing")
"""
Update tests using sampler, inputs, and results.
"""
[docs] def update_tests(self, sampler, inputs, results):
split = sampler.split
end_threshold = 0.5
if self.curr_cloud_id != sampler.cloud_id:
self.curr_cloud_id = sampler.cloud_id
num_points = sampler.possibilities[sampler.cloud_id].shape[0]
self.pbar = tqdm(total=num_points,
desc="{} {}/{}".format(split, self.curr_cloud_id,
len(sampler.dataset)))
self.pbar_update = 0
self.test_probs.append(
np.zeros(shape=[num_points, self.model.cfg.num_classes],
dtype=np.float16))
self.test_labels.append(np.zeros(shape=[num_points],
dtype=np.int16))
self.complete_infer = False
this_possiblility = sampler.possibilities[sampler.cloud_id]
self.pbar.update(this_possiblility[this_possiblility > end_threshold].shape[0] \
- self.pbar_update)
self.pbar_update = this_possiblility[this_possiblility >
end_threshold].shape[0]
self.test_probs[self.curr_cloud_id], self.test_labels[self.curr_cloud_id] \
= self.model.update_probs(inputs, results,
self.test_probs[self.curr_cloud_id],
self.test_labels[self.curr_cloud_id])
if split in ['test'] and this_possiblility[this_possiblility > end_threshold].shape[0] \
== this_possiblility.shape[0]:
proj_inds = self.model.preprocess(
self.dataset_split.get_data(self.curr_cloud_id), {
'split': split
}).get('proj_inds', None)
if proj_inds is None:
proj_inds = np.arange(
self.test_probs[self.curr_cloud_id].shape[0])
self.ori_test_probs.append(
self.test_probs[self.curr_cloud_id][proj_inds])
self.ori_test_labels.append(
self.test_labels[self.curr_cloud_id][proj_inds])
self.complete_infer = True
"""
Run the training on the self model.
"""
[docs] def run_train(self):
model = self.model
device = self.device
model.device = device
dataset = self.dataset
cfg = self.cfg
model.to(device)
log.info("DEVICE : {}".format(device))
timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt')
log.info("Logging in file : {}".format(log_file_path))
log.addHandler(logging.FileHandler(log_file_path))
Loss = SemSegLoss(self, model, dataset, device)
self.metric_train = SemSegMetric()
self.metric_val = SemSegMetric()
self.batcher = self.get_batcher(device)
train_dataset = dataset.get_split('train')
train_sampler = train_dataset.sampler
train_split = TorchDataloader(dataset=train_dataset,
preprocess=model.preprocess,
transform=model.transform,
sampler=train_sampler,
use_cache=dataset.cfg.use_cache,
steps_per_epoch=dataset.cfg.get(
'steps_per_epoch_train', None))
train_loader = DataLoader(
train_split,
batch_size=cfg.batch_size,
sampler=get_sampler(train_sampler),
num_workers=cfg.get('num_workers', 2),
pin_memory=cfg.get('pin_memory', True),
collate_fn=self.batcher.collate_fn,
worker_init_fn=lambda x: np.random.seed(x + np.uint32(
torch.utils.data.get_worker_info().seed))
) # numpy expects np.uint32, whereas torch returns np.uint64.
valid_dataset = dataset.get_split('validation')
valid_sampler = valid_dataset.sampler
valid_split = TorchDataloader(dataset=valid_dataset,
preprocess=model.preprocess,
transform=model.transform,
sampler=valid_sampler,
use_cache=dataset.cfg.use_cache,
steps_per_epoch=dataset.cfg.get(
'steps_per_epoch_valid', None))
valid_loader = DataLoader(
valid_split,
batch_size=cfg.val_batch_size,
sampler=get_sampler(valid_sampler),
num_workers=cfg.get('num_workers', 2),
pin_memory=cfg.get('pin_memory', True),
collate_fn=self.batcher.collate_fn,
worker_init_fn=lambda x: np.random.seed(x + np.uint32(
torch.utils.data.get_worker_info().seed)))
self.optimizer, self.scheduler = model.get_optimizer(cfg)
is_resume = model.cfg.get('is_resume', True)
self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume)
dataset_name = dataset.name if dataset is not None else ''
tensorboard_dir = join(
self.cfg.train_sum_dir,
model.__class__.__name__ + '_' + dataset_name + '_torch')
runid = get_runid(tensorboard_dir)
self.tensorboard_dir = join(self.cfg.train_sum_dir,
runid + '_' + Path(tensorboard_dir).name)
writer = SummaryWriter(self.tensorboard_dir)
self.save_config(writer)
log.info("Writing summary in {}.".format(self.tensorboard_dir))
log.info("Started training")
for epoch in range(0, cfg.max_epoch + 1):
log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===')
model.train()
self.metric_train.reset()
self.metric_val.reset()
self.losses = []
model.trans_point_sampler = train_sampler.get_point_sampler()
for step, inputs in enumerate(tqdm(train_loader, desc='training')):
if hasattr(inputs['data'], 'to'):
inputs['data'].to(device)
self.optimizer.zero_grad()
results = model(inputs['data'])
loss, gt_labels, predict_scores = model.get_loss(
Loss, results, inputs, device)
if predict_scores.size()[-1] == 0:
continue
loss.backward()
if model.cfg.get('grad_clip_norm', -1) > 0:
torch.nn.utils.clip_grad_value_(model.parameters(),
model.cfg.grad_clip_norm)
self.optimizer.step()
self.metric_train.update(predict_scores, gt_labels)
self.losses.append(loss.cpu().item())
self.scheduler.step()
# --------------------- validation
model.eval()
self.valid_losses = []
model.trans_point_sampler = valid_sampler.get_point_sampler()
with torch.no_grad():
for step, inputs in enumerate(
tqdm(valid_loader, desc='validation')):
if hasattr(inputs['data'], 'to'):
inputs['data'].to(device)
results = model(inputs['data'])
loss, gt_labels, predict_scores = model.get_loss(
Loss, results, inputs, device)
if predict_scores.size()[-1] == 0:
continue
self.metric_val.update(predict_scores, gt_labels)
self.valid_losses.append(loss.cpu().item())
self.save_logs(writer, epoch)
if epoch % cfg.save_ckpt_freq == 0:
self.save_ckpt(epoch)
"""
Get the batcher to be used based on the device and split.
"""
[docs] def get_batcher(self, device, split='training'):
batcher_name = getattr(self.model.cfg, 'batcher')
if batcher_name == 'DefaultBatcher':
batcher = DefaultBatcher()
elif batcher_name == 'ConcatBatcher':
batcher = ConcatBatcher(device, self.model.cfg.name)
else:
batcher = None
return batcher
"""
Save logs from the training and send results to TensorBoard.
"""
[docs] def save_logs(self, writer, epoch):
train_accs = self.metric_train.acc()
val_accs = self.metric_val.acc()
train_ious = self.metric_train.iou()
val_ious = self.metric_val.iou()
loss_dict = {
'Training loss': np.mean(self.losses),
'Validation loss': np.mean(self.valid_losses)
}
acc_dicts = [{
'Training accuracy': acc,
'Validation accuracy': val_acc
} for acc, val_acc in zip(train_accs, val_accs)]
iou_dicts = [{
'Training IoU': iou,
'Validation IoU': val_iou
} for iou, val_iou in zip(train_ious, val_ious)]
for key, val in loss_dict.items():
writer.add_scalar(key, val, epoch)
for key, val in acc_dicts[-1].items():
writer.add_scalar("{}/ Overall".format(key), val, epoch)
for key, val in iou_dicts[-1].items():
writer.add_scalar("{}/ Overall".format(key), val, epoch)
log.info(f"Loss train: {loss_dict['Training loss']:.3f} "
f" eval: {loss_dict['Validation loss']:.3f}")
log.info(f"Mean acc train: {acc_dicts[-1]['Training accuracy']:.3f} "
f" eval: {acc_dicts[-1]['Validation accuracy']:.3f}")
log.info(f"Mean IoU train: {iou_dicts[-1]['Training IoU']:.3f} "
f" eval: {iou_dicts[-1]['Validation IoU']:.3f}")
[docs] def load_ckpt(self, ckpt_path=None, is_resume=True):
"""Load a checkpoint. You must pass the checkpoint and indicate if you want to resume."""
train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint')
make_dir(train_ckpt_dir)
if ckpt_path is None:
ckpt_path = latest_torch_ckpt(train_ckpt_dir)
if ckpt_path is not None and is_resume:
log.info('ckpt_path not given. Restore from the latest ckpt')
else:
log.info('Initializing from scratch.')
return
if not exists(ckpt_path):
raise FileNotFoundError(f' ckpt {ckpt_path} not found')
log.info(f'Loading checkpoint {ckpt_path}')
ckpt = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(ckpt['model_state_dict'])
if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'):
log.info(f'Loading checkpoint optimizer_state_dict')
self.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
if 'scheduler_state_dict' in ckpt and hasattr(self, 'scheduler'):
log.info(f'Loading checkpoint scheduler_state_dict')
self.scheduler.load_state_dict(ckpt['scheduler_state_dict'])
"""
Save a checkpoint at the passed epoch.
"""
[docs] def save_ckpt(self, epoch):
path_ckpt = join(self.cfg.logs_dir, 'checkpoint')
make_dir(path_ckpt)
torch.save(
dict(epoch=epoch,
model_state_dict=self.model.state_dict(),
optimizer_state_dict=self.optimizer.state_dict(),
scheduler_state_dict=self.scheduler.state_dict()),
join(path_ckpt, f'ckpt_{epoch:05d}.pth'))
log.info(f'Epoch {epoch:3d}: save ckpt to {path_ckpt:s}')
"""
Save experiment configuration with Torch summary.
"""
[docs] def save_config(self, writer):
writer.add_text("Description/CloudViewer-ML", self.cfg_tb['readme'], 0)
writer.add_text("Description/Command line", self.cfg_tb['cmd_line'], 0)
writer.add_text('Configuration/Dataset',
code2md(self.cfg_tb['dataset'], language='json'), 0)
writer.add_text('Configuration/Model',
code2md(self.cfg_tb['model'], language='json'), 0)
writer.add_text('Configuration/Pipeline',
code2md(self.cfg_tb['pipeline'], language='json'), 0)
PIPELINE._register_module(SemanticSegmentation, "torch")