InĀ [1]:
import os
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
import random
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torch.optim as optim
from torchsummary import summary
import matplotlib.pyplot as plt
from piq import MultiScaleSSIMLoss
from skimage.metrics import peak_signal_noise_ratio

# Seting the device for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
Using device: cuda
InĀ [2]:
class HDRDataset(Dataset):
    def __init__(
        self,
        underexposed_dir: str,
        overexposed_dir: str,
        ground_truth_dir: str,
        indices: list = None,
        augment: bool = True
    ):
        # File lists
        self.underexposed_list  = sorted(os.listdir(underexposed_dir))
        self.overexposed_list   = sorted(os.listdir(overexposed_dir))
        self.ground_truth_list  = sorted(os.listdir(ground_truth_dir))
        
        if indices is not None:
            self.underexposed_list = [self.underexposed_list[i] for i in indices]
            self.overexposed_list  = [self.overexposed_list[i]  for i in indices]
            self.ground_truth_list = [self.ground_truth_list[i] for i in indices]

        self.underexposed_dir  = underexposed_dir
        self.overexposed_dir   = overexposed_dir
        self.ground_truth_dir  = ground_truth_dir

        # Helper in converting from sRGB to linear
        def srgb_to_linear(x: torch.Tensor) -> torch.Tensor:
            return torch.where( x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4 )

        # Normalization (ImageNet stats), also in linear RGB
        self.normalize = transforms.Compose([
            transforms.ToTensor(),                                 # [0,255]→[0,1] sRGB
            transforms.Lambda(lambda x: srgb_to_linear(x)),        # linearizing
            transforms.Normalize(
                mean=[0.485,0.456,0.406],
                std =[0.229,0.224,0.225])
        ])

    def __len__(self):
        return len(self.underexposed_list)

    def __getitem__(self, idx):

        underexposed_img = Image.open(os.path.join(self.underexposed_dir, self.underexposed_list[idx])).convert("RGB")
        overexposed_img = Image.open(os.path.join(self.overexposed_dir, self.overexposed_list[idx])).convert("RGB")
        ground_truth_img = Image.open(os.path.join(self.ground_truth_dir, self.ground_truth_list[idx])).convert("RGB")

        # Applying Transformations
        # 1) Random horizontal flip
        if random.random() < 0.5:
            underexposed_img, overexposed_img, ground_truth_img = \
                           TF.hflip(underexposed_img), TF.hflip(overexposed_img), TF.hflip(ground_truth_img)

        # 2) Random rotation
        angle = random.uniform(-15, 15)
        underexposed_img, overexposed_img, ground_truth_img = \
                          TF.rotate(underexposed_img, angle), TF.rotate(overexposed_img, angle), TF.rotate(ground_truth_img, angle)

        # 3) Random affine (Translation, Scaling and Shearing)
        translate = (
            int(random.uniform(-0.05, 0.05) * underexposed_img.size[0]),
            int(random.uniform(-0.05, 0.05) * underexposed_img.size[1]))
        
        scale = random.uniform(0.95, 1.05)
        
        shear = random.uniform(-5, 5)
        
        underexposed_img = TF.affine(underexposed_img, angle=0, translate=translate, scale=scale, shear=shear)
        overexposed_img = TF.affine(overexposed_img, angle=0, translate=translate, scale=scale, shear=shear)
        ground_truth_img = TF.affine(ground_truth_img, angle=0, translate=translate, scale=scale, shear=shear)

        # 4) Converting To Tensor + Normalization
        underexposed_img = self.normalize(underexposed_img)
        overexposed_img = self.normalize(overexposed_img)
        ground_truth_img = self.normalize(ground_truth_img)

        return underexposed_img, overexposed_img, ground_truth_img
InĀ [3]:
# Folder paths
ground_truth_folder = "Dataset_HDR/Size 384 Images/Ground Truth Images_384"
underexposed_folder = "Dataset_HDR/Size 384 Images/UnderExposed Images_384"
overexposed_folder  = "Dataset_HDR/Size 384 Images/OverExposed Images_384"
InĀ [4]:
# Getting list of indices and Train-Test Splits
all_indices = list(range(len(os.listdir(underexposed_folder))))
train_indices, test_indices = train_test_split(all_indices, test_size=0.1)
InĀ [5]:
# Creating datasets with the new HDRDataset signature
train_dataset = HDRDataset(underexposed_dir = underexposed_folder, overexposed_dir = overexposed_folder,
                           ground_truth_dir = ground_truth_folder, indices = train_indices, augment = True)

test_dataset = HDRDataset(underexposed_dir = underexposed_folder, overexposed_dir = overexposed_folder,
                          ground_truth_dir = ground_truth_folder, indices = test_indices, augment = False)

BATCH_SIZE  = 8
NUM_WORKERS = 4
PIN_MEMORY  = True

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, 
                          num_workers = NUM_WORKERS, pin_memory = PIN_MEMORY)

test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True,
                         num_workers = NUM_WORKERS,pin_memory=PIN_MEMORY)
