Source code for ml3d.datasets.samplers.semseg_spatially_regular

import numpy as np
from tqdm import tqdm
import random

from ...utils import SAMPLER


[docs]class SemSegSpatiallyRegularSampler(object): """Spatially regularSampler sampler for semantic segmentation datsets"""
[docs] def __init__(self, dataset): self.dataset = dataset self.length = len(dataset) self.split = self.dataset.split
def __len__(self): return self.length
[docs] def initialize_with_dataloader(self, dataloader): self.min_possibilities = [] self.possibilities = [] self.length = len(dataloader) dataset = self.dataset for index in range(len(dataset)): attr = dataset.get_attr(index) if dataloader.cache_convert: data = dataloader.cache_convert(attr['name']) elif dataloader.preprocess: data = dataloader.preprocess(dataset.get_data(index), attr) else: data = dataset.get_data(index) pc = data['point'] self.possibilities += [np.random.rand(pc.shape[0]) * 1e-3] self.min_possibilities += [float(np.min(self.possibilities[-1]))]
[docs] def get_cloud_sampler(self): def gen_train(): for i in range(self.length): self.cloud_id = int(np.argmin(self.min_possibilities)) yield self.cloud_id def gen_test(): curr_could_id = 0 while curr_could_id < self.length: if self.min_possibilities[curr_could_id] > 0.5: curr_could_id = curr_could_id + 1 continue self.cloud_id = curr_could_id yield self.cloud_id if self.split in ['train', 'validation', 'valid', 'training']: gen = gen_train else: gen = gen_test return gen()
[docs] def get_point_sampler(self): def _random_centered_gen(patchwise=True, **kwargs): if not patchwise: self.possibilities[self.cloud_id][:] = 1. self.min_possibilities[self.cloud_id] = 1. return pc = kwargs.get('pc', None) num_points = kwargs.get('num_points', None) radius = kwargs.get('radius', None) search_tree = kwargs.get('search_tree', None) if pc is None or num_points is None or (search_tree is None and radius is None): raise KeyError( "Please provide pc, num_points, and (search_tree or radius) \ for point_sampler in SemSegSpatiallyRegularSampler") cloud_id = self.cloud_id n = 0 while n < 2: center_id = np.argmin(self.possibilities[cloud_id]) center_point = pc[center_id, :].reshape(1, -1) if radius is not None: idxs = search_tree.query_radius(center_point, r=radius)[0] elif num_points is not None: if (pc.shape[0] < num_points): diff = num_points - pc.shape[0] idxs = np.array(range(pc.shape[0])) idxs = list(idxs) + list(random.choices(idxs, k=diff)) idxs = np.asarray(idxs) else: idxs = search_tree.query(center_point, k=num_points)[1][0] n = len(idxs) if n < 2: self.possibilities[cloud_id][center_id] += 0.001 random.shuffle(idxs) pc = pc[idxs] dists = np.sum(np.square((pc - center_point).astype(np.float32)), axis=1) delta = np.square(1 - dists / np.max(dists)) self.possibilities[cloud_id][idxs] += delta new_min = float(np.min(self.possibilities[cloud_id])) self.min_possibilities[cloud_id] = new_min return pc, idxs, center_point return _random_centered_gen
SAMPLER._register_module(SemSegSpatiallyRegularSampler)