# Common libs
import time
import numpy as np
import pickle
import torch
import yaml
from os import listdir
from os.path import exists, join, isdir
from ..models.kpconv import batch_grid_subsampling, batch_neighbors
from torch.utils.data import Sampler, get_worker_info
class KPConvBatch:
"""Batched results for KPConv."""
def __init__(self, batches):
"""Initialize.
Args:
batches: A batch of data
Returns:
class: The corresponding class.
"""
self.neighborhood_limits = []
p_list = []
f_list = []
l_list = []
fi_list = []
p0_list = []
s_list = []
R_list = []
r_inds_list = []
r_mask_list = []
val_labels_list = []
batch_n = 0
self.cfg = batches[0]['data']['cfg']
batch_limit = int(self.cfg.batch_limit)
for batch in batches:
# Stack batch
data = batch['data']
for p in data['p_list']:
batch_n += p.shape[0]
if batch_n > batch_limit:
break
p_list += data['p_list']
f_list += data['f_list']
l_list += data['l_list']
p0_list += data['p0_list']
s_list += data['s_list']
R_list += data['R_list']
r_inds_list += data['r_inds_list']
r_mask_list += data['r_mask_list']
val_labels_list += data['val_labels_list']
###################
# Concatenate batch
###################
stacked_points = np.concatenate(p_list, axis=0)
features = np.concatenate(f_list, axis=0)
labels = np.concatenate(l_list, axis=0)
frame_inds = np.array(fi_list, dtype=np.int32)
frame_centers = np.stack(p0_list, axis=0)
stack_lengths = np.array([pp.shape[0] for pp in p_list], dtype=np.int32)
scales = np.array(s_list, dtype=np.float32)
rots = np.stack(R_list, axis=0)
# Input features (Use reflectance, input height or all coordinates)
stacked_features = np.ones_like(stacked_points[:, :1], dtype=np.float32)
if self.cfg.in_features_dim == 1:
pass
elif self.cfg.in_features_dim == 2:
# Use original height coordinate
stacked_features = np.hstack((stacked_features, features[:, 2:3]))
elif self.cfg.in_features_dim == 3:
# Use height + reflectance
assert features.shape[1] > 3, "feat from dataset can not be None \
or try to set in_features_dim = 1, 2, 4"
stacked_features = np.hstack((stacked_features, features[:, 2:4]))
elif self.cfg.in_features_dim == 4:
# Use all coordinates
stacked_features = np.hstack((stacked_features, features[:, :3]))
elif self.cfg.in_features_dim == 5:
assert features.shape[1] >= 6, "feat from dataset should have \
at least 3 dims, or try to set in_features_dim = 1, 2, 4"
# Use color + height
stacked_features = np.hstack((stacked_features, features[:, 2:6]))
elif self.cfg.in_features_dim >= 6:
assert features.shape[1] > 3, "feat from dataset can not be None \
or try to set in_features_dim = 1, 2, 4"
# Use all coordinates + reflectance
stacked_features = np.hstack((stacked_features, features))
else:
raise ValueError('in_features_dim should be >= 0')
#######################
# Create network inputs
#######################
#
# Points, neighbors, pooling indices for each layers
#
# Get the whole input list
input_list = self.segmentation_inputs(stacked_points, stacked_features,
labels.astype(np.int64),
stack_lengths)
# Add scale and rotation for testing
input_list += [
scales, rots, frame_inds, frame_centers, r_inds_list, r_mask_list,
val_labels_list
]
input_list = [self.cfg.num_layers] + input_list
# Number of layers
L = int(input_list[0])
# Extract input tensors from the list of numpy array
ind = 1
self.points = [
torch.from_numpy(nparray) for nparray in input_list[ind:ind + L]
]
ind += L
self.neighbors = [
torch.from_numpy(nparray) for nparray in input_list[ind:ind + L]
]
ind += L
self.pools = [
torch.from_numpy(nparray) for nparray in input_list[ind:ind + L]
]
ind += L
self.upsamples = [
torch.from_numpy(nparray) for nparray in input_list[ind:ind + L]
]
ind += L
self.lengths = [
torch.from_numpy(nparray) for nparray in input_list[ind:ind + L]
]
ind += L
self.features = torch.from_numpy(input_list[ind])
ind += 1
self.labels = torch.from_numpy(input_list[ind])
ind += 1
self.scales = torch.from_numpy(input_list[ind])
ind += 1
self.rots = torch.from_numpy(input_list[ind])
ind += 1
self.frame_inds = torch.from_numpy(input_list[ind])
ind += 1
self.frame_centers = torch.from_numpy(input_list[ind])
ind += 1
self.reproj_inds = input_list[ind]
ind += 1
self.reproj_masks = input_list[ind]
ind += 1
self.val_labels = input_list[ind]
return
def big_neighborhood_filter(self, neighbors, layer):
"""Filter neighborhoods with max number of neighbors.
Limit is set to keep XX% of the neighborhoods untouched. Limit is
computed at initialization
"""
# crop neighbors matrix
if len(self.neighborhood_limits) > 0:
return neighbors[:, :self.neighborhood_limits[layer]]
else:
return neighbors
def segmentation_inputs(self, stacked_points, stacked_features, labels,
stack_lengths):
# Starting radius of convolutions
r_normal = self.cfg.first_subsampling_dl * self.cfg.conv_radius
# Starting layer
layer_blocks = []
# Lists of inputs
input_points = []
input_neighbors = []
input_pools = []
input_upsamples = []
input_stack_lengths = []
deform_layers = []
######################
# Loop over the blocks
######################
arch = self.cfg.architecture
for block_i, block in enumerate(arch):
# Get all blocks of the layer
if not ('pool' in block or 'strided' in block or
'global' in block or 'upsample' in block):
layer_blocks += [block]
continue
# Convolution neighbors indices
# *****************************
deform_layer = False
if layer_blocks:
# Convolutions are done in this layer, compute the neighbors with the good radius
if np.any(['deformable' in blck for blck in layer_blocks]):
r = r_normal * self.cfg.deform_radius / self.cfg.conv_radius
deform_layer = True
else:
r = r_normal
conv_i = batch_neighbors(stacked_points, stacked_points,
stack_lengths, stack_lengths, r)
else:
# This layer only perform pooling, no neighbors required
conv_i = np.zeros((0, 1), dtype=np.int32)
# Pooling neighbors indices
# *************************
# If end of layer is a pooling operation
if 'pool' in block or 'strided' in block:
# New subsampling length
dl = 2 * r_normal / self.cfg.conv_radius
# Subsampled points
pool_p, pool_b = batch_grid_subsampling(stacked_points,
stack_lengths,
sampleDl=dl)
# Radius of pooled neighbors
if 'deformable' in block:
r = r_normal * self.cfg.deform_radius / self.cfg.conv_radius
deform_layer = True
else:
r = r_normal
# Subsample indices
pool_i = batch_neighbors(pool_p, stacked_points, pool_b,
stack_lengths, r)
# Upsample indices (with the radius of the next layer to keep wanted density)
up_i = batch_neighbors(stacked_points, pool_p, stack_lengths,
pool_b, 2 * r)
else:
# No pooling in the end of this layer, no pooling indices required
pool_i = np.zeros((0, 1), dtype=np.int32)
pool_p = np.zeros((0, 3), dtype=np.float32)
pool_b = np.zeros((0,), dtype=np.int32)
up_i = np.zeros((0, 1), dtype=np.int32)
# Reduce size of neighbors matrices by eliminating furthest point
conv_i = self.big_neighborhood_filter(conv_i, len(input_points))
pool_i = self.big_neighborhood_filter(pool_i, len(input_points))
if up_i.shape[0] > 0:
up_i = self.big_neighborhood_filter(up_i, len(input_points) + 1)
# Updating input lists
input_points += [stacked_points]
input_neighbors += [conv_i.astype(np.int64)]
input_pools += [pool_i.astype(np.int64)]
input_upsamples += [up_i.astype(np.int64)]
input_stack_lengths += [stack_lengths]
deform_layers += [deform_layer]
# New points for next layer
stacked_points = pool_p
stack_lengths = pool_b
# Update radius and reset blocks
r_normal *= 2
layer_blocks = []
# Stop when meeting a global pooling or upsampling
if 'global' in block or 'upsample' in block:
break
###############
# Return inputs
###############
# list of network inputs
li = input_points + input_neighbors + input_pools + input_upsamples + input_stack_lengths
li += [stacked_features, labels]
return li
def pin_memory(self):
"""Manual pinning of the memory."""
self.points = [in_tensor.pin_memory() for in_tensor in self.points]
self.neighbors = [
in_tensor.pin_memory() for in_tensor in self.neighbors
]
self.pools = [in_tensor.pin_memory() for in_tensor in self.pools]
self.upsamples = [
in_tensor.pin_memory() for in_tensor in self.upsamples
]
self.lengths = [in_tensor.pin_memory() for in_tensor in self.lengths]
self.features = self.features.pin_memory()
self.labels = self.labels.pin_memory()
self.scales = self.scales.pin_memory()
self.rots = self.rots.pin_memory()
self.frame_inds = self.frame_inds.pin_memory()
self.frame_centers = self.frame_centers.pin_memory()
return self
def to(self, device):
self.points = [in_tensor.to(device) for in_tensor in self.points]
self.neighbors = [in_tensor.to(device) for in_tensor in self.neighbors]
self.pools = [in_tensor.to(device) for in_tensor in self.pools]
self.upsamples = [in_tensor.to(device) for in_tensor in self.upsamples]
self.lengths = [in_tensor.to(device) for in_tensor in self.lengths]
self.features = self.features.to(device)
self.labels = self.labels.to(device)
self.scales = self.scales.to(device)
self.rots = self.rots.to(device)
self.frame_inds = self.frame_inds.to(device)
self.frame_centers = self.frame_centers.to(device)
return self
def unstack_points(self, layer=None):
"""Unstack the points."""
return self.unstack_elements('points', layer)
def unstack_neighbors(self, layer=None):
"""Unstack the neighbors indices."""
return self.unstack_elements('neighbors', layer)
def unstack_pools(self, layer=None):
"""Unstack the pooling indices."""
return self.unstack_elements('pools', layer)
def unstack_elements(self, element_name, layer=None, to_numpy=True):
"""Return a list of the stacked elements in the batch at a certain
layer.
If no layer is given, then return all layers
"""
if element_name == 'points':
elements = self.points
elif element_name == 'neighbors':
elements = self.neighbors
elif element_name == 'pools':
elements = self.pools[:-1]
else:
raise ValueError('Unknown element name: {:s}'.format(element_name))
all_p_list = []
for layer_i, layer_elems in enumerate(elements):
if layer is None or layer == layer_i:
i0 = 0
p_list = []
if element_name == 'pools':
lengths = self.lengths[layer_i + 1]
else:
lengths = self.lengths[layer_i]
for b_i, length in enumerate(lengths):
elem = layer_elems[i0:i0 + length]
if element_name == 'neighbors':
elem[elem >= self.points[layer_i].shape[0]] = -1
elem[elem >= 0] -= i0
elif element_name == 'pools':
elem[elem >= self.points[layer_i].shape[0]] = -1
elem[elem >= 0] -= torch.sum(
self.lengths[layer_i][:b_i])
i0 += length
if to_numpy:
p_list.append(elem.numpy())
else:
p_list.append(elem)
if layer == layer_i:
return p_list
all_p_list.append(p_list)
return all_p_list
class SparseConvUnetBatch:
def __init__(self, batches):
pc = []
feat = []
label = []
lengths = []
for batch in batches:
data = batch['data']
pc.append(data['point'])
feat.append(data['feat'])
label.append(data['label'])
lengths.append(data['point'].shape[0])
self.point = pc
self.feat = feat
self.label = label
self.batch_lengths = lengths
def pin_memory(self):
self.point = [pc.pin_memory() for pc in self.point]
self.feat = [feat.pin_memory() for feat in self.feat]
self.label = [label.pin_memory() for label in self.label]
return self
def to(self, device):
self.point = [pc.to(device) for pc in self.point]
self.feat = [feat.to(device) for feat in self.feat]
self.label = [label.to(device) for label in self.label]
class ObjectDetectBatch:
def __init__(self, batches):
"""Initialize.
Args:
batches: A batch of data
Returns:
class: The corresponding class.
"""
self.point = []
self.labels = []
self.bboxes = []
self.bbox_objs = []
self.calib = []
self.attr = []
for batch in batches:
self.attr.append(batch['attr'])
data = batch['data']
attr = batch['attr']
if 'test' not in attr['split'] and len(
data['bboxes']
) == 0: # Skip training batch with no bounding box.
continue
self.point.append(torch.tensor(data['point'], dtype=torch.float32))
self.labels.append(
torch.tensor(data['labels'], dtype=torch.int64) if 'labels' in
data else None)
self.bboxes.append(
torch.tensor(data['bboxes'], dtype=torch.float32) if 'bboxes' in
data else None)
self.bbox_objs.append(data.get('bbox_objs'))
self.calib.append(data.get('calib'))
def pin_memory(self):
for i in range(len(self.point)):
self.point[i] = self.point[i].pin_memory()
if self.labels[i] is not None:
self.labels[i] = self.labels[i].pin_memory()
if self.bboxes[i] is not None:
self.bboxes[i] = self.bboxes[i].pin_memory()
return self
def to(self, device):
for i in range(len(self.point)):
self.point[i] = self.point[i].to(device)
if self.labels[i] is not None:
self.labels[i] = self.labels[i].to(device)
if self.bboxes[i] is not None:
self.bboxes[i] = self.bboxes[i].to(device)
[docs]class ConcatBatcher(object):
"""ConcatBatcher for KPConv."""
[docs] def __init__(self, device, model='KPConv'):
"""Initialize.
Args:
device: torch device 'gpu' or 'cpu'
Returns:
class: The corresponding class.
"""
super(ConcatBatcher, self).__init__()
self.device = device
self.model = model
[docs] def collate_fn(self, batches):
"""Collate function called by original PyTorch dataloader.
Args:
batches: a batch of data
Returns:
class: the batched result
"""
if self.model == "KPConv" or self.model == "KPFCNN":
batching_result = KPConvBatch(batches)
batching_result.to(self.device)
return {'data': batching_result, 'attr': []}
elif self.model == "SparseConvUnet":
return {'data': SparseConvUnetBatch(batches), 'attr': {}}
elif self.model == "PointPillars" or self.model == "PointRCNN":
batching_result = ObjectDetectBatch(batches)
return batching_result
else:
raise Exception(
f"Please define collate_fn for {self.model}, or use Default Batcher"
)