Source code for ml3d.torch.dataloaders.torch_dataloader

from abc import abstractmethod
from tqdm import tqdm
import torch
from torch.multiprocessing import Pool
from torch.utils.data import Dataset
from collections import namedtuple

from ...utils import Cache, get_hash


[docs]class TorchDataloader(Dataset): """This class allows you to load datasets for a PyTorch framework. Example: This example loads the SemanticKITTI dataset using the Torch dataloader: import torch from torch.utils.data import Dataset, DataLoader train_split = TorchDataloader(dataset=dataset.get_split('training')) """
[docs] def __init__(self, dataset=None, preprocess=None, transform=None, sampler=None, use_cache=True, steps_per_epoch=None, **kwargs): """Initialize. Args: dataset: The 3D ML dataset class. You can use the base dataset, sample datasets , or a custom dataset. preprocess: The model's preprocess method. transform: The model's transform method. use_cache: Indicates if preprocessed data should be cached. steps_per_epoch: The number of steps per epoch that indicates the bactches of samples to train. If it is None, then the step number will be the number of samples in the data. Returns: class: The corresponding class. """ self.dataset = dataset self.preprocess = preprocess self.steps_per_epoch = steps_per_epoch if preprocess is not None and use_cache: cache_dir = getattr(dataset.cfg, 'cache_dir') assert cache_dir is not None, 'cache directory is not given' self.cache_convert = Cache(preprocess, cache_dir=cache_dir, cache_key=get_hash(repr(preprocess))) uncached = [ idx for idx in range(len(dataset)) if dataset.get_attr(idx) ['name'] not in self.cache_convert.cached_ids ] if len(uncached) > 0: for idx in tqdm(range(len(dataset)), desc='preprocess'): attr = dataset.get_attr(idx) name = attr['name'] if name in self.cache_convert.cached_ids: continue data = dataset.get_data(idx) # cache the data self.cache_convert(name, data, attr) else: self.cache_convert = None self.transform = transform if sampler is not None: sampler.initialize_with_dataloader(self)
[docs] def __getitem__(self, index): """Returns the item at index position (idx).""" dataset = self.dataset index = index % len(dataset) attr = dataset.get_attr(index) if self.cache_convert: data = self.cache_convert(attr['name']) elif self.preprocess: data = self.preprocess(dataset.get_data(index), attr) else: data = dataset.get_data(index) if self.transform is not None: data = self.transform(data, attr) inputs = {'data': data, 'attr': attr} return inputs
[docs] def __len__(self): """Returns the number of steps for an epoch.""" if self.steps_per_epoch is not None: steps_per_epoch = self.steps_per_epoch else: steps_per_epoch = len(self.dataset) return steps_per_epoch