diff --git a/pcdet/utils/commu_utils.py b/pcdet/utils/commu_utils.py new file mode 100644 index 0000000..d9e866f --- /dev/null +++ b/pcdet/utils/commu_utils.py @@ -0,0 +1,182 @@ +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. + +deeply borrow from maskrcnn-benchmark and ST3D +""" + +import pickle +import time + +import torch +import torch.distributed as dist + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + origin_size = None + if not isinstance(data, torch.Tensor): + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + else: + origin_size = data.size() + tensor = data.reshape(-1) + + tensor_type = tensor.dtype + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).to("cuda") + size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.FloatTensor(size=(max_size,)).cuda().to(tensor_type)) + if local_size != max_size: + padding = torch.FloatTensor(size=(max_size - local_size,)).cuda().to(tensor_type) + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + if origin_size is None: + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + else: + buffer = tensor[:size] + data_list.append(buffer) + + if origin_size is not None: + new_shape = [-1] + list(origin_size[1:]) + resized_list = [] + for data in data_list: + # suppose the difference of tensor size exist in first dimension + data = data.reshape(new_shape) + resized_list.append(data) + + return resized_list + else: + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def average_reduce_value(data): + data_list = all_gather(data) + return sum(data_list) / len(data_list) + + +def all_reduce(data, op="sum", average=False): + + def op_map(op): + op_dict = { + "SUM": dist.ReduceOp.SUM, + "MAX": dist.ReduceOp.MAX, + "MIN": dist.ReduceOp.MIN, + "PRODUCT": dist.ReduceOp.PRODUCT, + } + return op_dict[op] + + world_size = get_world_size() + if world_size > 1: + reduced_data = data.clone() + dist.all_reduce(reduced_data, op=op_map(op.upper())) + if average: + assert op.upper() == 'SUM' + return reduced_data / world_size + else: + return reduced_data + return data + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output