DIP encoder and decoder neural network structure code with skip connections

"""
This code implements an encoder and decoder neural network structure with skip connections
1: Define a function skip that accepts a series of parameters to build the network
2: Ensure that the lengths of parameters num_channels_down, num_channels_up, and num_channels_skip are the same through assert, otherwise an exception will be thrown.
Note: These three parameters are used to specify the parameters of the channel.
num_channels_down represents the number of output channels of each convolutional layer in the downsampling path. num_channels_down=[16, 32, 64, 128, 128] means that the first layer of downsampling path outputs 16 channels and the second layer outputs 32 channels. By analogy, the last layer outputs 128 channels.
num_channels_up represents the number of output channels of each convolutional layer in the upsampling path
num_channels_skip represents the number of channels in the skip connection part
In neural networks, channel refers to the dimension or depth of the feature map.
In a convolutional neural network, each convolutional layer outputs a certain number of feature maps, and each feature map corresponds to a channel. These channels can be viewed as representations of different features learned by the network. For example, the first convolutional layer may output 16 feature maps, each feature map corresponding to a channel, representing 16 different low-level features.
The number of channels determines the kind and complexity of features the network can capture and represent. Increasing the number of channels can improve the expressive ability of the network, but it will also increase the number of parameters and computational complexity of the model.
3: Use nn.Sequential() to create an empty neural network model model
4: Loop to create the network structure, each loop contains a downsampling path and an upsampling path.
   ·In the downsampling path, use convolutional layers for feature extraction and add batch normalization and activation functions.
   ·In the upsampling path, use the upsampling module to upsample the features, and then splice them with the features of the skip connection part, and then perform convolution, batch normalization and activation function operations.
5: Depending on whether you need to use a 1×1 convolution layer for upsampling, decide whether to add this layer.
6: Finally, use the convolutional layer for output, and optionally add the Sigmoid function for normalization.
   Return the built model.
Summary: This network structure can be used for tasks such as image reconstruction and denoising, and the performance and accuracy of the network can be improved through skip connections.
"""
import torch
import torch.nn as nn
from .common import *
def skip(
        num_input_channels=2, num_output_channels=3,
        num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4],
        filter_size_down=3, filter_size_up=3, filter_skip_size=1,
        need_sigmoid=True, need_bias=True,
        pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU',
        need1x1_up=True):
    """Assembles encoder-decoder with skip connections.

    Arguments:
        act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU)
        pad (string): zero|reflection (default: 'zero')
        upsample_mode (string): 'nearest|bilinear' (default: 'nearest')
        downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride')

    """
    # Purpose: To ensure that the number of channels of feature maps in different layers of the network are correctly set and matched.
    assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip)
    
    # Get the total number of layers of the network
    n_scales = len(num_channels_down)

   
    if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) :
        upsample_mode = [upsample_mode]*n_scales

    if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)):
        downsample_mode = [downsample_mode]*n_scales

    if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) :
        filter_size_down = [filter_size_down]*n_scales

    if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) :
        filter_size_up = [filter_size_up]*n_scales
    # Because the index of the list in Python starts from 0, the index of the last layer should be the total number of layers minus one: n_scales - 1 During model training and inference, this index is used to select the feature map of the last layer as the network output
    last_scale = n_scales - 1

    cur_depth = None

    # Use pytorch's nn.Sequential() to build a hierarchical neural network
    model = nn.Sequential()
    #Create a temporary model model_tmp pointing to model for adding submodules
    model_tmp = model

    input_depth = num_input_channels

    # Loop to create the hierarchical structure of the neural network based on the number of num_channels_down layers
    for i in range(len(num_channels_down)):

        deeper = nn.Sequential()
        skip = nn.Sequential()

        # skip and deeper are two tensors with the same size in dimension 1. Use Concat(1, skip, deeper) to concatenate them in dimension 1 to get a new tensor.
        # For example, if the shape of skip is (batch_size, num_channels_skip[i], height, width)
        # The shape of deeper is (batch_size, num_channels_down[i], height, width)
        # Then the result shape of Concat(1, skip, deeper) will be (batch_size, num_channels_skip[i] + num_channels_down[i], height, width)
        # Among them, num_channels_skip[i] + num_channels_down[i] represents the sum of the number of channels along dimension 1.

        if num_channels_skip[i] != 0:
            model_tmp.add(Concat(1, skip, deeper))
        else:
            model_tmp.add(deeper)








        # is to add a batch normalization layer to the temporary model model_tmp pointing to model
        #num_channels_skip[i] represents the number of skip channels in the skip connection part, num_channels_up[i + 1] If i is not the last layer, it represents the number of output channels of the current layer in the upsampling path, num_channels_down[i] if i is the last layer, It represents the number of output channels of the current layer in the downsampling path.
        # Downsampling refers to extracting part of the information from the input image to reduce the amount of data in the image, thereby obtaining a smaller image. Upsampling refers to adding new information to the image to make it larger, thereby obtaining a larger image.
        model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i])))

        if num_channels_skip[i] != 0:
            #Create a convolutional layer. input_depth is the number of input channels, num_channels_skip[i] is the number of output channels of the skip connection layer, filter_skip_size is the size of the convolution kernel, bias and pad respectively indicate whether to add a bias term and whether to fill.
            # Then, connect the convolutional layer to the skip connection module, and add the batch normalization layer and activation function in order.
            skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad))
            # Add batch normalization layer and activation function through skip.add(bn(num_channels_skip[i])) and skip.add(act(act_fun)) respectively
            # bn(num_channels_skip[i]) Creates a batch normalization module with num_channels_skip[i] number of output channels based on the previously mentioned bn() function.
            # means applying batch normalization and activation functions in sequence on the output of the convolutional layer.
            # Applying batch normalization on the output of each convolutional layer can normalize the output of the network by normalizing and adjusting the features of each channel dimension. This helps improve the expressiveness and robustness of the network.
            # The activation function is used to introduce nonlinear properties. It performs nonlinear mapping on the features output by the convolution layer to increase the expressive ability of the network. Common activation functions include ReLU, sigmoid, tanh, etc. By inputting the output of the convolutional layer into a list of activation function modules, features can be transformed nonlinearly to better capture complex features in the data.
            # Therefore, connecting the convolutional layer to the batch normalization and activation function module list is to introduce normalization and nonlinear transformation into the deep convolutional neural network, improve network performance and enhance the expressive ability of features.
            skip.add(bn(num_channels_skip[i]))
            skip.add(act(act_fun))

        # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part))

        # Create a convolution layer through conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]) and add it to deeper
        # The number of input channels of the convolution layer is input_depth, the number of output channels is num_channels_down[i], the convolution kernel size is filter_size_down[i], and the step size is 2, which means downsampling the input image. bias and pad respectively indicate whether to add a bias term and whether to fill, and downsample_mode[i] indicates the downsampling method.
        # You can choose "stride" (use stride for downsampling) or "avg" (use average pooling for downsampling).
        deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i]))
        # Add batch normalization layer and activation function in sequence
        deeper.add(bn(num_channels_down[i]))
        deeper.add(act(act_fun))
        # Create another convolution layer through conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad) and add it to deeper.
        # The number of input and output channels of this convolution layer is num_channels_down[i], the convolution kernel size is filter_size_down[i], bias and pad respectively indicate whether to add a bias term and whether to fill.
        deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad))
        deeper.add(bn(num_channels_down[i]))
        deeper.add(act(act_fun))

        # Retain the previously built deeper container by creating a sequential container deeper_main that does not contain any layers
        deeper_main = nn.Sequential()

        # Subsequent parts of network construction
        # The subsequent part of the network performs feature upsampling operations, and splices the upsampled features with the previous jump connection results. By adding operations such as convolution, batch normalization, and activation functions, features can be further extracted and the expressive power of the network can be increased.
        # Determine whether i is equal to the length of the num_channels_down list minus 1. If true, it means that the current layer is the deepest layer
        # If true, it means that the current layer is the deepest layer. At the deepest level, set the value of k to num_channels_down[i].
        # If it is not the deepest layer, add the previously built deeper_main container to the deeper container and set the value of k to num_channels_up[i + 1]. This is done to connect deeper_main to the current layer.
        if i == len(num_channels_down) - 1:
            #Thedeepest
            k = num_channels_down[i]
        else:
            deeper.add(deeper_main)
            k = num_channels_up[i + 1]
        # Upsample the input through the nn.Upsample module, enlarging its size by 2 times.
        # scale_factor=2 represents the scale factor of upsampling, mode=upsample_mode[i] represents the upsampling method, you can choose "nearest" (nearest neighbor interpolation) or "bilinear" (bilinear interpolation).
        deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i]))
        # Create a convolutional layer through conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)
        # and add it to model_tmp. The number of input channels of this convolutional layer is num_channels_skip[i] + k, the number of output channels is num_channels_up[i], the convolution kernel size is filter_size_up[i], and the step size is 1, which means performing a convolution operation on the input features.
        # bias and pad indicate whether to add a bias term and whether to fill, respectively.
        model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad))
        model_tmp.add(bn(num_channels_up[i]))
        model_tmp.add(act(act_fun))

        # First, determine whether you need to use a 1×1 convolution layer for upsampling. If so, create a conv convolution layer and add it to the model_tmp container.
        # The number of input and output channels of this convolution layer is num_channels_up[i], and the convolution kernel size is 1, which means convolution operation on the input features.
        if need1x1_up:
            model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad))
            model_tmp.add(bn(num_channels_up[i]))
            model_tmp.add(act(act_fun))
        # Reset the number of channels of the current input image to the number of output channels of the current layer, that is, input_depth = num_channels_down[i], for use in the next layer.
        # Create a convolution layer through conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad) and add it to the model container.
        # The number of input channels of this convolutional layer is num_channels_up[0] (that is, the number of output channels of the last layer), the number of output channels is num_output_channels, and the convolution kernel size is 1, which means performing a convolution operation on the input features.
        # If necessary, you can add a nn.Sigmoid() module after the convolutional layer to normalize the output.
        input_depth = num_channels_down[i]
        model_tmp = deeper_main

    model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad))
    if need_sigmoid:
        model.add(nn.Sigmoid())
    # Finally, by adding operations such as convolution, batch normalization, activation functions, and upsampling, a deep convolutional neural network is constructed for feature extraction and feature reconstruction of the input image.
    return model


