Semantic-Guided Zero-Shot Learning for Low-Light ImageVideo Enhancement

Paper reading: Semantic-Guided Zero-Shot Learning for Low-Light Image/Video Enhancement

Code: https://github.com/ShenZheng2000/SemantiGuided-Low-Light-Image-Enhancement

One possible way to increase brightness in low-light conditions is to use a higher ISO or longer exposure time. However, these strategies exacerbate noise and introduce motion blur respectively [2]. Another reasonable approach is to use modern software such as Photoshop or Lightroom to make light adjustments. However, these software require artistic skills and are inefficient for large-scale datasets with different lighting conditions.

In recent years, low-light image enhancement methods based on deep learning have attracted widespread attention due to their impressive efficiency, accuracy, and robustness. Supervised methods [33, 48, 52, 45] score the highest on. Some benchmark datasets [47, 13, 35, 25] have excellent image-to-image mapping capabilities. However, they require pairs of training images (i.e., low/normal light pairs), which either require expensive retouching or require unachievable image capture with the same scene but different lighting conditions. On the other hand, unsupervised methods [20] only require an unpaired dataset for training. Nonetheless, data bias in manually selected datasets limits their generalization capabilities. Zero-shot learning [11, 27] methods eliminate the need for paired images and unpaired datasets. However, they ignore semantic information, which [37, 10, 31] showed is crucial for high-level vision tasks. Therefore, they enhance images at suboptimal visual quality. Figure 1 reveals limitations in previous studies.

To address the above limitations, we propose a semantically guided zero-element framework for low-light image enhancement (Fig. 2). Since we focus on low-light image/video enhancement, we first design a lightweight Enhancement Factor Extraction (EFE) network with depthwise separable convolutions [15] and symmetric jump connections. Separable convolutions [15] and symmetric jump connections design lightweight enhanced factor extraction (EFE) networks. EFE has strong adaptive capabilities and can utilize the spatial information of low-light images to enhance image/video effects. To perform image enhancement at affordable model sizes, we subsequently introduce the Recursive Image Enhancement (RIE) network.
Image enhancement (RIE) networks are able to enhance images step by step, using the output of the previous stage as the input of the subsequent recursive stage. In order to preserve semantic information during the enhancement process, we finally
Finally, we propose an unsupervised semantic segmentation (USS) network that eliminates the need for expensive segmentation annotations. The network USS receives enhanced images from the RIE and utilizes
Feature Pyramid Network [29] to compute the segmentation loss. The segmentation loss is combined with other non-reference loss functions into a total loss, and the parameters of the EFE are updated during training. The contributions of the proposed work are summarized as follows:

  • We propose a new semantic-guided zero-element low-light image enhancement network. To the best of our knowledge, we are the first framework to fuse high-level semantic information with low-light image enhancement networks.
  • We develop a lightweight convolutional neural network to automatically extract enhancement factors that record pixel-level illumination deficiency in low-light images.
  • A cyclic image enhancement strategy containing five non-reference loss functions is designed to improve the model’s generalization ability to images with different lighting conditions.
  • Extensive experiments were conducted to demonstrate the superiority of the model on both qualitative and quantitative indicators. Our model is ideal for low-light video enhancement as it can process 1000 images of size 1200 × 900 in 1 second on a single GPU
  1. Proposed Method

3.1. Enhancement Factor Extraction Network

Enhancement factor extraction (EFE) aims to learn the pixel-level light deficiency of low-light images, recording this information in the enhancement factor. Inspired by the U-Net [42] architecture, EFE is a fully convolutional neural network with symmetric skip connections, which means it can handle input images of arbitrary sizes. Batch normalization or up/down sampling are not employed as they would destroy the spatial coherence of the enhanced image [43, 21, 18]. Each convolutional block in EFE consists of a 3 × 3 depthwise separable convolutional layer followed by a ReLU [38] activation layer. The last convolutional block reduces the number of channels from 32 to 3 and outputs an enhancement factor xr via Tanh activation. Figure 3 visualizes the enhancement factors extracted from 2 low-light images. It can be seen that brighter areas in low-light images correspond to lower enhancement factor values, and vice versa.

3.2. Recurrent Image Enhancement Network

Inspired by the success of recursion [41, 55, 28] and light enhancement curves [50, 11] in low-light image enhancement, we construct a recursive image enhancement (RIE) network to enhance low-light images according to the enhancement factor, and then Output the enhanced image. Each recursion takes as its input the output of the previous stage and the enhancement factor. The process of loop enhancement is as follows:

