Add File
This commit is contained in:
50
pcdet/models/detectors/transfusion.py
Normal file
50
pcdet/models/detectors/transfusion.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from .detector3d_template import Detector3DTemplate
|
||||
|
||||
|
||||
class TransFusion(Detector3DTemplate):
|
||||
def __init__(self, model_cfg, num_class, dataset):
|
||||
super().__init__(model_cfg=model_cfg, num_class=num_class, dataset=dataset)
|
||||
self.module_list = self.build_networks()
|
||||
|
||||
def forward(self, batch_dict):
|
||||
for cur_module in self.module_list:
|
||||
batch_dict = cur_module(batch_dict)
|
||||
|
||||
if self.training:
|
||||
loss, tb_dict, disp_dict = self.get_training_loss(batch_dict)
|
||||
|
||||
ret_dict = {
|
||||
'loss': loss
|
||||
}
|
||||
return ret_dict, tb_dict, disp_dict
|
||||
else:
|
||||
pred_dicts, recall_dicts = self.post_processing(batch_dict)
|
||||
return pred_dicts, recall_dicts
|
||||
|
||||
def get_training_loss(self,batch_dict):
|
||||
disp_dict = {}
|
||||
|
||||
loss_trans, tb_dict = batch_dict['loss'],batch_dict['tb_dict']
|
||||
tb_dict = {
|
||||
'loss_trans': loss_trans.item(),
|
||||
**tb_dict
|
||||
}
|
||||
|
||||
loss = loss_trans
|
||||
return loss, tb_dict, disp_dict
|
||||
|
||||
def post_processing(self, batch_dict):
|
||||
post_process_cfg = self.model_cfg.POST_PROCESSING
|
||||
batch_size = batch_dict['batch_size']
|
||||
final_pred_dict = batch_dict['final_box_dicts']
|
||||
recall_dict = {}
|
||||
for index in range(batch_size):
|
||||
pred_boxes = final_pred_dict[index]['pred_boxes']
|
||||
|
||||
recall_dict = self.generate_recall_record(
|
||||
box_preds=pred_boxes,
|
||||
recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
|
||||
thresh_list=post_process_cfg.RECALL_THRESH_LIST
|
||||
)
|
||||
|
||||
return final_pred_dict, recall_dict
|
||||
Reference in New Issue
Block a user