Add File
This commit is contained in:
34
pcdet/models/model_utils/basic_block_2d.py
Normal file
34
pcdet/models/model_utils/basic_block_2d.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BasicBlock2D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
"""
|
||||
Initializes convolutional block
|
||||
Args:
|
||||
in_channels: int, Number of input channels
|
||||
out_channels: int, Number of output channels
|
||||
**kwargs: Dict, Extra arguments for nn.Conv2d
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
**kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, features):
|
||||
"""
|
||||
Applies convolutional block
|
||||
Args:
|
||||
features: (B, C_in, H, W), Input features
|
||||
Returns:
|
||||
x: (B, C_out, H, W), Output features
|
||||
"""
|
||||
x = self.conv(features)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user