Total-Variation based Bayesian Image Deblurring#

In this notebook, we will perform image deblurring via a Bayesian Maximum a Posteriori (MAP) approach. Our image prior will be composed of a total variation (TV) term and a positivity constraint.

Acknowledgements

This notebook is part of the offical Pyxu Example Gallery. We kindly acknowledge them for sharing their tutorial with us!

The forward model defining the blurring process is defined as:

\[\mathbf{y}=\mathbf{H}\mathbf{x}+\mathbf{n}\]

where:

  • \(\mathbf{y} \in \mathbb{R}^{d}\) is the observed blurred and noisy image,

  • \(\mathbf{H}: \mathbb{R}^{d\times d}\) is the blurring operator, which consists on a convolution with a Gaussian point-spread-function (PSF),

  • \(\mathbf{x} \in \mathbb{R}^{d}\) is the original clean image we want to recover,

  • \(\mathbf{n} \in \mathbb{R}^{d}\) is independent and identically distributed Gaussian noise.

# Importing necessary libraries and modules
import numpy as np
import matplotlib.pyplot as plt
import skimage
from pyxu.operator import Convolve, L21Norm, Gradient, SquaredL2Norm, PositiveOrthant
from pyxu.opt.solver import PD3O
from pyxu.opt.stop import RelError, MaxIter

# Setting up GPU support
GPU = False
if GPU:
    import cupy as xp
else:
    import numpy as xp

Loading and Preprocessing the Image#

We will use a sample image from the skimage.data module and preprocess it to be suitable for the deblurring process. The image is converted to a float type and normalized to have pixel values between 0 and 1.

# Loading and preprocessing the image
data = skimage.data.coffee()
skimage.io.imshow(data)
data = xp.asarray(data.astype("float32") / 255.0).transpose(2, 0, 1)
../../../_images/90d1f5174459b4b02b1f233f4c692d1d57c087613c3838af552f2b92493864ab.png

Creating the Blurring Kernel#

We will create a Gaussian blurring kernel (a.k.a. point spread function or PSF) to simulate the blurring effect of the camera lens on the image. The kernel is defined by its standard deviation and width. The Gaussian function is given by: $\( G(x)=\frac{1}{2\pi\sigma^{2}} e^{βˆ’\frac{(xβˆ’\mu)^{2}}{2\sigma^{2}}} \)$

where:

  • \(G(x)\) is the Gaussian function,

  • \(\sigma\) is the standard deviation,

  • \(\mu\) is the mean.

# Creating the Gaussian blurring kernel
sigma = 7
width = 13
mu = (width - 1) / 2
gauss = lambda x: (1 / (2 * np.pi * sigma**2)) * np.exp(
    -0.5 * ((x - mu) ** 2) / (sigma**2)
)

kernel_1d = np.fromfunction(gauss, (width,)).reshape(1, -1)
kernel_1d /= kernel_1d.sum()

kernel_1d = xp.asarray(kernel_1d)

Applying the Blurring and Adding Noise#

We will use the created Gaussian kernel to blur the image and then add Gaussian noise to simulate a real-world scenario where camera sensors are corrupted by thermal noise. Note that the 2D Gaussian kernel is defined in a separable fashion for efficiency reasons.