# common
import torch
import torch.nn as nn
import numpy as np
from .downsampler import Downsampler

def add_module(self, module):
    self.add_module(str(len(self) + 1), module)
    
torch.nn.Module.add = add_module

class Concat(nn.Module):
    def __init__(self, dim, *args):
        super(Concat, self).__init__()
        self.dim = dim

        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, input):
        inputs = []
        for module in self._modules.values():
            inputs.append(module(input))

        inputs_shapes2 = [x.shape[2] for x in inputs]
        inputs_shapes3 = [x.shape[3] for x in inputs]

        if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
            inputs_ = inputs
        else:
            target_shape2 = min(inputs_shapes2)
            target_shape3 = min(inputs_shapes3)

            inputs_ = []
            for inp in inputs:
                diff2 = (inp.size(2) - target_shape2) // 2
                diff3 = (inp.size(3) - target_shape3) // 2
                inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])

        return torch.cat(inputs_, dim=self.dim)

    def __len__(self):
        return len(self._modules)


class GenNoise(nn.Module):
    def __init__(self, dim2):
        super(GenNoise, self).__init__()
        self.dim2 = dim2

    def forward(self, input):
        a = list(input.size())
        a[1] = self.dim2
        # print (input.data.type())

        b = torch.zeros(a).type_as(input.data)
        b.normal_()

        x = torch.autograd.Variable(b)

        return x


