Add File
This commit is contained in:
444
pcdet/datasets/once/once_dataset.py
Normal file
444
pcdet/datasets/once/once_dataset.py
Normal file
@@ -0,0 +1,444 @@
|
||||
import copy
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
|
||||
from ..dataset import DatasetTemplate
|
||||
from ...ops.roiaware_pool3d import roiaware_pool3d_utils
|
||||
from ...utils import box_utils
|
||||
from .once_toolkits import Octopus
|
||||
|
||||
class ONCEDataset(DatasetTemplate):
|
||||
def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
|
||||
"""
|
||||
Args:
|
||||
root_path:
|
||||
dataset_cfg:
|
||||
class_names:
|
||||
training:
|
||||
logger:
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
|
||||
)
|
||||
self.split = dataset_cfg.DATA_SPLIT['train'] if training else dataset_cfg.DATA_SPLIT['test']
|
||||
assert self.split in ['train', 'val', 'test', 'raw_small', 'raw_medium', 'raw_large']
|
||||
|
||||
split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
|
||||
self.sample_seq_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
|
||||
self.cam_names = ['cam01', 'cam03', 'cam05', 'cam06', 'cam07', 'cam08', 'cam09']
|
||||
self.cam_tags = ['top', 'top2', 'left_back', 'left_front', 'right_front', 'right_back', 'back']
|
||||
self.toolkits = Octopus(self.root_path)
|
||||
|
||||
self.once_infos = []
|
||||
self.include_once_data(self.split)
|
||||
|
||||
def include_once_data(self, split):
|
||||
if self.logger is not None:
|
||||
self.logger.info('Loading ONCE dataset')
|
||||
once_infos = []
|
||||
|
||||
for info_path in self.dataset_cfg.INFO_PATH[split]:
|
||||
info_path = self.root_path / info_path
|
||||
if not info_path.exists():
|
||||
continue
|
||||
with open(info_path, 'rb') as f:
|
||||
infos = pickle.load(f)
|
||||
once_infos.extend(infos)
|
||||
|
||||
def check_annos(info):
|
||||
return 'annos' in info
|
||||
|
||||
if self.split != 'raw':
|
||||
once_infos = list(filter(check_annos,once_infos))
|
||||
|
||||
self.once_infos.extend(once_infos)
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.info('Total samples for ONCE dataset: %d' % (len(once_infos)))
|
||||
|
||||
def set_split(self, split):
|
||||
super().__init__(
|
||||
dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, root_path=self.root_path, logger=self.logger
|
||||
)
|
||||
self.split = split
|
||||
|
||||
split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
|
||||
self.sample_seq_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
|
||||
|
||||
def get_lidar(self, sequence_id, frame_id):
|
||||
return self.toolkits.load_point_cloud(sequence_id, frame_id)
|
||||
|
||||
def get_image(self, sequence_id, frame_id, cam_name):
|
||||
return self.toolkits.load_image(sequence_id, frame_id, cam_name)
|
||||
|
||||
def project_lidar_to_image(self, sequence_id, frame_id):
|
||||
return self.toolkits.project_lidar_to_image(sequence_id, frame_id)
|
||||
|
||||
def point_painting(self, points, info):
|
||||
semseg_dir = './' # add your own seg directory
|
||||
used_classes = [0,1,2,3,4,5]
|
||||
num_classes = len(used_classes)
|
||||
frame_id = str(info['frame_id'])
|
||||
seq_id = str(info['sequence_id'])
|
||||
painted = np.zeros((points.shape[0], num_classes)) # classes + bg
|
||||
for cam_name in self.cam_names:
|
||||
img_path = Path(semseg_dir) / Path(seq_id) / Path(cam_name) / Path(frame_id+'_label.png')
|
||||
calib_info = info['calib'][cam_name]
|
||||
cam_2_velo = calib_info['cam_to_velo']
|
||||
cam_intri = np.hstack([calib_info['cam_intrinsic'], np.zeros((3, 1), dtype=np.float32)])
|
||||
point_xyz = points[:, :3]
|
||||
points_homo = np.hstack(
|
||||
[point_xyz, np.ones(point_xyz.shape[0], dtype=np.float32).reshape((-1, 1))])
|
||||
points_lidar = np.dot(points_homo, np.linalg.inv(cam_2_velo).T)
|
||||
mask = points_lidar[:, 2] > 0
|
||||
points_lidar = points_lidar[mask]
|
||||
points_img = np.dot(points_lidar, cam_intri.T)
|
||||
points_img = points_img / points_img[:, [2]]
|
||||
uv = points_img[:, [0,1]]
|
||||
#depth = points_img[:, [2]]
|
||||
seg_map = np.array(Image.open(img_path)) # (H, W)
|
||||
H, W = seg_map.shape
|
||||
seg_feats = np.zeros((H*W, num_classes))
|
||||
seg_map = seg_map.reshape(-1)
|
||||
for cls_i in used_classes:
|
||||
seg_feats[seg_map==cls_i, cls_i] = 1
|
||||
seg_feats = seg_feats.reshape(H, W, num_classes).transpose(2, 0, 1)
|
||||
uv[:, 0] = (uv[:, 0] - W / 2) / (W / 2)
|
||||
uv[:, 1] = (uv[:, 1] - H / 2) / (H / 2)
|
||||
uv_tensor = torch.from_numpy(uv).unsqueeze(0).unsqueeze(0) # [1,1,N,2]
|
||||
seg_feats = torch.from_numpy(seg_feats).unsqueeze(0) # [1,C,H,W]
|
||||
proj_scores = F.grid_sample(seg_feats, uv_tensor, mode='bilinear', padding_mode='zeros') # [1, C, 1, N]
|
||||
proj_scores = proj_scores.squeeze(0).squeeze(1).transpose(0, 1).contiguous() # [N, C]
|
||||
painted[mask] = proj_scores.numpy()
|
||||
return np.concatenate([points, painted], axis=1)
|
||||
|
||||
def __len__(self):
|
||||
if self._merge_all_iters_to_one_epoch:
|
||||
return len(self.once_infos) * self.total_epochs
|
||||
|
||||
return len(self.once_infos)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self._merge_all_iters_to_one_epoch:
|
||||
index = index % len(self.once_infos)
|
||||
|
||||
info = copy.deepcopy(self.once_infos[index])
|
||||
frame_id = info['frame_id']
|
||||
seq_id = info['sequence_id']
|
||||
points = self.get_lidar(seq_id, frame_id)
|
||||
|
||||
if self.dataset_cfg.get('POINT_PAINTING', False):
|
||||
points = self.point_painting(points, info)
|
||||
|
||||
input_dict = {
|
||||
'points': points,
|
||||
'frame_id': frame_id,
|
||||
}
|
||||
|
||||
if 'annos' in info:
|
||||
annos = info['annos']
|
||||
input_dict.update({
|
||||
'gt_names': annos['name'],
|
||||
'gt_boxes': annos['boxes_3d'],
|
||||
'num_points_in_gt': annos.get('num_points_in_gt', None)
|
||||
})
|
||||
|
||||
data_dict = self.prepare_data(data_dict=input_dict)
|
||||
data_dict.pop('num_points_in_gt', None)
|
||||
return data_dict
|
||||
|
||||
def get_infos(self, num_workers=4, sample_seq_list=None):
|
||||
import concurrent.futures as futures
|
||||
import json
|
||||
root_path = self.root_path
|
||||
cam_names = self.cam_names
|
||||
|
||||
"""
|
||||
# dataset json format
|
||||
{
|
||||
'meta_info':
|
||||
'calib': {
|
||||
'cam01': {
|
||||
'cam_to_velo': list
|
||||
'cam_intrinsic': list
|
||||
'distortion': list
|
||||
}
|
||||
...
|
||||
}
|
||||
'frames': [
|
||||
{
|
||||
'frame_id': timestamp,
|
||||
'annos': {
|
||||
'names': list
|
||||
'boxes_3d': list of list
|
||||
'boxes_2d': {
|
||||
'cam01': list of list
|
||||
...
|
||||
}
|
||||
}
|
||||
'pose': list
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
# open pcdet format
|
||||
{
|
||||
'meta_info':
|
||||
'sequence_id': seq_idx
|
||||
'frame_id': timestamp
|
||||
'timestamp': timestamp
|
||||
'lidar': path
|
||||
'cam01': path
|
||||
...
|
||||
'calib': {
|
||||
'cam01': {
|
||||
'cam_to_velo': np.array
|
||||
'cam_intrinsic': np.array
|
||||
'distortion': np.array
|
||||
}
|
||||
...
|
||||
}
|
||||
'pose': np.array
|
||||
'annos': {
|
||||
'name': np.array
|
||||
'boxes_3d': np.array
|
||||
'boxes_2d': {
|
||||
'cam01': np.array
|
||||
....
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
def process_single_sequence(seq_idx):
|
||||
print('%s seq_idx: %s' % (self.split, seq_idx))
|
||||
seq_infos = []
|
||||
seq_path = Path(root_path) / 'data' / seq_idx
|
||||
json_path = seq_path / ('%s.json' % seq_idx)
|
||||
with open(json_path, 'r') as f:
|
||||
info_this_seq = json.load(f)
|
||||
meta_info = info_this_seq['meta_info']
|
||||
calib = info_this_seq['calib']
|
||||
for f_idx, frame in enumerate(info_this_seq['frames']):
|
||||
frame_id = frame['frame_id']
|
||||
if f_idx == 0:
|
||||
prev_id = None
|
||||
else:
|
||||
prev_id = info_this_seq['frames'][f_idx-1]['frame_id']
|
||||
if f_idx == len(info_this_seq['frames'])-1:
|
||||
next_id = None
|
||||
else:
|
||||
next_id = info_this_seq['frames'][f_idx+1]['frame_id']
|
||||
pc_path = str(seq_path / 'lidar_roof' / ('%s.bin' % frame_id))
|
||||
pose = np.array(frame['pose'])
|
||||
frame_dict = {
|
||||
'sequence_id': seq_idx,
|
||||
'frame_id': frame_id,
|
||||
'timestamp': int(frame_id),
|
||||
'prev_id': prev_id,
|
||||
'next_id': next_id,
|
||||
'meta_info': meta_info,
|
||||
'lidar': pc_path,
|
||||
'pose': pose
|
||||
}
|
||||
calib_dict = {}
|
||||
for cam_name in cam_names:
|
||||
cam_path = str(seq_path / cam_name / ('%s.jpg' % frame_id))
|
||||
frame_dict.update({cam_name: cam_path})
|
||||
calib_dict[cam_name] = {}
|
||||
calib_dict[cam_name]['cam_to_velo'] = np.array(calib[cam_name]['cam_to_velo'])
|
||||
calib_dict[cam_name]['cam_intrinsic'] = np.array(calib[cam_name]['cam_intrinsic'])
|
||||
calib_dict[cam_name]['distortion'] = np.array(calib[cam_name]['distortion'])
|
||||
frame_dict.update({'calib': calib_dict})
|
||||
|
||||
if 'annos' in frame:
|
||||
annos = frame['annos']
|
||||
boxes_3d = np.array(annos['boxes_3d'])
|
||||
if boxes_3d.shape[0] == 0:
|
||||
print(frame_id)
|
||||
continue
|
||||
boxes_2d_dict = {}
|
||||
for cam_name in cam_names:
|
||||
boxes_2d_dict[cam_name] = np.array(annos['boxes_2d'][cam_name])
|
||||
annos_dict = {
|
||||
'name': np.array(annos['names']),
|
||||
'boxes_3d': boxes_3d,
|
||||
'boxes_2d': boxes_2d_dict
|
||||
}
|
||||
|
||||
points = self.get_lidar(seq_idx, frame_id)
|
||||
corners_lidar = box_utils.boxes_to_corners_3d(np.array(annos['boxes_3d']))
|
||||
num_gt = boxes_3d.shape[0]
|
||||
num_points_in_gt = -np.ones(num_gt, dtype=np.int32)
|
||||
for k in range(num_gt):
|
||||
flag = box_utils.in_hull(points[:, 0:3], corners_lidar[k])
|
||||
num_points_in_gt[k] = flag.sum()
|
||||
annos_dict['num_points_in_gt'] = num_points_in_gt
|
||||
|
||||
frame_dict.update({'annos': annos_dict})
|
||||
seq_infos.append(frame_dict)
|
||||
return seq_infos
|
||||
|
||||
sample_seq_list = sample_seq_list if sample_seq_list is not None else self.sample_seq_list
|
||||
with futures.ThreadPoolExecutor(num_workers) as executor:
|
||||
infos = executor.map(process_single_sequence, sample_seq_list)
|
||||
all_infos = []
|
||||
for info in infos:
|
||||
all_infos.extend(info)
|
||||
return all_infos
|
||||
|
||||
def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'):
|
||||
import torch
|
||||
|
||||
database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
|
||||
db_info_save_path = Path(self.root_path) / ('once_dbinfos_%s.pkl' % split)
|
||||
|
||||
database_save_path.mkdir(parents=True, exist_ok=True)
|
||||
all_db_infos = {}
|
||||
|
||||
with open(info_path, 'rb') as f:
|
||||
infos = pickle.load(f)
|
||||
|
||||
for k in range(len(infos)):
|
||||
if 'annos' not in infos[k]:
|
||||
continue
|
||||
print('gt_database sample: %d' % (k + 1))
|
||||
info = infos[k]
|
||||
frame_id = info['frame_id']
|
||||
seq_id = info['sequence_id']
|
||||
points = self.get_lidar(seq_id, frame_id)
|
||||
annos = info['annos']
|
||||
names = annos['name']
|
||||
gt_boxes = annos['boxes_3d']
|
||||
|
||||
num_obj = gt_boxes.shape[0]
|
||||
point_indices = roiaware_pool3d_utils.points_in_boxes_cpu(
|
||||
torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes)
|
||||
).numpy() # (nboxes, npoints)
|
||||
|
||||
for i in range(num_obj):
|
||||
filename = '%s_%s_%d.bin' % (frame_id, names[i], i)
|
||||
filepath = database_save_path / filename
|
||||
gt_points = points[point_indices[i] > 0]
|
||||
|
||||
gt_points[:, :3] -= gt_boxes[i, :3]
|
||||
with open(filepath, 'w') as f:
|
||||
gt_points.tofile(f)
|
||||
|
||||
db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
|
||||
db_info = {'name': names[i], 'path': db_path, 'gt_idx': i,
|
||||
'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0]}
|
||||
if names[i] in all_db_infos:
|
||||
all_db_infos[names[i]].append(db_info)
|
||||
else:
|
||||
all_db_infos[names[i]] = [db_info]
|
||||
|
||||
for k, v in all_db_infos.items():
|
||||
print('Database %s: %d' % (k, len(v)))
|
||||
|
||||
with open(db_info_save_path, 'wb') as f:
|
||||
pickle.dump(all_db_infos, f)
|
||||
|
||||
@staticmethod
|
||||
def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
|
||||
def get_template_prediction(num_samples):
|
||||
ret_dict = {
|
||||
'name': np.zeros(num_samples), 'score': np.zeros(num_samples),
|
||||
'boxes_3d': np.zeros((num_samples, 7))
|
||||
}
|
||||
return ret_dict
|
||||
|
||||
def generate_single_sample_dict(box_dict):
|
||||
pred_scores = box_dict['pred_scores'].cpu().numpy()
|
||||
pred_boxes = box_dict['pred_boxes'].cpu().numpy()
|
||||
pred_labels = box_dict['pred_labels'].cpu().numpy()
|
||||
pred_dict = get_template_prediction(pred_scores.shape[0])
|
||||
if pred_scores.shape[0] == 0:
|
||||
return pred_dict
|
||||
|
||||
pred_dict['name'] = np.array(class_names)[pred_labels - 1]
|
||||
pred_dict['score'] = pred_scores
|
||||
pred_dict['boxes_3d'] = pred_boxes
|
||||
return pred_dict
|
||||
|
||||
annos = []
|
||||
for index, box_dict in enumerate(pred_dicts):
|
||||
frame_id = batch_dict['frame_id'][index]
|
||||
single_pred_dict = generate_single_sample_dict(box_dict)
|
||||
single_pred_dict['frame_id'] = frame_id
|
||||
annos.append(single_pred_dict)
|
||||
|
||||
if output_path is not None:
|
||||
raise NotImplementedError
|
||||
return annos
|
||||
|
||||
def evaluation(self, det_annos, class_names, **kwargs):
|
||||
from .once_eval.evaluation import get_evaluation_results
|
||||
|
||||
eval_det_annos = copy.deepcopy(det_annos)
|
||||
eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.once_infos]
|
||||
ap_result_str, ap_dict = get_evaluation_results(eval_gt_annos, eval_det_annos, class_names)
|
||||
|
||||
return ap_result_str, ap_dict
|
||||
|
||||
def create_once_infos(dataset_cfg, class_names, data_path, save_path, workers=4):
|
||||
dataset = ONCEDataset(dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False)
|
||||
|
||||
splits = ['train', 'val', 'test', 'raw_small', 'raw_medium', 'raw_large']
|
||||
ignore = ['test']
|
||||
|
||||
print('---------------Start to generate data infos---------------')
|
||||
for split in splits:
|
||||
if split in ignore:
|
||||
continue
|
||||
|
||||
filename = 'once_infos_%s.pkl' % split
|
||||
filename = save_path / Path(filename)
|
||||
dataset.set_split(split)
|
||||
once_infos = dataset.get_infos(num_workers=workers)
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(once_infos, f)
|
||||
print('ONCE info %s file is saved to %s' % (split, filename))
|
||||
|
||||
train_filename = save_path / 'once_infos_train.pkl'
|
||||
print('---------------Start create groundtruth database for data augmentation---------------')
|
||||
dataset.set_split('train')
|
||||
dataset.create_groundtruth_database(train_filename, split='train')
|
||||
print('---------------Data preparation Done---------------')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='arg parser')
|
||||
parser.add_argument('--cfg_file', type=str, default=None, help='specify the config of dataset')
|
||||
parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
|
||||
parser.add_argument('--runs_on', type=str, default='server', help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.func == 'create_once_infos':
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from easydict import EasyDict
|
||||
dataset_cfg = EasyDict(yaml.load(open(args.cfg_file)))
|
||||
|
||||
|
||||
ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
|
||||
once_data_path = ROOT_DIR / 'data' / 'once'
|
||||
once_save_path = ROOT_DIR / 'data' / 'once'
|
||||
|
||||
if args.runs_on == 'cloud':
|
||||
once_data_path = Path('/cache/once/')
|
||||
once_save_path = Path('/cache/once/')
|
||||
dataset_cfg.DATA_PATH = dataset_cfg.CLOUD_DATA_PATH
|
||||
|
||||
create_once_infos(
|
||||
dataset_cfg=dataset_cfg,
|
||||
class_names=['Car', 'Bus', 'Truck', 'Pedestrian', 'Bicycle'],
|
||||
data_path=once_data_path,
|
||||
save_path=once_save_path
|
||||
)
|
||||
Reference in New Issue
Block a user