# Applying the blurring and adding noise
conv = Convolve(
    arg_shape=data.shape,
    kernel=[xp.array([1]), kernel_1d, kernel_1d], 
    center=[0, width // 2, width // 2],
    mode="reflect",
    enable_warnings=True,
)
y = conv(data.ravel()).reshape(data.shape)
y = xp.random.normal(loc=y, scale=0.05)
y = y.clip(0, 1)
skimage.io.imshow(y.transpose(1,2,0))
<matplotlib.image.AxesImage at 0x74993f2739a0>
../../../_images/0cd579efb002f38970475e7bff31ad460a6a5f5ad654ad4244aca8d937120a8a.png

MAP Estimate with Composite Positivity + Total Variation Prior#

Maximum a Posteriori seeks the most credible output given the likelihood and image prior, that is a mode of posterior distribution (not necessarily unique, but globally optimal in the log-concave case). The likelihood model is based on the noise distribution (here assumed Gaussian), and the prior model incorporates our assumptions about the image. The total variation prior promotes β€œmostly flat” solutions, helping to preserve edges while smoothing out noise. The positivity constraint ensures that the pixel values of the deblurred image remain non-negative.

The MAP optimization problem can be written as: $\(\hat{\mathbf{x}}=\arg\min_{\mathbf{x} β‰₯0} \frac{1}{2}\Vert \mathbf{y}βˆ’ \mathbf{H}\mathbf{x}\Vert^{2}_{2}+\lambda\Vert\nabla\mathbf{x}\Vert_{1,2}\)$

where:

  • \(\Vert\mathbf{y}βˆ’\mathbf{H}\mathbf{x}\Vert_2^2\) is the squared \(L_2\)-norm representing the data fidelity term (likelihood),

  • \(\Vert \nabla \mathbf{x}\Vert_{2, 1}=\sqrt{\Vert \nabla_{x} \mathbf{x} \Vert_{1}^{2} + \Vert \nabla_{y} \mathbf{x}\Vert_{1}^{2}}\) is the isotropic total variation norm,

  • \(\lambda\) is the regularization parameter,

  • \(\mathbf{x}β‰₯0\) is the positivity constraint.

We solve this problem with the PD3O solver πŸ”—, with

  • \(\mathcal{F}(\mathbf{x})=\frac{1}{2}\Vert \mathbf{y}βˆ’ \mathbf{H}\mathbf{x}\Vert^{2}_{2}\),

  • \(\mathcal{G}(\mathbf{x})=\iota_+(\mathbf{x})\),

  • \(\mathcal{H}(\mathbf{z})=\lambda \Vert \mathbf{z}\Vert_{2, 1}\),

  • \(\mathcal{K}=\nabla\).

PD3O manages the composite, non-smooth, and non-proximable term \(\mathcal{H}(\mathcal{K} \mathbf{x})\) utilizing its Fenchel biconjugate. This implies that, while the minimization of the functionals \(\mathcal{F}\) and \(\mathcal{G}\) occurs on the primal variable of interest, the minimization of the term \(\mathcal{H}(\mathcal{K} \mathbf{x})\) is indirectly undertaken on the dual variable, and, upon convergence, subsequently on the primal when the primal-dual gap reduces to zero. Therefore, it is crucial to set the relative improvement convergence threshold of PD3O at a significantly low level to ensure the proper convergence of the algorithm; if not, the resulting solution will not exhibit the mostly flat behavior expected from the application of a Total Variation (TV) prior. A challenge to note is that PD3O, being generically designed, may demonstrate slow convergence when seeking such high accuracies. We can overcome this issue by using the GPU implementation of PD3O, which is significantly faster than the CPU version (e.g., this example runs in approximately 30 seconds in GPU vs. approximately 4 minutes in CPU).

# Setting up the MAP approach with total variation prior and positivity constraint
sl2 = SquaredL2Norm(dim=y.size).asloss(y.ravel())
loss = sl2 * conv

l21 = L21Norm(arg_shape=(2, *y.shape), l2_axis=(0, 1))

grad = Gradient(
    arg_shape=y.shape,
    directions=(1, 2),
    gpu=GPU,
    diff_method="fd",
    scheme="central",
    accuracy=3,
)

stop_crit = RelError(
            eps=1e-6,
            var="x",
            f=None,
            norm=2,
            satisfy_all=True,
        ) | MaxIter(5000)

positivity = PositiveOrthant(dim=y.size)
solver = PD3O(f=loss, g=positivity, h= 3e-2 * l21, K=grad, verbosity=500)
solver.fit(x0=y.ravel(), stop_crit=stop_crit)

# Getting the deblurred image
recons = solver.solution().reshape(y.shape)
recons /= recons.max()
INFO -- [2024-06-18 13:24:57.830786] Iteration 0
	iteration: 0
	RelError[x]: 0.0
	N_iter: 1.0
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[6], line 26
     24 positivity = PositiveOrthant(dim=y.size)
     25 solver = PD3O(f=loss, g=positivity, h= 3e-2 * l21, K=grad, verbosity=500)
---> 26 solver.fit(x0=y.ravel(), stop_crit=stop_crit)
     28 # Getting the deblurred image
     29 recons = solver.solution().reshape(y.shape)

File ~/miniconda3/envs/image-analysis-field-guide/lib/python3.9/site-packages/pyxu/abc/solver.py:304, in Solver.fit(self, **kwargs)
    298 self._fit_init(
    299     mode=kwargs.pop("mode", SolverMode.BLOCK),
    300     stop_crit=kwargs.pop("stop_crit", None),
    301     track_objective=kwargs.pop("track_objective", False),
    302 )
    303 self.m_init(**kwargs)
--> 304 self._fit_run()

File ~/miniconda3/envs/image-analysis-field-guide/lib/python3.9/site-packages/pyxu/abc/solver.py:525, in Solver._fit_run(self)
    523 self._astate["worker"].start()
    524 if mode is SolverMode.BLOCK:
--> 525     self._astate["worker"].join()
    526     self.stop()  # state clean-up
    527 else:
    528     # User controls execution via busy() + stop().

File ~/miniconda3/envs/image-analysis-field-guide/lib/python3.9/threading.py:1060, in Thread.join(self, timeout)
   1057     raise RuntimeError("cannot join current thread")
   1059 if timeout is None:
-> 1060     self._wait_for_tstate_lock()
   1061 else:
   1062     # the behavior of a negative timeout isn't documented, but
   1063     # historically .join(timeout=x) for x<0 has acted as if timeout=0
   1064     self._wait_for_tstate_lock(timeout=max(timeout, 0))

File ~/miniconda3/envs/image-analysis-field-guide/lib/python3.9/threading.py:1080, in Thread._wait_for_tstate_lock(self, block, timeout)
   1077     return
   1079 try:
-> 1080     if lock.acquire(block, timeout):
   1081         lock.release()
   1082         self._stop()

KeyboardInterrupt: 
skimage.io.imshow(recons.transpose(1,2,0))
<matplotlib.image.AxesImage at 0x7a65aafacb80>
../../../_images/58340392a12db367c4cd6e30db0a36b2708dfa4f300787e7d764dc62c3c36702.png
# Evaluating the deblurred image
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr

if GPU:
    data = data.get()
    y = y.get()
    recons = recons.get()

data = data.transpose(1, 2, 0)
y = y.transpose(1, 2, 0)
recons = recons.clip(0,1)
recons = recons.transpose(1, 2, 0)

mse_y = mse(data, y)
ssim_y = ssim(data, y, channel_axis=2, data_range=1.)
psnr_y = psnr(data, y, data_range=1.)
mse_recons = mse(data, recons)
ssim_recons = ssim(data, recons, channel_axis=2, data_range=1.)
psnr_recons = psnr(data, recons, data_range=1.)

Visualizing the Results#

Finally, let’s visualize the original image, the blurred and noisy image, and the deblurred image obtained using the MAP approach with a total variation prior and positivity constraint. We will also display the evaluation metrics for a comprehensive comparison.

# Visualizing the results
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(15, 11))
axes[0, 0].imshow(data.clip(0,1))
axes[0, 0].set_title("Original Image")
axes[0, 0].axis('off')

axes[0, 1].imshow(y.clip(0, 1))
axes[0, 1].set_title(f"Blurred and Noisy Image\nMSE: {mse_y:.2f}, SSIM: {ssim_y:.2f}, PSNR: {psnr_y:.2f}")
axes[0, 1].axis('off')

axes[1, 0].imshow(recons)
axes[1, 0].set_title(f"Deblurred Image\nMSE: {mse_recons:.2f}, SSIM: {ssim_recons:.2f}, PSNR: {psnr_recons:.2f}")
axes[1, 0].axis('off')

plt.tight_layout()
plt.show()
../../../_images/25a083d8aeabd6930ba5fbec2be7e8afa37b2f63241f3feb495750af51c5bd4d.png