Where x is the output, xr is the enhancement factor, and t is the recursion step size. The next step is to decide the best order in which to light the image. Since recurrent networks should be simple differentials and should be efficient for asymptotic relief, we only consider the order of positive integers. With this in mind, we plot recursive image augmentation with respect to different xr and order in the figure. 4. When the order is 1, the pixel value is insensitive to xr, which is the same as the previous stage. When Order is equal to 3 or 4, the pixel value is close to or even exceeds 1.0, making the image look too bright. In contrast, the order of 2 confers the most robust recursive enhancement.

3.3. Unsupervised Semantic Segmentation Network

The Unsupervised Semantic Segmentation (USS) network aims to perform accurate pixel segmentation of enhanced images and maintain semantic information during the process of progressive image enhancement. Similar to [7, 32, 46, 12], we freeze all layers of the segmentation network during training. Here, we use two paths, including the bottom-bottom path, which uses ResNet-50 [14] with ImageNet [5] weights, and the top-down path, which uses Gaussian initialization with mean 0 and standard deviation 0.01. Both pathways have four convolutional blocks, which are connected to each other through lateral connections. The choice of weight initialization will be explained in the ablation study.

The enhanced image from RIE will first enter the bottom-up path for feature extraction. Then, a top-down path converts high semantic layers into high-resolution layers for space-aware semantic segmentation. Each convolutional block in the top-down approach performs bilinear upsampling on the image and concatenates it with the lateral results. For better perceptual quality, two smooth 3×3 convolutional layers are applied after stitching. Finally, we concatenate the results of each block in and calculate the split.

3.4. Loss Functions

Five reference-free loss functions are used, including Lspa, Lrgb, Lbri, Ltv and Lsem. We did not consider content loss or perceptual loss due to unavailability of paired training images [35].

**Spatial Consistency Loss** This spatial consistency loss helps maintain spatial consistency between low-light and enhanced images by preserving the differences between adjacent pixels during the enhancement process. Unlike [11, 27] which only considers adjacent units, we also include spatial consistency with non-adjacent units (see Figure 5). The spatial consistency loss is:

image-20231109203803367

Among them, Y and I are the average pixel values of the a × a local area in the enhanced image and low-light image respectively. A is one side of the local area, which we will set to 4 according to the ablation study. ψ(i) is four adjacent neighborhood values (top, down, left, right), ψ(i) is four non-adjacent neighborhood values (top left, top right, lower left, lower right) . The α value is 0.5 because the weight of non-adjacent neighbors is less important.

RGB loss Color loss [45, 52, 11] reduces color inaccuracies in enhanced images by bridging different color channels. We adopt Charbonnier loss, which facilitates high-quality image reconstruction [23, 19]. RGB loss is

image-20231109204718447

Among them, ε is a penalty term, which is set to 10?6 based on experience to improve training stability.

Brightness loss Inspired by [34, 45, 11], we design brightness loss to limit under/overexposure in images. The loss measures the L1 difference between the average pixel value of a specific area to a predefined exposure level e, the brightness loss is

image-20231109205038003

where E is the ideal image exposure level, set to 0. Ablation studies show 60. Total variationLoss Total variation loss[3] measures the difference between adjacent pixels in an image. We use total variation loss here to reduce noise and increase smoothness of the image. Different from previous low-light image enhancement work [48, 52, 45, 11], we additionally consider the inter-channel (R, G and B) relationship in the loss to improve color brightness. Our total variation loss is:

image-20231109205538899

Among them, C, H, and W represent the channel, height, and width of the image respectively. ?x and ?y are the horizontal and vertical gradient operations respectively. Semantic Loss Semantic loss helps preserve the semantic information of the image during the enhancement process. We refer to focal loss [30] to write our cost function. The semantic loss recommended by ablation studies does not require segmentation labels, only a pre-initialized model. The semantic loss is:

image-20231109205645699

Among them, p is the estimated class probability of a pixel by the segmentation network. Inspired by [7], we choose the focus coefficients β and γ to be 1 and 2 respectively.

Code

Two network components are required, one is the recursive network and the other is the segmentation network (no optimization required)

image-20231109210358005