InĀ [6]:
# Total number of images in each folder
total_underexposed = len(os.listdir(underexposed_folder))
total_overexposed  = len(os.listdir(overexposed_folder))

print(f"Total number of UnderExposed images: {total_underexposed}")
print(f"Total number of OverExposed images : {total_overexposed}\n")

print(f"Number of images in Train split of UnderExposed Folder: {len(train_dataset.underexposed_list)}")
print(f"Number of images in Test  split of UnderExposed Folder: {len(test_dataset.underexposed_list)}\n")

print(f"Number of images in Train split of OverExposed Folder: {len(train_dataset.overexposed_list)}")
print(f"Number of images in Test  split of OverExposed Folder: {len(test_dataset.overexposed_list)}")
Total number of UnderExposed images: 400
Total number of OverExposed images : 400

Number of images in Train split of UnderExposed Folder: 360
Number of images in Test  split of UnderExposed Folder: 40

Number of images in Train split of OverExposed Folder: 360
Number of images in Test  split of OverExposed Folder: 40
InĀ [7]:
class HDRResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, dropout_prob=0.3):
        super(HDRResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.Leakyrelu = nn.LeakyReLU(0.1, inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout1 = nn.Dropout2d(p=dropout_prob)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=True)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.dropout2 = nn.Dropout2d(p=dropout_prob)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True),
                nn.BatchNorm2d(out_channels)
            )
 
    def forward(self, x):
        identity = self.shortcut(x)
        identity = self.pool(identity)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.Leakyrelu(out)
        out = self.pool(out)
        out = self.dropout1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        
        out = out + identity
        out = self.Leakyrelu(out)
        out = self.dropout2(out)

        return out
