R3PM-Net / dataloader /user_data.py
YasiiKB's picture
initial commit
97aa5af verified
import os
import numpy as np
import torch
class ClassificationData:
def __init__(self, data_dict):
self.data_dict = data_dict
self.pcs = self.find_attribute('pcs')
self.labels = self.find_attribute('labels')
self.check_data()
def find_attribute(self, attribute):
try:
attribute_data = self.data_dict[attribute]
except:
print("Given data directory has no key attribute \"{}\"".format(attribute))
return attribute_data
def check_data(self):
assert 1 < len(self.pcs.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.pcs.shape)
assert 0 < len(self.labels.shape) < 3, "Error in dimension of labels! Given data dimension: {}".format(self.labels.shape)
if len(self.pcs.shape)==2: self.pcs = self.pcs.reshape(1, -1, 3)
if len(self.labels.shape) == 1: self.labels = self.labels.reshape(1, -1)
assert self.pcs.shape[0] == self.labels.shape[0], "Inconsistency in the number of point clouds and number of ground truth labels!"
def __len__(self):
return self.pcs.shape[0]
def __getitem__(self, index):
return torch.tensor(self.pcs[index]).float(), torch.from_numpy(self.labels[idx]).type(torch.LongTensor)
class RegistrationData:
def __init__(self, data_dict):
self.data_dict = data_dict
self.template = self.find_attribute('template')
self.source = self.find_attribute('source')
self.transformation = self.find_attribute('transformation')
self.check_data()
# def find_attribute(self, attribute):
# try:
# attribute_data = self.data[attribute]
# except:
# print("Given data directory has no key attribute \"{}\"".format(attribute))
# return attribute_data
def find_attribute(self, attribute):
attribute_data = None
if attribute in self.data_dict:
attribute_data = self.data_dict[attribute]
else:
print("Given data directory has no key attribute \"{}\"".format(attribute))
return attribute_data
def check_data(self):
assert 1 < len(self.template.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.template.shape)
assert 1 < len(self.source.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.source.shape)
assert 1 < len(self.transformation.shape) < 4, "Error in dimension of transformations! Given data dimension: {}".format(self.transformation.shape)
if len(self.template.shape)==2: self.template = self.template.reshape(1, -1, 3)
if len(self.source.shape)==2: self.source = self.source.reshape(1, -1, 3)
if len(self.transformation.shape) == 2: self.transformation = self.transformation.reshape(1, 4, 4)
assert self.template.shape[0] == self.source.shape[0], "Inconsistency in the number of template and source point clouds!"
assert self.source.shape[0] == self.transformation.shape[0], "Inconsistency in the number of transformation and source point clouds!"
def __len__(self):
return self.template.shape[0]
def __getitem__(self, index):
return torch.tensor(self.template[index]).float(), torch.tensor(self.source[index]).float(), torch.tensor(self.transformation[index]).float()
class FlowData:
def __init__(self, data_dict):
self.data_dict = data_dict
self.frame1 = self.find_attribute('frame1')
self.frame2 = self.find_attribute('frame2')
self.flow = self.find_attribute('flow')
self.check_data()
def find_attribute(self, attribute):
try:
attribute_data = self.data[attribute]
except:
print("Given data directory has no key attribute \"{}\"".format(attribute))
return attribute_data
def check_data(self):
assert 1 < len(self.frame1.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame1.shape)
assert 1 < len(self.frame2.shape) < 4, "Error in dimension of point clouds! Given data dimension: {}".format(self.frame2.shape)
assert 1 < len(self.flow.shape) < 4, "Error in dimension of flow! Given data dimension: {}".format(self.flow.shape)
if len(self.frame1.shape)==2: self.frame1 = self.frame1.reshape(1, -1, 3)
if len(self.frame2.shape)==2: self.frame2 = self.frame2.reshape(1, -1, 3)
if len(self.flow.shape) == 2: self.flow = self.flow.reshape(1, -1, 3)
assert self.frame1.shape[0] == self.frame2.shape[0], "Inconsistency in the number of frame1 and frame2 point clouds!"
assert self.frame2.shape[0] == self.flow.shape[0], "Inconsistency in the number of flow and frame2 point clouds!"
def __len__(self):
return self.frame1.shape[0]
def __getitem__(self, index):
return torch.tensor(self.frame1[index]).float(), torch.tensor(self.frame2[index]).float(), torch.tensor(self.flow[index]).float()
class UserData:
def __init__(self, application, data_dict):
self.application = application
if self.application == 'classification':
self.data_class = ClassificationData(data_dict)
elif self.application == 'registration':
self.data_class = RegistrationData(data_dict)
elif self.application == 'flow_estimation':
self.data_class = FlowData(data_dict)
def __len__(self):
return len(self.data_class)
def __getitem__(self, index):
return self.data_class[index]