class enhance_net_nopool(nn.Module):
    def __init__(self, scale_factor, conv_type='dsc'):
        super(enhance_net_nopool, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.scale_factor = scale_factor
        self.upsample = nn.UpsamplingBilinear2d(scale_factor=self.scale_factor)
        number_f = 32
        # Define Conv type
        if conv_type == 'dsc':
            self.conv = DSC
        elif conv_type == 'dc':
            self.conv = DC
        elif conv_type == 'tc':
            self.conv = TC
        else:
            print("conv type is not available")
        # zerodce DWC + p-shared
        self.e_conv1 = self.conv(3, number_f)
        self.e_conv2 = self.conv(number_f, number_f)
        self.e_conv3 = self.conv(number_f, number_f)
        self.e_conv4 = self.conv(number_f, number_f)
        self.e_conv5 = self.conv(number_f * 2, number_f)
        self.e_conv6 = self.conv(number_f * 2, number_f)
        self.e_conv7 = self.conv(number_f * 2, 3)

    def enhance(self, x, x_r):
        x = x + x_r * (torch.pow(x, 2) - x)
        x = x + x_r * (torch.pow(x, 2) - x)
        x = x + x_r * (torch.pow(x, 2) - x)
        enhance_image_1 = x + x_r * (torch.pow(x, 2) - x)
        x = enhance_image_1 + x_r * (torch.pow(enhance_image_1, 2) - enhance_image_1)
        x = x + x_r * (torch.pow(x, 2) - x)
        x = x + x_r * (torch.pow(x, 2) - x)
        enhance_image = x + x_r * (torch.pow(x, 2) - x)
        return enhance_image
    def forward(self, x):
        if self.scale_factor == 1:
            x_down = x
        else:
            x_down = F.interpolate(x, scale_factor=1 / self.scale_factor, mode='bilinear')
        #extraction
        x1 = self.relu(self.e_conv1(x_down))
        x2 = self.relu(self.e_conv2(x1))
        x3 = self.relu(self.e_conv3(x2))
        x4 = self.relu(self.e_conv4(x3))
        x5 = self.relu(self.e_conv5(torch.cat([x3, x4], 1)))
        x6 = self.relu(self.e_conv6(torch.cat([x2, x5], 1)))
        x_r = F.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
        
        #Dense link extracts image features (amplification factor)
        if self.scale_factor == 1:
            x_r = x_r
        else:
            x_r = self.upsample(x_r)
        #Make sure the size of x_r is consistent with the size of the image
        #enhancement
        The difference in the network structure of #zero-dce is that x_r changes, but the x_r in this article will not change after being obtained (only in self.enhance)
        enhance_image = self.enhance(x, x_r)
        return enhance_image, x_r

Compare the structure of zero-dce network

"""
Model File
"""

from mindspore import nn
from mindspore import ops
from mindspore.common.initializer import Normal

class ZeroDCE(nn.Cell):
    """
    Main Zero DCE Model
    """
    def __init__(self, *, sigma=0.02, mean=0.0):
        super().__init__()

        self.relu = nn.ReLU()

        number_f = 32
        self.e_conv1 = nn.Conv2d(3, number_f, 3, 1, pad_mode='pad', padding=1, has_bias=True,
                                 weight_init=Normal(sigma, mean))
        self.e_conv2 = nn.Conv2d(number_f, number_f, 3, 1, pad_mode='pad', padding=1, has_bias=True,
                                 weight_init=Normal(sigma, mean))
        self.e_conv3 = nn.Conv2d(number_f, number_f, 3, 1, pad_mode='pad', padding=1, has_bias=True,
                                 weight_init=Normal(sigma, mean))
        self.e_conv4 = nn.Conv2d(number_f, number_f, 3, 1, pad_mode='pad', padding=1, has_bias=True,
                                 weight_init=Normal(sigma, mean))
        self.e_conv5 = nn.Conv2d(number_f * 2, number_f, 3, 1, pad_mode='pad', padding=1,
                                 has_bias=True, weight_init=Normal(sigma, mean))
        self.e_conv6 = nn.Conv2d(number_f * 2, number_f, 3, 1, pad_mode='pad', padding=1,
                                 has_bias=True, weight_init=Normal(sigma, mean))
        self.e_conv7 = nn.Conv2d(number_f * 2, 24, 3, 1, pad_mode='pad', padding=1, has_bias=True,
                                 weight_init=Normal(sigma, mean))

        self.split = ops.Split(axis=1, output_num=8)
        self.cat = ops.Concat(axis=1)

    def construct(self, x):
        """
        ZeroDCE inference
        """
        x1 = self.relu(self.e_conv1(x))
        x2 = self.relu(self.e_conv2(x1))
        x3 = self.relu(self.e_conv3(x2))
        x4 = self.relu(self.e_conv4(x3))

        x5 = self.relu(self.e_conv5(self.cat([x3, x4])))
        x6 = self.relu(self.e_conv6(self.cat([x2, x5])))

        x_r = ops.tanh(self.e_conv7(self.cat([x1, x6])))
        r1, r2, r3, r4, r5, r6, r7, r8 = self.split(x_r)

        x = x + r1 * (ops.pows(x, 2) - x)
        x = x + r2 * (ops.pows(x, 2) - x)
        x = x + r3 * (ops.pows(x, 2) - x)
        enhance_image_1 = x + r4 * (ops.pows(x, 2) - x)
        x = enhance_image_1 + r5 * (ops.pows(enhance_image_1, 2) - enhance_image_1)
        x = x + r6 * (ops.pows(x, 2) - x)
        x = x + r7 * (ops.pows(x, 2) - x)
        enhance_image = x + r8 * (ops.pows(x, 2) - x)
        r = self.cat([r1, r2, r3, r4, r5, r6, r7, r8])
        return enhance_image, r

Semantic Segmentation Network

#The default parameter of calss here is 21.
class fpn(nn.Module):
    def __init__(self, numClass):
        super(fpn, self).__init__()
        # Res net
        self.resnet = resnet50(True)

        # fpn module
        self.fpn = fpn_module(numClass)#For detailed code, please see the author’s public code

        # init fpn
        for m in self.fpn.children():
            nn.init.normal_(m.weight, mean=0, std=0.01)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Top-down
        c2, c3, c4, c5 = self.resnet.forward(x)#Features come from different scale features of the pre-trained model
        
        return self.fpn.forward(c2, c3, c4, c5)
    #The output is a two-dimensional tensor with probabilities for the 28 channel categories

Network training:

 def train(self):
        self.net.train()

        for epoch in range(self.num_epochs):

            for iteration, img_lowlight in enumerate(self.train_loader):

                img_lowlight = img_lowlight.to(device)
                enhanced_image, A = self.net(img_lowlight)#Network to be learned
                loss = self.get_loss(A, enhanced_image, img_lowlight, self.E)#Loss that needs to be optimized

                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.net.parameters(), self.grad_clip_norm)
                self.optimizer.step()

                if ((iteration + 1) % self.display_iter) == 0:
                    print("Loss at iteration", iteration + 1, ":", loss.item())
                if ((iteration + 1) % self.snapshot_iter) == 0:
                    torch.save(self.net.state_dict(), self.snapshots_folder + "Epoch" + str(epoch) + '.pth')
 def get_loss(self, A, enhanced_image, img_lowlight, E):
        Loss_TV = 1600 * self.L_TV(A)
        loss_spa = torch.mean(self.L_spa(enhanced_image, img_lowlight))
        loss_col = 5 * torch.mean(self.L_color(enhanced_image))
        loss_exp = 10 * torch.mean(self.L_exp(enhanced_image, E))
        loss_seg = self.get_seg_loss(enhanced_image)

        loss = Loss_TV + loss_spa + loss_col + loss_exp + 0.1 * loss_seg

        return loss

