103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
|
||
|
|
def clip_sigmoid(x, eps=1e-4):
|
||
|
|
y = torch.clamp(x.sigmoid_(), min=eps, max=1 - eps)
|
||
|
|
return y
|
||
|
|
|
||
|
|
|
||
|
|
class PositionEmbeddingLearned(nn.Module):
|
||
|
|
"""
|
||
|
|
Absolute pos embedding, learned.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, input_channel, num_pos_feats=288):
|
||
|
|
super().__init__()
|
||
|
|
self.position_embedding_head = nn.Sequential(
|
||
|
|
nn.Conv1d(input_channel, num_pos_feats, kernel_size=1),
|
||
|
|
nn.BatchNorm1d(num_pos_feats),
|
||
|
|
nn.ReLU(inplace=True),
|
||
|
|
nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1))
|
||
|
|
|
||
|
|
def forward(self, xyz):
|
||
|
|
xyz = xyz.transpose(1, 2).contiguous()
|
||
|
|
position_embedding = self.position_embedding_head(xyz)
|
||
|
|
return position_embedding
|
||
|
|
|
||
|
|
|
||
|
|
class TransformerDecoderLayer(nn.Module):
|
||
|
|
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
|
||
|
|
self_posembed=None, cross_posembed=None, cross_only=False):
|
||
|
|
super().__init__()
|
||
|
|
self.cross_only = cross_only
|
||
|
|
if not self.cross_only:
|
||
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||
|
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||
|
|
# Implementation of Feedforward model
|
||
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||
|
|
self.dropout = nn.Dropout(dropout)
|
||
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||
|
|
|
||
|
|
self.norm1 = nn.LayerNorm(d_model)
|
||
|
|
self.norm2 = nn.LayerNorm(d_model)
|
||
|
|
self.norm3 = nn.LayerNorm(d_model)
|
||
|
|
self.dropout1 = nn.Dropout(dropout)
|
||
|
|
self.dropout2 = nn.Dropout(dropout)
|
||
|
|
self.dropout3 = nn.Dropout(dropout)
|
||
|
|
|
||
|
|
def _get_activation_fn(activation):
|
||
|
|
"""Return an activation function given a string"""
|
||
|
|
if activation == "relu":
|
||
|
|
return F.relu
|
||
|
|
if activation == "gelu":
|
||
|
|
return F.gelu
|
||
|
|
if activation == "glu":
|
||
|
|
return F.glu
|
||
|
|
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||
|
|
|
||
|
|
self.activation = _get_activation_fn(activation)
|
||
|
|
|
||
|
|
self.self_posembed = self_posembed
|
||
|
|
self.cross_posembed = cross_posembed
|
||
|
|
|
||
|
|
def with_pos_embed(self, tensor, pos_embed):
|
||
|
|
return tensor if pos_embed is None else tensor + pos_embed
|
||
|
|
|
||
|
|
def forward(self, query, key, query_pos, key_pos, key_padding_mask=None, attn_mask=None):
|
||
|
|
# NxCxP to PxNxC
|
||
|
|
if self.self_posembed is not None:
|
||
|
|
query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1)
|
||
|
|
else:
|
||
|
|
query_pos_embed = None
|
||
|
|
if self.cross_posembed is not None:
|
||
|
|
key_pos_embed = self.cross_posembed(key_pos).permute(2, 0, 1)
|
||
|
|
else:
|
||
|
|
key_pos_embed = None
|
||
|
|
|
||
|
|
query = query.permute(2, 0, 1)
|
||
|
|
key = key.permute(2, 0, 1)
|
||
|
|
|
||
|
|
if not self.cross_only:
|
||
|
|
q = k = v = self.with_pos_embed(query, query_pos_embed)
|
||
|
|
query2 = self.self_attn(q, k, value=v)[0]
|
||
|
|
query = query + self.dropout1(query2)
|
||
|
|
query = self.norm1(query)
|
||
|
|
|
||
|
|
query2 = self.multihead_attn(query=self.with_pos_embed(query, query_pos_embed),
|
||
|
|
key=self.with_pos_embed(key, key_pos_embed),
|
||
|
|
value=self.with_pos_embed(key, key_pos_embed),
|
||
|
|
key_padding_mask=key_padding_mask, attn_mask=attn_mask)[0]
|
||
|
|
|
||
|
|
query = query + self.dropout2(query2)
|
||
|
|
query = self.norm2(query)
|
||
|
|
|
||
|
|
query2 = self.linear2(self.dropout(self.activation(self.linear1(query))))
|
||
|
|
query = query + self.dropout3(query2)
|
||
|
|
query = self.norm3(query)
|
||
|
|
|
||
|
|
# NxCxP to PxNxC
|
||
|
|
query = query.permute(1, 2, 0)
|
||
|
|
return query
|
||
|
|
|