class Swish(nn.Module):
    """
        https://arxiv.org/abs/1710.05941
        The hype was so huge that I could not help but try it
    """
    def __init__(self):
        super(Swish, self).__init__()
        self.s = nn.Sigmoid()

    def forward(self, x):
        return x * self.s(x)


def act(act_fun = 'LeakyReLU'):
    '''
        Either string defining an activation function or module (e.g. nn.ReLU)
    '''
    if isinstance(act_fun, str):
        if act_fun == 'LeakyReLU':
            return nn.LeakyReLU(0.2, inplace=True)
        elif act_fun == 'Swish':
            return Swish()
        elif act_fun == 'ELU':
            return nn.ELU()
        elif act_fun == 'none':
            return nn.Sequential()
        else:
            assert False
    else:
        return act_fun()

# nn.BatchNorm2d(num_features) is a built-in function in PyTorch.
# Batch Normalization module used to build a two-dimensional convolutional layer
# It normalizes the input data in the channel dimension so that the mean on each channel is 0 and the variance is 1, and scales and translates the normalized data so that the network can converge faster.
# The num_features parameter indicates the number of channels of input data. The batch normalization operation is performed on the channel dimension.
def bn(num_features):
    return nn.BatchNorm2d(num_features)


def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
    downsampler = None
    if stride != 1 and downsample_mode != 'stride':

        if downsample_mode == 'avg':
            downsampler = nn.AvgPool2d(stride, stride)
        elif downsample_mode == 'max':
            downsampler = nn.MaxPool2d(stride, stride)
        elif downsample_mode in ['lanczos2', 'lanczos3']:
            downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
        else:
            assert False

        stride=1

    padder=None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0
  
    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)


    layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
    return nn.Sequential(*layers)
