Add File
This commit is contained in:
112
tools/demo.py
Normal file
112
tools/demo.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import argparse
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import open3d
|
||||
from visual_utils import open3d_vis_utils as V
|
||||
OPEN3D_FLAG = True
|
||||
except:
|
||||
import mayavi.mlab as mlab
|
||||
from visual_utils import visualize_utils as V
|
||||
OPEN3D_FLAG = False
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pcdet.config import cfg, cfg_from_yaml_file
|
||||
from pcdet.datasets import DatasetTemplate
|
||||
from pcdet.models import build_network, load_data_to_gpu
|
||||
from pcdet.utils import common_utils
|
||||
|
||||
|
||||
class DemoDataset(DatasetTemplate):
|
||||
def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None, ext='.bin'):
|
||||
"""
|
||||
Args:
|
||||
root_path:
|
||||
dataset_cfg:
|
||||
class_names:
|
||||
training:
|
||||
logger:
|
||||
"""
|
||||
super().__init__(
|
||||
dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
|
||||
)
|
||||
self.root_path = root_path
|
||||
self.ext = ext
|
||||
data_file_list = glob.glob(str(root_path / f'*{self.ext}')) if self.root_path.is_dir() else [self.root_path]
|
||||
|
||||
data_file_list.sort()
|
||||
self.sample_file_list = data_file_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_file_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.ext == '.bin':
|
||||
points = np.fromfile(self.sample_file_list[index], dtype=np.float32).reshape(-1, 4)
|
||||
elif self.ext == '.npy':
|
||||
points = np.load(self.sample_file_list[index])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
input_dict = {
|
||||
'points': points,
|
||||
'frame_id': index,
|
||||
}
|
||||
|
||||
data_dict = self.prepare_data(data_dict=input_dict)
|
||||
return data_dict
|
||||
|
||||
|
||||
def parse_config():
|
||||
parser = argparse.ArgumentParser(description='arg parser')
|
||||
parser.add_argument('--cfg_file', type=str, default='cfgs/kitti_models/second.yaml',
|
||||
help='specify the config for demo')
|
||||
parser.add_argument('--data_path', type=str, default='demo_data',
|
||||
help='specify the point cloud data file or directory')
|
||||
parser.add_argument('--ckpt', type=str, default=None, help='specify the pretrained model')
|
||||
parser.add_argument('--ext', type=str, default='.bin', help='specify the extension of your point cloud data file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
cfg_from_yaml_file(args.cfg_file, cfg)
|
||||
|
||||
return args, cfg
|
||||
|
||||
|
||||
def main():
|
||||
args, cfg = parse_config()
|
||||
logger = common_utils.create_logger()
|
||||
logger.info('-----------------Quick Demo of OpenPCDet-------------------------')
|
||||
demo_dataset = DemoDataset(
|
||||
dataset_cfg=cfg.DATA_CONFIG, class_names=cfg.CLASS_NAMES, training=False,
|
||||
root_path=Path(args.data_path), ext=args.ext, logger=logger
|
||||
)
|
||||
logger.info(f'Total number of samples: \t{len(demo_dataset)}')
|
||||
|
||||
model = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=demo_dataset)
|
||||
model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=True)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for idx, data_dict in enumerate(demo_dataset):
|
||||
logger.info(f'Visualized sample index: \t{idx + 1}')
|
||||
data_dict = demo_dataset.collate_batch([data_dict])
|
||||
load_data_to_gpu(data_dict)
|
||||
pred_dicts, _ = model.forward(data_dict)
|
||||
|
||||
V.draw_scenes(
|
||||
points=data_dict['points'][:, 1:], ref_boxes=pred_dicts[0]['pred_boxes'],
|
||||
ref_scores=pred_dicts[0]['pred_scores'], ref_labels=pred_dicts[0]['pred_labels']
|
||||
)
|
||||
|
||||
if not OPEN3D_FLAG:
|
||||
mlab.show(stop=True)
|
||||
|
||||
logger.info('Demo done.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user