The loss function mainly proposed in this article

 def get_NoGT_target(inputs):
    sfmx_inputs = F.log_softmax(inputs, dim=1)#Normalize according to rows or columns, and then do one more log operation
    target = torch.argmax(sfmx_inputs, dim=1)# Will input the input tensor, no matter how many dimensions it has, first reshape it into a one-dimensional vector, and then find the index of the maximum value in this one-dimensional vector
    return target
    def get_seg_loss(self, enhanced_image):
        # segment the enhanced image
        seg_input = enhanced_image.to(device)
        seg_output = self.seg(seg_input).to(device)

        # build segment output
        target = (get_NoGT_target(seg_output)).data.to(device)

        # calculate seg. loss
        seg_loss = self.seg_criterion(seg_output, target)

        return seg_loss
self.seg_criterion = FocalLoss(gamma=2).to(device)
class FocalLoss(nn.Module):

    # def __init__(self, device, gamma=0, eps=1e-7, size_average=True):
    def __init__(self, gamma=0, eps=1e-7, size_average=True, reduce=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.size_average = size_average
        self.reduce = reduce
        # self.device = device

    def forward(self, input, target):
        # y = one_hot(target, input.size(1), self.device)
        y = one_hot(target, input.size(1))
        probs = F.softmax(input, dim=1)
        probs = (probs * y).sum(1) # dimension ?
        probs = probs.clamp(self.eps, 1. - self.eps)

        log_p = probs.log()
        # print('probs size= {}'.format(probs.size()))
        # print(probs)

        batch_loss = -(torch.pow((1 - probs), self.gamma)) * log_p
        # print('-----bacth_loss------')
        # print(batch_loss)

        if self.reduce:
            if self.size_average:
                loss = batch_loss.mean()
            else:
                loss = batch_loss.sum()
        else:
            loss = batch_loss
        return loss

published experimental results

image-20231109215024247

 if self.reduce:
        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
    else:
        loss = batch_loss
    return loss
published experimental results

[External link pictures are being transferred...(img-34Z4SH7g-1699538813410)]