Source code for ml3d.torch.models.point_transformer

import numpy as np
import torch
import torch.nn as nn
import torch.utils.dlpack
import cloudViewer.core as o3c

from sklearn.neighbors import KDTree
from cloudViewer.ml.torch.ops import knn_search

from .base_model import BaseModel
from ...utils import MODEL
from ..modules.losses import filter_valid_label
from ...datasets.augment import SemsegAugmentation
from ...datasets.utils import DataProcessing
from ..utils.pointnet.pointnet2_utils import furthest_point_sample_v2


[docs]class PointTransformer(BaseModel): """Semantic Segmentation model. Based on PointTransformer architecture https://arxiv.org/pdf/2012.09164.pdf Uses Encoder-Decoder architecture with Transformer layers. Attributes: name: Name of model. Default to "PointTransformer". blocks: Number of Bottleneck layers. in_channels: Number of features(default 6). num_classes: Number of classes. voxel_size: Voxel length for subsampling. max_voxels: Maximum number of voxels. batcher: Batching method for dataloader. augment: dictionary for augmentation. """
[docs] def __init__(self, name="PointTransformer", blocks=[2, 2, 2, 2, 2], in_channels=6, num_classes=13, voxel_size=0.04, max_voxels=80000, batcher='ConcatBatcher', augment=None, **kwargs): super(PointTransformer, self).__init__(name=name, blocks=blocks, in_channels=in_channels, num_classes=num_classes, voxel_size=voxel_size, max_voxels=max_voxels, batcher=batcher, augment=augment, **kwargs) cfg = self.cfg self.in_channels = in_channels self.augmenter = SemsegAugmentation(cfg.augment) self.in_planes, planes = in_channels, [32, 64, 128, 256, 512] fpn_planes, fpnhead_planes, share_planes = 128, 64, 8 stride, nsample = [1, 4, 4, 4, 4], [8, 16, 16, 16, 16] block = Bottleneck self.encoders = nn.ModuleList() for i in range(5): self.encoders.append( self._make_enc( block, planes[i], blocks[i], share_planes, stride=stride[i], nsample=nsample[i])) # N/1, N/4, N/16, N/64, N/256 self.decoders = nn.ModuleList() for i in range(4, -1, -1): self.decoders.append( self._make_dec(block, planes[i], 2, share_planes, nsample=nsample[i], is_head=True if i == 4 else False)) self.cls = nn.Sequential(nn.Linear(planes[0], planes[0]), nn.BatchNorm1d(planes[0]), nn.ReLU(inplace=True), nn.Linear(planes[0], num_classes))
def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16): """Private method to create encoder. Args: block: Bottleneck block consisting transformer layers. planes: list of feature dimension. blocks: Number of `block` layers. share_planes: Number of common planes for transformer. stride: stride for pooling. nsample: number of neighbour to sample. Returns: Returns encoder object. """ layers = [] layers.append( TransitionDown(self.in_planes, planes * block.expansion, stride, nsample)) self.in_planes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.in_planes, self.in_planes, share_planes, nsample=nsample)) return nn.Sequential(*layers) def _make_dec(self, block, planes, blocks, share_planes=8, nsample=16, is_head=False): """Private method to create decoder. Args: block: Bottleneck block consisting transformer layers. planes: list of feature dimension. blocks: Number of `block` layers. share_planes: Number of common planes for transformer. nsample: number of neighbour to sample. is_head: bool type for head layer. Returns: Returns decoder object. """ layers = [] layers.append( TransitionUp(self.in_planes, None if is_head else planes * block.expansion)) self.in_planes = planes * block.expansion for _ in range(1, blocks): layers.append( block(self.in_planes, self.in_planes, share_planes, nsample=nsample)) return nn.Sequential(*layers)
[docs] def forward(self, batch): """Forward pass for the model. Args: inputs: A dict object for inputs with following keys point (tf.float32): Input pointcloud (N,3) feat (tf.float32): Input features (N, 3) row_splits (tf.int64): row splits for batches (b+1,) Returns: Returns the probability distribution. """ points = [batch.point] # (n, 3) feats = [batch.feat] # (n, c) row_splits = [batch.row_splits] # (b) feats[0] = points[0] if self.in_channels == 3 else torch.cat( (points[0], feats[0]), 1) for i in range(5): p, f, r = self.encoders[i]([points[i], feats[i], row_splits[i]]) points.append(p) feats.append(f) row_splits.append(r) for i in range(4, -1, -1): if i == 4: feats[i + 1] = self.decoders[4 - i][1:]([ points[i + 1], self.decoders[4 - i][0]( [points[i + 1], feats[i + 1], row_splits[i + 1]]), row_splits[i + 1] ])[1] else: feats[i + 1] = self.decoders[4 - i][1:]([ points[i + 1], self.decoders[4 - i][0]( [points[i + 1], feats[i + 1], row_splits[i + 1]], [points[i + 2], feats[i + 2], row_splits[i + 2]]), row_splits[i + 1] ])[1] feat = self.cls(feats[1]) return feat
[docs] def preprocess(self, data, attr): """Data preprocessing function. This function is called before training to preprocess the data from a dataset. It consists of subsampling pointcloud with voxelization. Args: data: A sample from the dataset. attr: The corresponding attributes. Returns: Returns the preprocessed data """ cfg = self.cfg points = np.array(data['point'], dtype=np.float32) if data.get('label') is None: labels = np.zeros((points.shape[0],), dtype=np.int32) else: labels = np.array(data['label'], dtype=np.int32).reshape((-1,)) if data.get('feat') is None: feat = None else: feat = np.array(data['feat'], dtype=np.float32) data = dict() if (cfg.voxel_size): points_min = np.min(points, 0) points -= points_min if (feat is None): sub_points, sub_labels = DataProcessing.grid_subsampling( points, labels=labels, grid_size=cfg.voxel_size) sub_feat = None else: sub_points, sub_feat, sub_labels = DataProcessing.grid_subsampling( points, features=feat, labels=labels, grid_size=cfg.voxel_size) else: sub_points, sub_feat, sub_labels = points, feat, labels search_tree = KDTree(sub_points) data['point'] = sub_points data['feat'] = sub_feat data['label'] = sub_labels data['search_tree'] = search_tree if attr['split'] in ["test", "testing"]: proj_inds = np.squeeze( search_tree.query(points, return_distance=False)) proj_inds = proj_inds.astype(np.int32) data['proj_inds'] = proj_inds return data
[docs] def transform(self, data, attr): """Transform function for the point cloud and features. This function is called after preprocess method. It consists of calling augmentation and normalizing the pointcloud. Args: data: A sample from the dataset. attr: The corresponding attributes. Returns: Returns dictionary data with keys (point, feat, label). """ cfg = self.cfg points = data['point'] feat = data['feat'] labels = data['label'] if attr['split'] in ['training', 'train']: points, feat, labels = self.augmenter.augment( points, feat, labels, self.cfg.get('augment', None)) if attr['split'] not in ['test', 'testing']: if cfg.max_voxels and data['label'].shape[0] > cfg.max_voxels: init_idx = np.random.randint( labels.shape[0] ) if 'train' in attr['split'] else labels.shape[0] // 2 crop_idx = np.argsort( np.sum(np.square(points - points[init_idx]), 1))[:cfg.max_voxels] if feat is not None: points, feat, labels = points[crop_idx], feat[ crop_idx], labels[crop_idx] else: points, labels = points[crop_idx], labels[crop_idx] points_min, points_max = np.min(points, 0), np.max(points, 0) points -= (points_min + points_max) / 2.0 data['point'] = torch.from_numpy(points).to(torch.float32) if feat is not None: data['feat'] = torch.from_numpy(feat).to(torch.float32) / 255.0 data['label'] = torch.from_numpy(labels).to(torch.int64) return data
[docs] def update_probs(self, inputs, results, test_probs): result = results.reshape(-1, self.cfg.num_classes) probs = torch.nn.functional.softmax(result, dim=-1).cpu().data.numpy() self.trans_point_sampler(patchwise=False) return probs
[docs] def inference_begin(self): data = self.preprocess(data, {'split': 'test'}) data = self.transform(data, {'split': 'test'}) self.inference_input = data
[docs] def inference_preprocess(self): return self.inference_input
[docs] def inference_end(self, inputs, results): results = torch.reshape(results, (-1, self.cfg.num_classes)) m_softmax = torch.nn.Softmax(dim=-1) results = m_softmax(results) results = results.cpu().data.numpy() probs = np.reshape(results, [-1, self.cfg.num_classes]) reproj_inds = self.inference_input['proj_inds'] probs = probs[reproj_inds] pred_l = np.argmax(probs, 1) return {'predict_labels': pred_l, 'predict_scores': probs}
[docs] def get_loss(self, sem_seg_loss, results, inputs, device): """Calculate the loss on output of the model. Args: sem_seg_loss: Object of type `SemSegLoss`. results: Output of the model. inputs: Input of the model. device: device(cpu or cuda). Returns: Returns loss, labels and scores. """ cfg = self.cfg labels = inputs['data'].label scores, labels = filter_valid_label(results, labels, cfg.num_classes, cfg.ignored_label_inds, device) loss = sem_seg_loss.weighted_CrossEntropyLoss(scores, labels) return loss, labels, scores
[docs] def get_optimizer(self, cfg_pipeline): optimizer = torch.optim.SGD(self.parameters(), **cfg_pipeline.optimizer) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[ int(cfg_pipeline.max_epoch * 0.6), int(cfg_pipeline.max_epoch * 0.8) ], gamma=0.1) return optimizer, scheduler
MODEL._register_module(PointTransformer, 'torch') class Transformer(nn.Module): """Transformer layer of the model, uses self attention.""" def __init__(self, in_planes, out_planes, share_planes=8, nsample=16): """Constructor for Transformer Layer. Args: in_planes (int): Number of input planes. out_planes (int): Number of output planes. share_planes (int): Number of shared planes. nsample (int): Number of neighbours. """ super().__init__() self.mid_planes = mid_planes = out_planes // 1 self.out_planes = out_planes self.share_planes = share_planes self.nsample = nsample self.linear_q = nn.Linear(in_planes, mid_planes) self.linear_k = nn.Linear(in_planes, mid_planes) self.linear_v = nn.Linear(in_planes, out_planes) self.linear_p = nn.Sequential( nn.Linear(3, 3), nn.BatchNorm1d(3), nn.ReLU(inplace=True), nn.Linear(3, out_planes), ) self.linear_w = nn.Sequential( nn.BatchNorm1d(mid_planes), nn.ReLU(inplace=True), nn.Linear(mid_planes, mid_planes // share_planes), nn.BatchNorm1d(mid_planes // share_planes), nn.ReLU(inplace=True), nn.Linear(out_planes // share_planes, out_planes // share_planes), # Verify ) self.softmax = nn.Softmax(dim=1) def forward(self, pxo): """Forward call for Transformer. Args: pxo: [point, feat, row_splits] with shapes (n, 3), (n, c) and (b+1,) Returns: Transformer features. """ point, feat, row_splits = pxo # (n, 3), (n, c), (b) feat_q, feat_k, feat_v = self.linear_q(feat), self.linear_k( feat), self.linear_v(feat) # (n, c) feat_k = queryandgroup(self.nsample, point, point, feat_k, None, row_splits, row_splits, use_xyz=True) # (n, nsample, 3+c) feat_v = queryandgroup(self.nsample, point, point, feat_v, None, row_splits, row_splits, use_xyz=False) # (n, nsample, c) point_r, feat_k = feat_k[:, :, 0:3], feat_k[:, :, 3:] for i, layer in enumerate(self.linear_p): point_r = layer(point_r.transpose(1, 2).contiguous()).transpose( 1, 2).contiguous() if i == 1 else layer( point_r) # (n, nsample, c) w = feat_k - feat_q.unsqueeze(1) + point_r.view( point_r.shape[0], point_r.shape[1], self.out_planes // self.mid_planes, self.mid_planes).sum(2) # (n, nsample, c) for i, layer in enumerate(self.linear_w): w = layer(w.transpose(1, 2).contiguous()).transpose( 1, 2).contiguous() if i % 3 == 0 else layer(w) w = self.softmax(w) # (n, nsample, c) n, nsample, c = feat_v.shape s = self.share_planes feat = ((feat_v + point_r).view(n, nsample, s, c // s) * w.unsqueeze(2)).sum(1).view(n, c) return feat class TransitionDown(nn.Module): """TransitionDown layer for PointTransformer. Subsamples points and increase receptive field. """ def __init__(self, in_planes, out_planes, stride=1, nsample=16): """Constructor for TransitionDown Layer. Args: in_planes (int): Number of input planes. out_planes (int): Number of output planes. stride (int): subsampling factor. nsample (int): Number of neighbours. """ super().__init__() self.stride, self.nsample = stride, nsample if stride != 1: self.linear = nn.Linear(3 + in_planes, out_planes, bias=False) self.pool = nn.MaxPool1d(nsample) else: self.linear = nn.Linear(in_planes, out_planes, bias=False) self.bn = nn.BatchNorm1d(out_planes) self.relu = nn.ReLU(inplace=True) def forward(self, pxo): """Forward call for TransitionDown Args: pxo: [point, feat, row_splits] with shapes (n, 3), (n, c) and (b+1,) Returns: List of point, feat, row_splits. """ point, feat, row_splits = pxo # (n, 3), (n, c), (b+1) if self.stride != 1: new_row_splits = [0] count = 0 for i in range(1, row_splits.shape[0]): count += (row_splits[i].item() - row_splits[i - 1].item()) // self.stride new_row_splits.append(count) new_row_splits = torch.LongTensor(new_row_splits).to( row_splits.device) idx = furthest_point_sample_v2(point, row_splits, new_row_splits) # (m) new_point = point[idx.long(), :] # (m, 3) feat = queryandgroup(self.nsample, point, new_point, feat, None, row_splits, new_row_splits, use_xyz=True) # (m, nsample, 3+c) feat = self.relu( self.bn(self.linear(feat).transpose( 1, 2).contiguous())) # (m, c, nsample) feat = self.pool(feat).squeeze(-1) # (m, c) point, row_splits = new_point, new_row_splits else: feat = self.relu(self.bn(self.linear(feat))) # (n, c) return [point, feat, row_splits] class TransitionUp(nn.Module): """Decoder layer for PointTransformer. Interpolate points based on corresponding encoder layer. """ def __init__(self, in_planes, out_planes=None): """Constructor for TransitionUp Layer. Args: in_planes (int): Number of input planes. out_planes (int): Number of output planes. """ super().__init__() if out_planes is None: self.linear1 = nn.Sequential(nn.Linear(2 * in_planes, in_planes), nn.BatchNorm1d(in_planes), nn.ReLU(inplace=True)) self.linear2 = nn.Sequential(nn.Linear(in_planes, in_planes), nn.ReLU(inplace=True)) else: self.linear1 = nn.Sequential(nn.Linear(out_planes, out_planes), nn.BatchNorm1d(out_planes), nn.ReLU(inplace=True)) self.linear2 = nn.Sequential(nn.Linear(in_planes, out_planes), nn.BatchNorm1d(out_planes), nn.ReLU(inplace=True)) def forward(self, pxo1, pxo2=None): """Forward call for TransitionUp Args: pxo1: [point, feat, row_splits] with shapes (n, 3), (n, c) and (b+1,) pxo2: [point, feat, row_splits] with shapes (n, 3), (n, c) and (b+1,) Returns: Interpolated features. """ if pxo2 is None: _, feat, row_splits = pxo1 # (n, 3), (n, c), (b) feat_tmp = [] for i in range(0, row_splits.shape[0] - 1): start_i, end_i, count = row_splits[i], row_splits[ i + 1], row_splits[i + 1] - row_splits[i] feat_b = feat[start_i:end_i, :] feat_b = torch.cat( (feat_b, self.linear2(feat_b.sum(0, True) / count).repeat( count, 1)), 1) feat_tmp.append(feat_b) feat = torch.cat(feat_tmp, 0) feat = self.linear1(feat) else: point_1, feat_1, row_splits_1 = pxo1 point_2, feat_2, row_splits_2 = pxo2 feat = self.linear1(feat_1) + interpolation( point_2, point_1, self.linear2(feat_2), row_splits_2, row_splits_1) return feat class Bottleneck(nn.Module): """Bottleneck layer for PointTransformer. Block of layers using Transformer layer as building block. """ expansion = 1 def __init__(self, in_planes, planes, share_planes=8, nsample=16): """Constructor for Bottleneck Layer. Args: in_planes (int): Number of input planes. planes (int): Number of output planes. share_planes (int): Number of shared planes. nsample (int): Number of neighbours. """ super(Bottleneck, self).__init__() self.linear1 = nn.Linear(in_planes, planes, bias=False) self.bn1 = nn.BatchNorm1d(planes) self.transformer2 = Transformer(planes, planes, share_planes, nsample) self.bn2 = nn.BatchNorm1d(planes) self.linear3 = nn.Linear(planes, planes * self.expansion, bias=False) self.bn3 = nn.BatchNorm1d(planes * self.expansion) self.relu = nn.ReLU(inplace=True) def forward(self, pxo): """Forward call for Bottleneck Args: pxo: [point, feat, row_splits] with shapes (n, 3), (n, c) and (b+1,) Returns: List of point, feat, row_splits. """ point, feat, row_splits = pxo # (n, 3), (n, c), (b) identity = feat feat = self.relu(self.bn1(self.linear1(feat))) feat = self.relu(self.bn2(self.transformer2([point, feat, row_splits]))) feat = self.bn3(self.linear3(feat)) feat += identity feat = self.relu(feat) return [point, feat, row_splits] def queryandgroup(nsample, points, queries, feat, idx, points_row_splits, queries_row_splits, use_xyz=True): """Find nearest neighbours and returns grouped features. Args: nsample: Number of neighbours (k). points: Input pointcloud (n, 3). queries: Queries for Knn (m, 3). feat: features (n, c). idx: Optional knn index list. points_row_splits: row_splits for batching points. queries_row_splits: row_splits for batching queries. use_xyz: Whether to return xyz concatenated with features. Returns: Returns grouped features (m, nsample, c) or (m, nsample, 3+c). """ if not (points.is_contiguous and queries.is_contiguous() and feat.is_contiguous()): raise ValueError("queryandgroup (points/queries/feat not contiguous)") if queries is None: queries = points if idx is None: idx = knn_batch(points, queries, k=nsample, points_row_splits=points_row_splits, queries_row_splits=queries_row_splits, return_distances=False) n, m, c = points.shape[0], queries.shape[0], feat.shape[1] grouped_xyz = points[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3) grouped_xyz -= queries.unsqueeze(1) # (m, nsample, 3) grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c) if use_xyz: return torch.cat((grouped_xyz, grouped_feat), -1) # (m, nsample, 3+c) else: return grouped_feat def knn_batch(points, queries, k, points_row_splits, queries_row_splits, return_distances=True): """K nearest neighbour with batch support. Args: points: Input pointcloud. queries: Queries for Knn. k: Number of neighbours. points_row_splits: row_splits for batching points. queries_row_splits: row_splits for batching queries. return_distances: Whether to return distance with neighbours. """ if points_row_splits.shape[0] != queries_row_splits.shape[0]: raise ValueError("KNN(points and queries must have same batch size)") points = points.cpu() queries = queries.cpu() # ml3d knn. ans = knn_search(points, queries, k=k, points_row_splits=points_row_splits, queries_row_splits=queries_row_splits, return_distances=True) if return_distances: return ans.neighbors_index.reshape( -1, k).long().cuda(), ans.neighbors_distance.reshape(-1, k).cuda() else: return ans.neighbors_index.reshape(-1, k).long().cuda() def interpolation(points, queries, feat, points_row_splits, queries_row_splits, k=3): """Interpolation of features with nearest neighbours. Args: points: Input pointcloud (m, 3). queries: Queries for Knn (n, 3). feat: features (m, c). points_row_splits: row_splits for batching points. queries_row_splits: row_splits for batching queries. k: Number of neighbours. Returns: Returns interpolated features (n, c). """ if not (points.is_contiguous and queries.is_contiguous() and feat.is_contiguous()): raise ValueError("Interpolation (points/queries/feat not contiguous)") idx, dist = knn_batch(points, queries, k=k, points_row_splits=points_row_splits, queries_row_splits=queries_row_splits, return_distances=True) # (n, k), (n, k) idx, dist = idx.reshape(-1, k), dist.reshape(-1, k) dist_recip = 1.0 / (dist + 1e-8) # (n, k) norm = torch.sum(dist_recip, dim=1, keepdim=True) weight = dist_recip / norm # (n, k) new_feat = torch.FloatTensor(queries.shape[0], feat.shape[1]).zero_().to(feat.device) for i in range(k): new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) return new_feat