import numpy as np
import torch
import torch.nn as nn

class Downsampler(nn.Module):
    '''
        http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
    '''
    def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False):
        super(Downsampler, self).__init__()
        
        assert phase in [0, 0.5], 'phase should be 0 or 0.5'

        if kernel_type == 'lanczos2':
            support=2
            kernel_width = 4 * factor + 1
            kernel_type_ = 'lanczos'

        elif kernel_type == 'lanczos3':
            support=3
            kernel_width = 6 * factor + 1
            kernel_type_ = 'lanczos'

        elif kernel_type == 'gauss12':
            kernel_width = 7
            sigma = 1/2
            kernel_type_ = 'gauss'

        elif kernel_type == 'gauss1sq2':
            kernel_width = 9
            sigma = 1./np.sqrt(2)
            kernel_type_ = 'gauss'

        elif kernel_type in ['lanczos', 'gauss', 'box']:
            kernel_type_ = kernel_type

        else:
            assert False, 'wrong name kernel'
            
            
        # note that `kernel width` will be different to actual size for phase = 1/2
        self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma)
        
        downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0)
        downsampler.weight.data[:] = 0
        downsampler.bias.data[:] = 0

        kernel_torch = torch.from_numpy(self.kernel)
        for i in range(n_planes):
            downsampler.weight.data[i, i] = kernel_torch

        self.downsampler_ = downsampler

        if preserve_size:

            if self.kernel.shape[0] % 2 == 1:
                pad = int((self.kernel.shape[0] - 1) / 2.)
            else:
                pad = int((self.kernel.shape[0] - factor) / 2.)
                
            self.padding = nn.ReplicationPad2d(pad)
        
        self.preserve_size = preserve_size
        
    def forward(self, input):
        if self.preserve_size:
            x = self.padding(input)
        else:
            x= input
        self.x = x
        return self.downsampler_(x)
        
def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
    assert kernel_type in ['lanczos', 'gauss', 'box']
    
    # factor = float(factor)
    if phase == 0.5 and kernel_type != 'box':
        kernel = np.zeros([kernel_width - 1, kernel_width - 1])
    else:
        kernel = np.zeros([kernel_width, kernel_width])
    
        
    if kernel_type == 'box':
        assert phase == 0.5, 'Box filter is always half-phased'
        kernel[:] = 1./(kernel_width * kernel_width)
        
    elif kernel_type == 'gauss':
        assert sigma, 'sigma is not specified'
        assert phase != 0.5, 'phase 1/2 for gauss not implemented'
        
        center = (kernel_width + 1.)/2.
        print(center, kernel_width)
        sigma_sq = sigma * sigma
        
        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):
                di = (i - center)/2.
                dj = (j - center)/2.
                kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq))
                kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq)
    elif kernel_type == 'lanczos':
        assert support, 'support is not specified'
        center = (kernel_width + 1) / 2.

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):
                
                if phase == 0.5:
                    di = abs(i + 0.5 - center) / factor
                    dj = abs(j + 0.5 - center) / factor
                else:
                    di = abs(i - center) / factor
                    dj = abs(j - center) / factor
                
                
                pi_sq = np.pi * np.pi

                val=1
                if di != 0:
                    val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support)
                    val = val / (np.pi * np.pi * di * di)
                
                if dj != 0:
                    val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj/support)
                    val = val / (np.pi * np.pi * dj * dj)
                
                kernel[i - 1][j - 1] = val
            
        
    else:
        assert False, 'wrong method name'
    
    kernel /= kernel.sum()
    
    return kernel

#a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True)










#################
#downsampler
# Learnable downsampler

#KS = 32
# dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor))
    
# class Apply(nn.Module):
# def __init__(self, what, dim, *args):
# super(Apply, self).__init__()
# self.dim = dim
    
# self.what = what

# def forward(self, input):
# inputs = []
# for i in range(input.size(self.dim)):
# inputs.append(self.what(input.narrow(self.dim, i, 1)))

# return torch.cat(inputs, dim=self.dim)

# def __len__(self):
# return len(self._modules)
    
# downs = Apply(dow, 1)
# downs.type(dtype)(net_input.type(dtype)).size()