InĀ [8]:
class HDRBranchUNet(nn.Module):
    def __init__(self, dropout_prob=0.3):
        super().__init__()
        
        # Encoder
        self.initial_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, padding=2, bias=True),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.MaxPool2d(2, 2),              # → (B,32,192,192)
            nn.Dropout2d(dropout_prob)
        )
        self.downlayer1 = HDRResidualBlock( 32,  64, dropout_prob=dropout_prob)  # → (B,64,96,96)
        self.downlayer2 = HDRResidualBlock( 64, 128, dropout_prob=dropout_prob)  # → (B,128,48,48)
        self.downlayer3 = HDRResidualBlock(128, 256, dropout_prob=dropout_prob)  # → (B,256,24,24)
        self.downlayer4 = HDRResidualBlock(256, 512, dropout_prob=dropout_prob)  # → (B,512,12,12)
        self.downlayer5 = HDRResidualBlock(512,1024,dropout_prob=dropout_prob)   # → (B,1024,6,6)

        # Decoder (upsample)
        self.uplayer5 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)             # → (B,512,12,12)
        self.skip_conn5 = nn.Sequential(nn.Conv2d(512+512, 512, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(512), nn.LeakyReLU(0.1, inplace=True))

        self.uplayer4 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)              # → (B,256,24,24)
        self.skip_conn4 = nn.Sequential(nn.Conv2d(256+256, 256, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(256), nn.LeakyReLU(0.1, inplace=True))

        self.uplayer3 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)              # → (B,128,48,48)
        self.skip_conn3 = nn.Sequential(nn.Conv2d(128+128, 128, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(128), nn.LeakyReLU(0.1, inplace=True))

        self.uplayer2 = nn.ConvTranspose2d(128,  64, kernel_size=4, stride=2, padding=1)              # → (B,64,96,96)
        self.skip_conn2 = nn.Sequential(nn.Conv2d(64+64, 64, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(64), nn.LeakyReLU(0.1,inplace=True))

        self.uplayer1 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)               # → (B,32,192,192)
        self.skip_conn1 = nn.Sequential(nn.Conv2d(32+32, 32, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(32), nn.LeakyReLU(0.1,inplace=True))

        self.uplayer0 = nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1)              # → (B,32,384,384)
        self.final_conv = nn.Conv2d(32, 3, kernel_size=1)                                           # → (B,3,384,384)

    def forward(self, x):
        # --- Encoder ---
        e0 = self.initial_conv(x)       # 32Ɨ192
        e1 = self.downlayer1(e0)        # 64Ɨ96
        e2 = self.downlayer2(e1)        # 128Ɨ48
        e3 = self.downlayer3(e2)        # 256Ɨ24
        e4 = self.downlayer4(e3)        # 512Ɨ12
        e5 = self.downlayer5(e4)        # 1024Ɨ6

        # --- Decoder with skip connections ---
        d5 = self.uplayer5(e5)                                   # 512Ɨ12
        d5 = self.skip_conn5(torch.cat([d5, e4], dim=1))         # fuse with e4

        d4 = self.uplayer4(d5)                                   # 256Ɨ24
        d4 = self.skip_conn4(torch.cat([d4, e3], dim=1))         # fuse with e3

        d3 = self.uplayer3(d4)                                   # 128Ɨ48
        d3 = self.skip_conn3(torch.cat([d3, e2], dim=1))         # fuse with e2

        d2 = self.uplayer2(d3)                                   # 64Ɨ96
        d2 = self.skip_conn2(torch.cat([d2, e1], dim=1))         # fuse with e1

        d1 = self.uplayer1(d2)                                   # 32Ɨ192
        d1 = self.skip_conn1(torch.cat([d1, e0], dim=1))         # fuse with e0

        d0 = self.uplayer0(d1)                                   # 32Ɨ384
        
        return self.final_conv(d0)
InĀ [9]:
class HDRConcat(nn.Module):
    def __init__(self, dropout_prob=0.3):
        super().__init__()
        
        # Two parallel UNet-style branches (each's outputs is a 3 channel 384Ɨ384 image)
        self.branch_underexposed = HDRBranchUNet(dropout_prob=dropout_prob)
        self.branch_overexposed  = HDRBranchUNet(dropout_prob=dropout_prob)

        # Merging and refining the concatenated HDR maps
        self.concat_conv = nn.Sequential(
            nn.Conv2d(6, 32, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(16,  3, kernel_size=1))

    def forward(self, underexposed_img, overexposed_img):
        
        # Each branch which has full UNet skips internally
        hdr_underexposed = self.branch_underexposed(underexposed_img)     # → (B,3,384,384)
        hdr_overexposed = self.branch_overexposed(overexposed_img)        # → (B,3,384,384)

        # Concatenating and refining
        x = torch.cat([hdr_underexposed, hdr_overexposed], dim=1)         # → (B,6,384,384)
        
        return self.concat_conv(x)                                        # → (B,3,384,384)
InĀ [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = HDRConcat(dropout_prob=0.3).to(device)

# Printing the summary of the model
print("The Model Summary is:")
summary(model, [(3,  384, 384), (3,  384, 384)])
The Model Summary is:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 32, 384, 384]           2,432
       BatchNorm2d-2         [-1, 32, 384, 384]              64
         LeakyReLU-3         [-1, 32, 384, 384]               0
         MaxPool2d-4         [-1, 32, 192, 192]               0
         Dropout2d-5         [-1, 32, 192, 192]               0
            Conv2d-6         [-1, 64, 192, 192]           2,112
       BatchNorm2d-7         [-1, 64, 192, 192]             128
         MaxPool2d-8           [-1, 64, 96, 96]               0
            Conv2d-9         [-1, 64, 192, 192]          51,264
      BatchNorm2d-10         [-1, 64, 192, 192]             128
        LeakyReLU-11         [-1, 64, 192, 192]               0
        MaxPool2d-12           [-1, 64, 96, 96]               0
        Dropout2d-13           [-1, 64, 96, 96]               0
           Conv2d-14           [-1, 64, 96, 96]         102,464
      BatchNorm2d-15           [-1, 64, 96, 96]             128
        LeakyReLU-16           [-1, 64, 96, 96]               0
        Dropout2d-17           [-1, 64, 96, 96]               0
 HDRResidualBlock-18           [-1, 64, 96, 96]               0
           Conv2d-19          [-1, 128, 96, 96]           8,320
      BatchNorm2d-20          [-1, 128, 96, 96]             256
        MaxPool2d-21          [-1, 128, 48, 48]               0
           Conv2d-22          [-1, 128, 96, 96]         204,928
      BatchNorm2d-23          [-1, 128, 96, 96]             256
        LeakyReLU-24          [-1, 128, 96, 96]               0
        MaxPool2d-25          [-1, 128, 48, 48]               0
        Dropout2d-26          [-1, 128, 48, 48]               0
           Conv2d-27          [-1, 128, 48, 48]         409,728
      BatchNorm2d-28          [-1, 128, 48, 48]             256
        LeakyReLU-29          [-1, 128, 48, 48]               0
        Dropout2d-30          [-1, 128, 48, 48]               0
 HDRResidualBlock-31          [-1, 128, 48, 48]               0
           Conv2d-32          [-1, 256, 48, 48]          33,024
      BatchNorm2d-33          [-1, 256, 48, 48]             512
        MaxPool2d-34          [-1, 256, 24, 24]               0
           Conv2d-35          [-1, 256, 48, 48]         819,456
      BatchNorm2d-36          [-1, 256, 48, 48]             512
        LeakyReLU-37          [-1, 256, 48, 48]               0
        MaxPool2d-38          [-1, 256, 24, 24]               0
        Dropout2d-39          [-1, 256, 24, 24]               0
           Conv2d-40          [-1, 256, 24, 24]       1,638,656
      BatchNorm2d-41          [-1, 256, 24, 24]             512
        LeakyReLU-42          [-1, 256, 24, 24]               0
        Dropout2d-43          [-1, 256, 24, 24]               0
 HDRResidualBlock-44          [-1, 256, 24, 24]               0
           Conv2d-45          [-1, 512, 24, 24]         131,584
      BatchNorm2d-46          [-1, 512, 24, 24]           1,024
        MaxPool2d-47          [-1, 512, 12, 12]               0
           Conv2d-48          [-1, 512, 24, 24]       3,277,312
      BatchNorm2d-49          [-1, 512, 24, 24]           1,024
        LeakyReLU-50          [-1, 512, 24, 24]               0
        MaxPool2d-51          [-1, 512, 12, 12]               0
        Dropout2d-52          [-1, 512, 12, 12]               0
           Conv2d-53          [-1, 512, 12, 12]       6,554,112
      BatchNorm2d-54          [-1, 512, 12, 12]           1,024
        LeakyReLU-55          [-1, 512, 12, 12]               0
        Dropout2d-56          [-1, 512, 12, 12]               0
 HDRResidualBlock-57          [-1, 512, 12, 12]               0
           Conv2d-58         [-1, 1024, 12, 12]         525,312
      BatchNorm2d-59         [-1, 1024, 12, 12]           2,048
        MaxPool2d-60           [-1, 1024, 6, 6]               0
           Conv2d-61         [-1, 1024, 12, 12]      13,108,224
      BatchNorm2d-62         [-1, 1024, 12, 12]           2,048
        LeakyReLU-63         [-1, 1024, 12, 12]               0
        MaxPool2d-64           [-1, 1024, 6, 6]               0
        Dropout2d-65           [-1, 1024, 6, 6]               0
           Conv2d-66           [-1, 1024, 6, 6]      26,215,424
      BatchNorm2d-67           [-1, 1024, 6, 6]           2,048
        LeakyReLU-68           [-1, 1024, 6, 6]               0
        Dropout2d-69           [-1, 1024, 6, 6]               0
 HDRResidualBlock-70           [-1, 1024, 6, 6]               0
  ConvTranspose2d-71          [-1, 512, 12, 12]       8,389,120
           Conv2d-72          [-1, 512, 12, 12]       4,719,104
      BatchNorm2d-73          [-1, 512, 12, 12]           1,024
        LeakyReLU-74          [-1, 512, 12, 12]               0
  ConvTranspose2d-75          [-1, 256, 24, 24]       2,097,408
           Conv2d-76          [-1, 256, 24, 24]       1,179,904
      BatchNorm2d-77          [-1, 256, 24, 24]             512
        LeakyReLU-78          [-1, 256, 24, 24]               0
  ConvTranspose2d-79          [-1, 128, 48, 48]         524,416
           Conv2d-80          [-1, 128, 48, 48]         295,040
      BatchNorm2d-81          [-1, 128, 48, 48]             256
        LeakyReLU-82          [-1, 128, 48, 48]               0
  ConvTranspose2d-83           [-1, 64, 96, 96]         131,136
           Conv2d-84           [-1, 64, 96, 96]          73,792
      BatchNorm2d-85           [-1, 64, 96, 96]             128
        LeakyReLU-86           [-1, 64, 96, 96]               0
  ConvTranspose2d-87         [-1, 32, 192, 192]          32,800
           Conv2d-88         [-1, 32, 192, 192]          18,464
      BatchNorm2d-89         [-1, 32, 192, 192]              64
        LeakyReLU-90         [-1, 32, 192, 192]               0
  ConvTranspose2d-91         [-1, 32, 384, 384]          16,416
           Conv2d-92          [-1, 3, 384, 384]              99
    HDRBranchUNet-93          [-1, 3, 384, 384]               0
           Conv2d-94         [-1, 32, 384, 384]           2,432
      BatchNorm2d-95         [-1, 32, 384, 384]              64
        LeakyReLU-96         [-1, 32, 384, 384]               0
        MaxPool2d-97         [-1, 32, 192, 192]               0
        Dropout2d-98         [-1, 32, 192, 192]               0
           Conv2d-99         [-1, 64, 192, 192]           2,112
     BatchNorm2d-100         [-1, 64, 192, 192]             128
       MaxPool2d-101           [-1, 64, 96, 96]               0
          Conv2d-102         [-1, 64, 192, 192]          51,264
     BatchNorm2d-103         [-1, 64, 192, 192]             128
       LeakyReLU-104         [-1, 64, 192, 192]               0
       MaxPool2d-105           [-1, 64, 96, 96]               0
       Dropout2d-106           [-1, 64, 96, 96]               0
          Conv2d-107           [-1, 64, 96, 96]         102,464
     BatchNorm2d-108           [-1, 64, 96, 96]             128
       LeakyReLU-109           [-1, 64, 96, 96]               0
       Dropout2d-110           [-1, 64, 96, 96]               0
HDRResidualBlock-111           [-1, 64, 96, 96]               0
          Conv2d-112          [-1, 128, 96, 96]           8,320
     BatchNorm2d-113          [-1, 128, 96, 96]             256
       MaxPool2d-114          [-1, 128, 48, 48]               0
          Conv2d-115          [-1, 128, 96, 96]         204,928
     BatchNorm2d-116          [-1, 128, 96, 96]             256
       LeakyReLU-117          [-1, 128, 96, 96]               0
       MaxPool2d-118          [-1, 128, 48, 48]               0
       Dropout2d-119          [-1, 128, 48, 48]               0
          Conv2d-120          [-1, 128, 48, 48]         409,728
     BatchNorm2d-121          [-1, 128, 48, 48]             256
       LeakyReLU-122          [-1, 128, 48, 48]               0
       Dropout2d-123          [-1, 128, 48, 48]               0
HDRResidualBlock-124          [-1, 128, 48, 48]               0
          Conv2d-125          [-1, 256, 48, 48]          33,024
     BatchNorm2d-126          [-1, 256, 48, 48]             512
       MaxPool2d-127          [-1, 256, 24, 24]               0
          Conv2d-128          [-1, 256, 48, 48]         819,456
     BatchNorm2d-129          [-1, 256, 48, 48]             512
       LeakyReLU-130          [-1, 256, 48, 48]               0
       MaxPool2d-131          [-1, 256, 24, 24]               0
       Dropout2d-132          [-1, 256, 24, 24]               0
          Conv2d-133          [-1, 256, 24, 24]       1,638,656
     BatchNorm2d-134          [-1, 256, 24, 24]             512
       LeakyReLU-135          [-1, 256, 24, 24]               0
       Dropout2d-136          [-1, 256, 24, 24]               0
HDRResidualBlock-137          [-1, 256, 24, 24]               0
          Conv2d-138          [-1, 512, 24, 24]         131,584
     BatchNorm2d-139          [-1, 512, 24, 24]           1,024
       MaxPool2d-140          [-1, 512, 12, 12]               0
          Conv2d-141          [-1, 512, 24, 24]       3,277,312
     BatchNorm2d-142          [-1, 512, 24, 24]           1,024
       LeakyReLU-143          [-1, 512, 24, 24]               0
       MaxPool2d-144          [-1, 512, 12, 12]               0
       Dropout2d-145          [-1, 512, 12, 12]               0
          Conv2d-146          [-1, 512, 12, 12]       6,554,112
     BatchNorm2d-147          [-1, 512, 12, 12]           1,024
       LeakyReLU-148          [-1, 512, 12, 12]               0
       Dropout2d-149          [-1, 512, 12, 12]               0
HDRResidualBlock-150          [-1, 512, 12, 12]               0
          Conv2d-151         [-1, 1024, 12, 12]         525,312
     BatchNorm2d-152         [-1, 1024, 12, 12]           2,048
       MaxPool2d-153           [-1, 1024, 6, 6]               0
          Conv2d-154         [-1, 1024, 12, 12]      13,108,224
     BatchNorm2d-155         [-1, 1024, 12, 12]           2,048
       LeakyReLU-156         [-1, 1024, 12, 12]               0
       MaxPool2d-157           [-1, 1024, 6, 6]               0
       Dropout2d-158           [-1, 1024, 6, 6]               0
          Conv2d-159           [-1, 1024, 6, 6]      26,215,424
     BatchNorm2d-160           [-1, 1024, 6, 6]           2,048
       LeakyReLU-161           [-1, 1024, 6, 6]               0
       Dropout2d-162           [-1, 1024, 6, 6]               0
HDRResidualBlock-163           [-1, 1024, 6, 6]               0
 ConvTranspose2d-164          [-1, 512, 12, 12]       8,389,120
          Conv2d-165          [-1, 512, 12, 12]       4,719,104
     BatchNorm2d-166          [-1, 512, 12, 12]           1,024
       LeakyReLU-167          [-1, 512, 12, 12]               0
 ConvTranspose2d-168          [-1, 256, 24, 24]       2,097,408
          Conv2d-169          [-1, 256, 24, 24]       1,179,904
     BatchNorm2d-170          [-1, 256, 24, 24]             512
       LeakyReLU-171          [-1, 256, 24, 24]               0
 ConvTranspose2d-172          [-1, 128, 48, 48]         524,416
          Conv2d-173          [-1, 128, 48, 48]         295,040
     BatchNorm2d-174          [-1, 128, 48, 48]             256
       LeakyReLU-175          [-1, 128, 48, 48]               0
 ConvTranspose2d-176           [-1, 64, 96, 96]         131,136
          Conv2d-177           [-1, 64, 96, 96]          73,792
     BatchNorm2d-178           [-1, 64, 96, 96]             128
       LeakyReLU-179           [-1, 64, 96, 96]               0
 ConvTranspose2d-180         [-1, 32, 192, 192]          32,800
          Conv2d-181         [-1, 32, 192, 192]          18,464
     BatchNorm2d-182         [-1, 32, 192, 192]              64
       LeakyReLU-183         [-1, 32, 192, 192]               0
 ConvTranspose2d-184         [-1, 32, 384, 384]          16,416
          Conv2d-185          [-1, 3, 384, 384]              99
   HDRBranchUNet-186          [-1, 3, 384, 384]               0
          Conv2d-187         [-1, 32, 384, 384]           1,760
     BatchNorm2d-188         [-1, 32, 384, 384]              64
       LeakyReLU-189         [-1, 32, 384, 384]               0
          Conv2d-190         [-1, 16, 384, 384]           4,624
     BatchNorm2d-191         [-1, 16, 384, 384]              32
       LeakyReLU-192         [-1, 16, 384, 384]               0
          Conv2d-193          [-1, 3, 384, 384]              51
================================================================
Total params: 141,158,537
Trainable params: 141,158,537
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 746496.00
Forward/backward pass size (MB): 1130.62
Params size (MB): 538.48
Estimated Total Size (MB): 748165.10
----------------------------------------------------------------
InĀ [11]:
# Hyperparameters & Setup
BATCH_SIZE = 8
LR         = 1e-4
MAX_EPOCHS = 50
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Denormalization
mean = torch.tensor([0.485,0.456,0.406], device=DEVICE)[None,:,None,None]
std  = torch.tensor([0.229,0.224,0.225], device=DEVICE)[None,:,None,None]

def denormalize(x):
    return x * std + mean

# Loss Functions
class HDRLoss(nn.Module):
    def __init__(self,
                 alpha,                                    # Mean Absolute Error on RGB Loss (RGB‐L1)
                 beta,                                     # Luminance‐L1 Loss (Lum-L1)
                 gamma,                                    # Multi-Scale Structural Similarity Index Measure Loss (MS-SSIM)
                 delta,                                    # Color Preservation Loss (Lab-Chroma penalty)
                 eps_hist, hist_bins, hist_sigma):         # Soft Histogram Matching Loss 
                 
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
        self.eps_hist = eps_hist
        self.hist_bins = hist_bins
        self.hist_sigma = hist_sigma

        self.ms_ssim_loss = MultiScaleSSIMLoss(data_range=1.0,reduction='mean')

    def get_luminance(self, img):
        R, G, B = img[:,0:1], img[:,1:2], img[:,2:3]
        return 0.2126*R + 0.7152*G + 0.0722*B

    def color_preservation_loss(self, pred, target):
        lum_p = self.get_luminance(pred)
        lum_t = self.get_luminance(target)
        return F.l1_loss(pred - lum_p, target - lum_t)

    def _soft_histogram(self, x):
        
        # x: BƗ3ƗHƗW in [0,1]
        bins = self.hist_bins
        sigma = self.hist_sigma
        centers = torch.linspace(0.0, 1.0, bins, device=x.device)  # [bins]
        x_expanded = x.unsqueeze(-1)                               # BƗ3ƗHƗWƗ1
        c = centers.view(1,1,1,1,bins)                             # 1Ɨ1Ɨ1Ɨ1Ɨbins
        diff = x_expanded - c                                      # BƗ3ƗHƗWƗbins
        weights = torch.exp(-0.5 * (diff / sigma)**2)              # Gaussian
        hist = weights.sum(dim=[2,3])                              # BƗ3Ɨbins
        hist = hist / (hist.sum(dim=-1, keepdim=True) + 1e-6)      # normalize
        return hist                                                # BƗ3Ɨbins

    def histogram_loss(self, pred, target):
        Hp = self._soft_histogram(pred)
        Ht = self._soft_histogram(target)
        return F.l1_loss(Hp, Ht)

    def forward(self, pred, target):
        rgb_l1_loss = F.l1_loss(pred, target)
        lum_l1_loss = F.l1_loss(self.get_luminance(pred), self.get_luminance(target))
        ms_ssim_loss = self.ms_ssim_loss(pred, target)
        color_pre_l1_loss = self.color_preservation_loss(pred, target)
        hist_loss = self.histogram_loss(pred, target)

        return (self.alpha * rgb_l1_loss + self.beta * lum_l1_loss + self.gamma * ms_ssim_loss + self.delta * color_pre_l1_loss +
            self.eps_hist * hist_loss)

loss_fn = HDRLoss( alpha = 0.8, beta = 0.6, gamma = 0.3, delta = 2.0, eps_hist = 0.1, hist_bins = 32, hist_sigma = 0.01).to(DEVICE)


# Model and Optimizer
model     = HDRConcat(dropout_prob=0.3).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scaler    = torch.amp.GradScaler()

# Training + Testing Loops
train_losses, test_losses = [], []
print("********** Training and Testing for Each Epochs wrt. Losses **********\n")

for epoch in range(1, MAX_EPOCHS+1):
    
    # Training
    model.train()
    total_train_loss = 0.0
    for u, o, gt in train_loader:
        u, o, gt = u.to(DEVICE), o.to(DEVICE), gt.to(DEVICE)
        optimizer.zero_grad()
        with torch.amp.autocast(device_type='cuda'):
            pred = model(u, o)
            p, g = denormalize(pred).clamp(0,1), denormalize(gt).clamp(0,1)
            loss = loss_fn(p, g)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_train_loss += loss.item()

    avg_train = total_train_loss / len(train_loader)
    train_losses.append(avg_train)

    # Testing
    model.eval()
    total_test_loss = 0.0
    with torch.no_grad():
        for u, o, gt in test_loader:
            u, o, gt = u.to(DEVICE), o.to(DEVICE), gt.to(DEVICE)
            pred = model(u, o)
            p, g = denormalize(pred).clamp(0,1), denormalize(gt).clamp(0,1)
            total_test_loss += loss_fn(p, g).item()

    avg_test = total_test_loss / len(test_loader)
    test_losses.append(avg_test)

    print(f"Epochs [{epoch}/{MAX_EPOCHS}]    "
          f"Training Loss: {avg_train:.4f}    |    Testing Loss: {avg_test:.4f}")

# Saving the final model
print("\nTraining & Testing completed. Saving final model weights…")
torch.save(model.state_dict(), "HDRModel.pth")
print("Model weights are saved to HDRModel.pth")
********** Training and Testing for Each Epochs wrt. Losses **********

Epochs [1/50]    Training Loss: 0.7261    |    Testing Loss: 0.5983
Epochs [2/50]    Training Loss: 0.5816    |    Testing Loss: 0.5327
Epochs [3/50]    Training Loss: 0.5282    |    Testing Loss: 0.4798
Epochs [4/50]    Training Loss: 0.4953    |    Testing Loss: 0.4542
Epochs [5/50]    Training Loss: 0.4678    |    Testing Loss: 0.4126
Epochs [6/50]    Training Loss: 0.4455    |    Testing Loss: 0.4076
Epochs [7/50]    Training Loss: 0.4243    |    Testing Loss: 0.4010
Epochs [8/50]    Training Loss: 0.4058    |    Testing Loss: 0.3985
Epochs [9/50]    Training Loss: 0.3890    |    Testing Loss: 0.3534
Epochs [10/50]    Training Loss: 0.3696    |    Testing Loss: 0.3442
Epochs [11/50]    Training Loss: 0.3550    |    Testing Loss: 0.3257
Epochs [12/50]    Training Loss: 0.3397    |    Testing Loss: 0.3106
Epochs [13/50]    Training Loss: 0.3255    |    Testing Loss: 0.2870
Epochs [14/50]    Training Loss: 0.3074    |    Testing Loss: 0.2883
Epochs [15/50]    Training Loss: 0.2925    |    Testing Loss: 0.2568
Epochs [16/50]    Training Loss: 0.2756    |    Testing Loss: 0.2456
Epochs [17/50]    Training Loss: 0.2620    |    Testing Loss: 0.2430
Epochs [18/50]    Training Loss: 0.2483    |    Testing Loss: 0.2309
Epochs [19/50]    Training Loss: 0.2372    |    Testing Loss: 0.2537
Epochs [20/50]    Training Loss: 0.2293    |    Testing Loss: 0.2101
Epochs [21/50]    Training Loss: 0.2197    |    Testing Loss: 0.2101
Epochs [22/50]    Training Loss: 0.2059    |    Testing Loss: 0.1757
Epochs [23/50]    Training Loss: 0.1972    |    Testing Loss: 0.1861
Epochs [24/50]    Training Loss: 0.1916    |    Testing Loss: 0.1934
Epochs [25/50]    Training Loss: 0.1883    |    Testing Loss: 0.1762
Epochs [26/50]    Training Loss: 0.1845    |    Testing Loss: 0.1638
Epochs [27/50]    Training Loss: 0.1785    |    Testing Loss: 0.1508
Epochs [28/50]    Training Loss: 0.1695    |    Testing Loss: 0.1500
Epochs [29/50]    Training Loss: 0.1698    |    Testing Loss: 0.1425
Epochs [30/50]    Training Loss: 0.1672    |    Testing Loss: 0.1492
Epochs [31/50]    Training Loss: 0.1571    |    Testing Loss: 0.1412
Epochs [32/50]    Training Loss: 0.1630    |    Testing Loss: 0.1365
Epochs [33/50]    Training Loss: 0.1566    |    Testing Loss: 0.1348
Epochs [34/50]    Training Loss: 0.1548    |    Testing Loss: 0.1401
Epochs [35/50]    Training Loss: 0.1521    |    Testing Loss: 0.1284
Epochs [36/50]    Training Loss: 0.1448    |    Testing Loss: 0.1236
Epochs [37/50]    Training Loss: 0.1486    |    Testing Loss: 0.1296
Epochs [38/50]    Training Loss: 0.1405    |    Testing Loss: 0.1123
Epochs [39/50]    Training Loss: 0.1325    |    Testing Loss: 0.1239
Epochs [40/50]    Training Loss: 0.1350    |    Testing Loss: 0.1148
Epochs [41/50]    Training Loss: 0.1339    |    Testing Loss: 0.1070
Epochs [42/50]    Training Loss: 0.1343    |    Testing Loss: 0.1126
Epochs [43/50]    Training Loss: 0.1221    |    Testing Loss: 0.1130
Epochs [44/50]    Training Loss: 0.1215    |    Testing Loss: 0.1227
Epochs [45/50]    Training Loss: 0.1257    |    Testing Loss: 0.0994
Epochs [46/50]    Training Loss: 0.1228    |    Testing Loss: 0.1044
Epochs [47/50]    Training Loss: 0.1249    |    Testing Loss: 0.1105
Epochs [48/50]    Training Loss: 0.1171    |    Testing Loss: 0.0957
Epochs [49/50]    Training Loss: 0.1079    |    Testing Loss: 0.0952
Epochs [50/50]    Training Loss: 0.1105    |    Testing Loss: 0.0881

Training & Testing completed. Saving final model weights…
Model weights are saved to HDRModel.pth
InĀ [12]:
# Plotting Training and Testing Loss vs Epochs
epochs = range(1, len(train_losses) + 1)

plt.figure()
plt.plot(epochs, train_losses, label='Training Loss', color='teal')
plt.plot(epochs, test_losses,  label='Testing Loss', color='brown')
plt.xlabel('\nEpochs')
plt.ylabel('Loss\n')
plt.title('Training and Testing Loss vs Epochs\n')
plt.legend()
plt.tight_layout()
plt.show()
No description has been provided for this image
InĀ [13]:
# Displaying of Output
display_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

ms_ssim_metric = MultiScaleSSIMLoss(data_range=1.0, reduction='mean').to(DEVICE)

def linear_to_srgb(x: torch.Tensor) -> torch.Tensor:
    return torch.where(x <= 0.0031308, 12.92 * x, 1.055 * (x ** (1/2.4)) - 0.055)

print("\nPress Enter to Display Next Test Image; Type 'quit' + Enter to Quit.")
model.eval()
with torch.no_grad():
    for idx, (underexposed_t, overexposed_t, ground_truth_t) in enumerate(display_loader):
        
        # Loading images & normalizing
        f_underexposed_img, f_overexposed_img, f_ground_truth_img = (lst[idx] for lst in
            (test_dataset.underexposed_list, test_dataset.overexposed_list, test_dataset.ground_truth_list))
        
        underexposed_img = Image.open(os.path.join(underexposed_folder, f_underexposed_img)).convert("RGB")
        overexposed_img = Image.open(os.path.join(overexposed_folder,  f_overexposed_img)).convert("RGB")
        ground_truth_img = Image.open(os.path.join(ground_truth_folder, f_ground_truth_img)).convert("RGB")

        underexposed_in = test_dataset.normalize(underexposed_img).unsqueeze(0).to(DEVICE)
        overexposed_in = test_dataset.normalize(overexposed_img).unsqueeze(0).to(DEVICE)

        # Predicting and geting full-range linear radiance
        P_out = model(underexposed_in, overexposed_in)
        
        P_lin = denormalize(P_out)[0]
        ground_truth_lin = denormalize(ground_truth_t.to(DEVICE))[0].clamp(0,1)

        # Applying small fixed per channel gain
        exps = torch.tensor([2.5, 2.5, 2.5], device=P_lin.device).view(3,1,1)
        P_gain = P_lin * exps

        # Reinhard Tone Mapping
        P_tone = P_gain / (P_gain + 1.0)

        # Gamma correction to sRGB
        P_srgb = linear_to_srgb(P_tone.clamp(0,1))
        P_np   = P_srgb.permute(1,2,0).cpu().numpy()

        # Prep reference images
        underexposed_np = np.array(underexposed_img)/255.0
        overexposed_np = np.array(overexposed_img)/255.0
        ground_truth_np = np.array(ground_truth_img)/255.0

        # Metrics on linear domain
        P_metric = denormalize(P_out).clamp(0,1)
        ground_truth_metric = denormalize(ground_truth_t.to(DEVICE)).clamp(0,1)
        P_arr = P_metric[0].permute(1,2,0).cpu().numpy()
        ground_truth_arr = ground_truth_metric[0].permute(1,2,0).cpu().numpy()

        psnr_val = peak_signal_noise_ratio(ground_truth_arr, P_arr, data_range=1.0)
        
        ms_ssim_val = ms_ssim_metric(P_metric, ground_truth_metric).item()

        # Plotting of Images
        fig, axs = plt.subplots(1,4,figsize=(18,5))
        axs[0].imshow(underexposed_np); axs[0].set_title("Underexposed Image");    axs[0].axis('off')
        axs[1].imshow(overexposed_np); axs[1].set_title("Overexposed Image");     axs[1].axis('off')
        axs[2].imshow(ground_truth_np); axs[2].set_title("Ground Truth Image");    axs[2].axis('off')
        axs[3].imshow(P_np)
        axs[3].set_title(f"Pred. Output HDR Image\nPSNR: {psnr_val:.2f}dB   Multi-Scale SSIM: {ms_ssim_val:.3f}")
        axs[3].axis('off')

        plt.tight_layout()
        plt.show()

        if input().strip().lower().startswith('quit'):
            break
Press Enter to Display Next Test Image; Type 'quit' + Enter to Quit.
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
InĀ [14]:
# END