660 lines
23 KiB
Python
660 lines
23 KiB
Python
"""
|
|
Mostly copy-paste from
|
|
https://github.com/open-mmlab/mmdetection/blob/ecac3a77becc63f23d9f6980b2a36f86acd00a8a/mmdet/models/layers/transformer/utils.py
|
|
|
|
"""
|
|
|
|
import copy
|
|
import math
|
|
import warnings
|
|
import collections.abc
|
|
from collections import OrderedDict
|
|
from itertools import repeat
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# From PyTorch internals
|
|
def _ntuple(n):
|
|
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
to_2tuple = _ntuple(2)
|
|
|
|
def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
nn.init.constant_(module.weight, val)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias)
|
|
|
|
|
|
def trunc_normal_init(module: nn.Module,
|
|
mean: float = 0,
|
|
std: float = 1,
|
|
a: float = -2,
|
|
b: float = 2,
|
|
bias: float = 0) -> None:
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
trunc_normal_(module.weight, mean, std, a, b) # type: ignore
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
nn.init.constant_(module.bias, bias) # type: ignore
|
|
|
|
|
|
def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
|
|
b: float) -> Tensor:
|
|
# Method based on
|
|
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
# Modified from
|
|
# https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
|
def norm_cdf(x):
|
|
# Computes standard normal cumulative distribution function
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
warnings.warn(
|
|
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
|
'The distribution of values may be incorrect.',
|
|
stacklevel=2)
|
|
|
|
with torch.no_grad():
|
|
# Values are generated by using a truncated uniform distribution and
|
|
# then using the inverse CDF for the normal distribution.
|
|
# Get upper and lower cdf values
|
|
lower = norm_cdf((a - mean) / std)
|
|
upper = norm_cdf((b - mean) / std)
|
|
|
|
# Uniformly fill tensor with values from [lower, upper], then translate
|
|
# to [2lower-1, 2upper-1].
|
|
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
|
|
|
# Use inverse cdf transform for normal distribution to get truncated
|
|
# standard normal
|
|
tensor.erfinv_()
|
|
|
|
# Transform to proper mean, std
|
|
tensor.mul_(std * math.sqrt(2.))
|
|
tensor.add_(mean)
|
|
|
|
# Clamp to ensure it's in the proper range
|
|
tensor.clamp_(min=a, max=b)
|
|
return tensor
|
|
|
|
|
|
def trunc_normal_(tensor: Tensor,
|
|
mean: float = 0.,
|
|
std: float = 1.,
|
|
a: float = -2.,
|
|
b: float = 2.) -> Tensor:
|
|
r"""Fills the input Tensor with values drawn from a truncated
|
|
normal distribution. The values are effectively drawn from the
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
with values outside :math:`[a, b]` redrawn until they are within
|
|
the bounds. The method used for generating the random values works
|
|
best when :math:`a \leq \text{mean} \leq b`.
|
|
|
|
Modified from
|
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
|
|
|
|
Args:
|
|
tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
|
|
mean (float): the mean of the normal distribution.
|
|
std (float): the standard deviation of the normal distribution.
|
|
a (float): the minimum cutoff value.
|
|
b (float): the maximum cutoff value.
|
|
"""
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
|
|
def drop_path(x: torch.Tensor,
|
|
drop_prob: float = 0.,
|
|
training: bool = False) -> torch.Tensor:
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
|
residual blocks).
|
|
|
|
We follow the implementation
|
|
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
|
"""
|
|
if drop_prob == 0. or not training:
|
|
return x
|
|
keep_prob = 1 - drop_prob
|
|
# handle tensors with different dimensions, not just 4D tensors.
|
|
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
|
random_tensor = keep_prob + torch.rand(
|
|
shape, dtype=x.dtype, device=x.device)
|
|
output = x.div(keep_prob) * random_tensor.floor()
|
|
return output
|
|
|
|
|
|
class DropPath(nn.Module):
|
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
|
residual blocks).
|
|
|
|
We follow the implementation
|
|
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
|
|
|
Args:
|
|
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
|
|
"""
|
|
|
|
def __init__(self, drop_prob: float = 0.1):
|
|
super().__init__()
|
|
self.drop_prob = drop_prob
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
class FFN(nn.Module):
|
|
"""Implements feed-forward networks (FFNs) with identity connection.
|
|
|
|
Args:
|
|
embed_dims (int): The feature dimension. Same as
|
|
`MultiheadAttention`. Defaults: 256.
|
|
feedforward_channels (int): The hidden dimension of FFNs.
|
|
Defaults: 1024.
|
|
num_fcs (int, optional): The number of fully-connected layers in
|
|
FFNs. Default: 2.
|
|
act_cfg (dict, optional): The activation config for FFNs.
|
|
Default: dict(type='ReLU')
|
|
ffn_drop (float, optional): Probability of an element to be
|
|
zeroed in FFN. Default 0.0.
|
|
add_identity (bool, optional): Whether to add the
|
|
identity connection. Default: `True`.
|
|
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
|
when adding the shortcut.
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims=256,
|
|
feedforward_channels=1024,
|
|
num_fcs=2,
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
ffn_drop=0.,
|
|
dropout_layer=None,
|
|
add_identity=True,
|
|
init_cfg=None,
|
|
**kwargs):
|
|
super().__init__()
|
|
self._is_init = False
|
|
self.init_cfg = copy.deepcopy(init_cfg)
|
|
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
|
f'than 2. got {num_fcs}.'
|
|
self.embed_dims = embed_dims
|
|
self.feedforward_channels = feedforward_channels
|
|
self.num_fcs = num_fcs
|
|
self.act_cfg = act_cfg
|
|
# ignore act_cfg, default GELU
|
|
self.activate = nn.GELU()
|
|
|
|
layers = []
|
|
in_channels = embed_dims
|
|
for _ in range(num_fcs - 1):
|
|
layers.append(
|
|
nn.Sequential(
|
|
nn.Linear(in_channels, feedforward_channels), self.activate,
|
|
nn.Dropout(ffn_drop)))
|
|
in_channels = feedforward_channels
|
|
layers.append(nn.Linear(feedforward_channels, embed_dims))
|
|
layers.append(nn.Dropout(ffn_drop))
|
|
self.layers = nn.Sequential(*layers)
|
|
self.dropout_layer = DropPath(dropout_layer['drop_prob'])
|
|
self.add_identity = add_identity
|
|
|
|
def forward(self, x, identity=None):
|
|
"""Forward function for `FFN`.
|
|
|
|
The function would add x to the output tensor if residue is None.
|
|
"""
|
|
out = self.layers(x)
|
|
if not self.add_identity:
|
|
return self.dropout_layer(out)
|
|
if identity is None:
|
|
identity = x
|
|
return identity + self.dropout_layer(out)
|
|
|
|
|
|
|
|
def nlc_to_nchw(x, hw_shape):
|
|
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.
|
|
|
|
Args:
|
|
x (Tensor): The input tensor of shape [N, L, C] before conversion.
|
|
hw_shape (Sequence[int]): The height and width of output feature map.
|
|
|
|
Returns:
|
|
Tensor: The output tensor of shape [N, C, H, W] after conversion.
|
|
"""
|
|
H, W = hw_shape
|
|
assert len(x.shape) == 3
|
|
B, L, C = x.shape
|
|
assert L == H * W, 'The seq_len does not match H, W'
|
|
return x.transpose(1, 2).reshape(B, C, H, W).contiguous()
|
|
|
|
|
|
def nchw_to_nlc(x):
|
|
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.
|
|
|
|
Args:
|
|
x (Tensor): The input tensor of shape [N, C, H, W] before conversion.
|
|
|
|
Returns:
|
|
Tensor: The output tensor of shape [N, L, C] after conversion.
|
|
"""
|
|
assert len(x.shape) == 4
|
|
return x.flatten(2).transpose(1, 2).contiguous()
|
|
|
|
|
|
class AdaptivePadding(nn.Module):
|
|
"""Applies padding to input (if needed) so that input can get fully covered
|
|
by filter you specified. It support two modes "same" and "corner". The
|
|
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
|
|
input. The "corner" mode would pad zero to bottom right.
|
|
|
|
Args:
|
|
kernel_size (int | tuple): Size of the kernel:
|
|
stride (int | tuple): Stride of the filter. Default: 1:
|
|
dilation (int | tuple): Spacing between kernel elements.
|
|
Default: 1
|
|
padding (str): Support "same" and "corner", "corner" mode
|
|
would pad zero to bottom right, and "same" mode would
|
|
pad zero around input. Default: "corner".
|
|
Example:
|
|
>>> kernel_size = 16
|
|
>>> stride = 16
|
|
>>> dilation = 1
|
|
>>> input = torch.rand(1, 1, 15, 17)
|
|
>>> adap_pad = AdaptivePadding(
|
|
>>> kernel_size=kernel_size,
|
|
>>> stride=stride,
|
|
>>> dilation=dilation,
|
|
>>> padding="corner")
|
|
>>> out = adap_pad(input)
|
|
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
|
>>> input = torch.rand(1, 1, 16, 17)
|
|
>>> out = adap_pad(input)
|
|
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
|
"""
|
|
|
|
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
|
|
|
|
super(AdaptivePadding, self).__init__()
|
|
|
|
assert padding in ('same', 'corner')
|
|
|
|
kernel_size = to_2tuple(kernel_size)
|
|
stride = to_2tuple(stride)
|
|
padding = to_2tuple(padding)
|
|
dilation = to_2tuple(dilation)
|
|
|
|
self.padding = padding
|
|
self.kernel_size = kernel_size
|
|
self.stride = stride
|
|
self.dilation = dilation
|
|
|
|
def get_pad_shape(self, input_shape):
|
|
input_h, input_w = input_shape
|
|
kernel_h, kernel_w = self.kernel_size
|
|
stride_h, stride_w = self.stride
|
|
output_h = math.ceil(input_h / stride_h)
|
|
output_w = math.ceil(input_w / stride_w)
|
|
pad_h = max((output_h - 1) * stride_h +
|
|
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
|
|
pad_w = max((output_w - 1) * stride_w +
|
|
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
|
|
return pad_h, pad_w
|
|
|
|
def forward(self, x):
|
|
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
|
|
if pad_h > 0 or pad_w > 0:
|
|
if self.padding == 'corner':
|
|
x = F.pad(x, [0, pad_w, 0, pad_h])
|
|
elif self.padding == 'same':
|
|
x = F.pad(x, [
|
|
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
|
|
pad_h - pad_h // 2
|
|
])
|
|
return x
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
"""Image to Patch Embedding.
|
|
|
|
We use a conv layer to implement PatchEmbed.
|
|
|
|
Args:
|
|
in_channels (int): The num of input channels. Default: 3
|
|
embed_dims (int): The dimensions of embedding. Default: 768
|
|
conv_type (str): The config dict for embedding
|
|
conv layer type selection. Default: "Conv2d.
|
|
kernel_size (int): The kernel_size of embedding conv. Default: 16.
|
|
stride (int): The slide stride of embedding conv.
|
|
Default: None (Would be set as `kernel_size`).
|
|
padding (int | tuple | string ): The padding length of
|
|
embedding conv. When it is a string, it means the mode
|
|
of adaptive padding, support "same" and "corner" now.
|
|
Default: "corner".
|
|
dilation (int): The dilation rate of embedding conv. Default: 1.
|
|
bias (bool): Bias of embed conv. Default: True.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: None.
|
|
input_size (int | tuple | None): The size of input, which will be
|
|
used to calculate the out size. Only work when `dynamic_size`
|
|
is False. Default: None.
|
|
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
embed_dims=768,
|
|
conv_type='Conv2d',
|
|
kernel_size=16,
|
|
stride=16,
|
|
padding='corner',
|
|
dilation=1,
|
|
bias=True,
|
|
norm_cfg=None,
|
|
input_size=None,
|
|
init_cfg=None,
|
|
):
|
|
super(PatchEmbed, self).__init__()
|
|
self._is_init = False
|
|
|
|
self.init_cfg = copy.deepcopy(init_cfg)
|
|
self.embed_dims = embed_dims
|
|
if stride is None:
|
|
stride = kernel_size
|
|
|
|
kernel_size = to_2tuple(kernel_size)
|
|
stride = to_2tuple(stride)
|
|
dilation = to_2tuple(dilation)
|
|
|
|
if isinstance(padding, str):
|
|
self.adap_padding = AdaptivePadding(
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
dilation=dilation,
|
|
padding=padding)
|
|
# disable the padding of conv
|
|
padding = 0
|
|
else:
|
|
self.adap_padding = None
|
|
padding = to_2tuple(padding)
|
|
|
|
self.projection = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=embed_dims,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
bias=bias)
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = nn.LayerNorm(embed_dims)
|
|
else:
|
|
self.norm = None
|
|
|
|
if input_size:
|
|
input_size = to_2tuple(input_size)
|
|
# `init_out_size` would be used outside to
|
|
# calculate the num_patches
|
|
# when `use_abs_pos_embed` outside
|
|
self.init_input_size = input_size
|
|
if self.adap_padding:
|
|
pad_h, pad_w = self.adap_padding.get_pad_shape(input_size)
|
|
input_h, input_w = input_size
|
|
input_h = input_h + pad_h
|
|
input_w = input_w + pad_w
|
|
input_size = (input_h, input_w)
|
|
|
|
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
|
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
|
|
(kernel_size[0] - 1) - 1) // stride[0] + 1
|
|
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
|
|
(kernel_size[1] - 1) - 1) // stride[1] + 1
|
|
self.init_out_size = (h_out, w_out)
|
|
else:
|
|
self.init_input_size = None
|
|
self.init_out_size = None
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Args:
|
|
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
|
|
|
|
Returns:
|
|
tuple: Contains merged results and its spatial shape.
|
|
|
|
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
|
|
- out_size (tuple[int]): Spatial shape of x, arrange as
|
|
(out_h, out_w).
|
|
"""
|
|
|
|
if self.adap_padding:
|
|
x = self.adap_padding(x)
|
|
|
|
x = self.projection(x)
|
|
out_size = (x.shape[2], x.shape[3])
|
|
x = x.flatten(2).transpose(1, 2)
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
return x, out_size
|
|
|
|
|
|
class PatchMerging(nn.Module):
|
|
"""Merge patch feature map.
|
|
|
|
This layer groups feature map by kernel_size, and applies norm and linear
|
|
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
|
|
merge patch, which is about 25% faster than original implementation.
|
|
Instead, we need to modify pretrained models for compatibility.
|
|
|
|
Args:
|
|
in_channels (int): The num of input channels.
|
|
to gets fully covered by filter and stride you specified..
|
|
Default: True.
|
|
out_channels (int): The num of output channels.
|
|
kernel_size (int | tuple, optional): the kernel size in the unfold
|
|
layer. Defaults to 2.
|
|
stride (int | tuple, optional): the stride of the sliding blocks in the
|
|
unfold layer. Default: None. (Would be set as `kernel_size`)
|
|
padding (int | tuple | string ): The padding length of
|
|
embedding conv. When it is a string, it means the mode
|
|
of adaptive padding, support "same" and "corner" now.
|
|
Default: "corner".
|
|
dilation (int | tuple, optional): dilation parameter in the unfold
|
|
layer. Default: 1.
|
|
bias (bool, optional): Whether to add bias in linear layer or not.
|
|
Defaults: False.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: dict(type='LN').
|
|
init_cfg (dict, optional): The extra config for initialization.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=2,
|
|
stride=None,
|
|
padding='corner',
|
|
dilation=1,
|
|
bias=False,
|
|
norm_cfg=dict(type='LN'),
|
|
init_cfg=None):
|
|
super().__init__()
|
|
self._is_init = False
|
|
self.init_cfg = copy.deepcopy(init_cfg)
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
if stride:
|
|
stride = stride
|
|
else:
|
|
stride = kernel_size
|
|
|
|
kernel_size = to_2tuple(kernel_size)
|
|
stride = to_2tuple(stride)
|
|
dilation = to_2tuple(dilation)
|
|
|
|
if isinstance(padding, str):
|
|
self.adap_padding = AdaptivePadding(
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
dilation=dilation,
|
|
padding=padding)
|
|
# disable the padding of unfold
|
|
padding = 0
|
|
else:
|
|
self.adap_padding = None
|
|
|
|
padding = to_2tuple(padding)
|
|
self.sampler = nn.Unfold(
|
|
kernel_size=kernel_size,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
stride=stride)
|
|
|
|
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
|
|
|
if norm_cfg is not None:
|
|
self.norm = nn.LayerNorm(sample_dim)
|
|
else:
|
|
self.norm = None
|
|
|
|
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
|
|
|
def forward(self, x, input_size):
|
|
"""
|
|
Args:
|
|
x (Tensor): Has shape (B, H*W, C_in).
|
|
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
|
Default: None.
|
|
|
|
Returns:
|
|
tuple: Contains merged results and its spatial shape.
|
|
|
|
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
|
- out_size (tuple[int]): Spatial shape of x, arrange as
|
|
(Merged_H, Merged_W).
|
|
"""
|
|
B, L, C = x.shape
|
|
assert isinstance(input_size, Sequence), f'Expect ' \
|
|
f'input_size is ' \
|
|
f'`Sequence` ' \
|
|
f'but get {input_size}'
|
|
|
|
H, W = input_size
|
|
assert L == H * W, 'input feature has wrong size'
|
|
|
|
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
|
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
|
# but need to modify pretrained model for compatibility
|
|
|
|
if self.adap_padding:
|
|
x = self.adap_padding(x)
|
|
H, W = x.shape[-2:]
|
|
|
|
x = self.sampler(x)
|
|
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
|
|
|
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
|
|
(self.sampler.kernel_size[0] - 1) -
|
|
1) // self.sampler.stride[0] + 1
|
|
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
|
|
(self.sampler.kernel_size[1] - 1) -
|
|
1) // self.sampler.stride[1] + 1
|
|
|
|
output_size = (out_h, out_w)
|
|
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
|
x = self.norm(x) if self.norm else x
|
|
x = self.reduction(x)
|
|
return x, output_size
|
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5):
|
|
"""Inverse function of sigmoid.
|
|
|
|
Args:
|
|
x (Tensor): The tensor to do the
|
|
inverse.
|
|
eps (float): EPS avoid numerical
|
|
overflow. Defaults 1e-5.
|
|
Returns:
|
|
Tensor: The x has passed the inverse
|
|
function of sigmoid, has same
|
|
shape with input.
|
|
"""
|
|
x = x.clamp(min=0, max=1)
|
|
x1 = x.clamp(min=eps)
|
|
x2 = (1 - x).clamp(min=eps)
|
|
return torch.log(x1 / x2)
|
|
|
|
|
|
|
|
def swin_converter(ckpt):
|
|
|
|
new_ckpt = OrderedDict()
|
|
|
|
def correct_unfold_reduction_order(x):
|
|
out_channel, in_channel = x.shape
|
|
x = x.reshape(out_channel, 4, in_channel // 4)
|
|
x = x[:, [0, 2, 1, 3], :].transpose(1,
|
|
2).reshape(out_channel, in_channel)
|
|
return x
|
|
|
|
def correct_unfold_norm_order(x):
|
|
in_channel = x.shape[0]
|
|
x = x.reshape(4, in_channel // 4)
|
|
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
|
|
return x
|
|
|
|
for k, v in ckpt.items():
|
|
if k.startswith('head'):
|
|
continue
|
|
elif k.startswith('layers'):
|
|
new_v = v
|
|
if 'attn.' in k:
|
|
new_k = k.replace('attn.', 'attn.w_msa.')
|
|
elif 'mlp.' in k:
|
|
if 'mlp.fc1.' in k:
|
|
new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
|
|
elif 'mlp.fc2.' in k:
|
|
new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
|
|
else:
|
|
new_k = k.replace('mlp.', 'ffn.')
|
|
elif 'downsample' in k:
|
|
new_k = k
|
|
if 'reduction.' in k:
|
|
new_v = correct_unfold_reduction_order(v)
|
|
elif 'norm.' in k:
|
|
new_v = correct_unfold_norm_order(v)
|
|
else:
|
|
new_k = k
|
|
new_k = new_k.replace('layers', 'stages', 1)
|
|
elif k.startswith('patch_embed'):
|
|
new_v = v
|
|
if 'proj' in k:
|
|
new_k = k.replace('proj', 'projection')
|
|
else:
|
|
new_k = k
|
|
else:
|
|
new_v = v
|
|
new_k = k
|
|
|
|
new_ckpt['backbone.' + new_k] = new_v
|
|
|
|
return new_ckpt
|