Visualizing big images with OME-Zarr and Napari#

In this notebook, we will see how to

  • write a TIF image and its corresponding segmentation in ome-zarr format.

  • open and view an ome-zarr dataset in multiscale mode in Napari (in 2D mode).

  • browse an ome-zarr dataset interactively in Napari.

Acknowledgements

This notebook was adapted from the Napari Lazy Image Browser project by the EPFL Center for Imaging.

Setup#

You will need the ome-zarr-py library (pip install ome-zarr).

import os
import skimage.io
import dask.array as da
from dask import delayed
from ome_zarr.io import parse_url
from ome_zarr.writer import write_image
from ome_zarr.scale import Scaler
import zarr

Acknowledgements

We kindly acknowledge Lemaitre lab in EPFL for providing the data for this notebook!

from shared_data import DATASET

image_file = DATASET.fetch("drosophila_trachea.tif")

image_file
'/home/wittwer/.cache/field-guide/drosophila_trachea.tif'

Save a Numpy array to OME-Zarr format#

[…]

We use pyramids.

To learn more about writing OME-Zarr in Python, see the ome-zarr documentation.

img_shape = (30, 250, 250)

delayed_read = delayed(skimage.io.imread)

img = da.from_delayed(
    delayed_read(image_file), shape=img_shape, dtype=float
)

### Compute the segmentation
seg = (img > 30).astype(int)
zarr_output_path = "./test_img.zarr"

if os.path.isdir(zarr_output_path):
    import shutil
    shutil.rmtree(zarr_output_path)

os.mkdir(zarr_output_path)
chunk_shape = (10, 64, 64)

store = parse_url(zarr_output_path, mode="w").store
root = zarr.group(store=store)
labels_grp = root.create_group("labels")
label_name = "segmentation"
labels_grp.attrs["labels"] = [label_name]
label_grp = labels_grp.create_group(label_name)
label_grp.attrs["image-label"] = {"colors": [ {"label-value": 1, "rgba": [255, 0, 0, 255]}]}

scaler = Scaler(downscale=2, method='nearest')

img = img.rechunk(chunk_shape)

write_image(
    image=img, 
    group=root, 
    axes="zyx", 
    storage_options=dict(chunks=chunk_shape),
    scaler=scaler,
)

seg = seg.rechunk(chunk_shape)

write_image(
    image=seg,
    group=label_grp, 
    axes="zyx",
    storage_options=dict(chunks=chunk_shape),
    scaler=scaler,
)
[]

Open and view an OME-Zarr dataset in Napari#

from ome_zarr.reader import Reader
import napari

reader = Reader(parse_url(zarr_output_path))
nodes = list(reader())
dask_image_data = nodes[0].data
dask_labels_data = nodes[2].data

viewer = napari.view_image(dask_image_data, contrast_limits=[0, 255])
viewer.add_labels(dask_labels_data)
no parent found for <ome_zarr.reader.Label object at 0x74fce6f56bb0>: None
<Labels layer 'dask_labels_data' at 0x74fce7ffdd90>

Browse an ome-zarr dataset interactively in Napari#

import napari.layers
from napari.components.viewer_model import ViewerModel
from napari.qt import QtViewer
import numpy as np
import dask.array as da
from PyQt5.QtCore import Qt
from qtpy.QtWidgets import (
    QWidget,
    QLabel,
    QGridLayout,
    QSpinBox,
    QSplitter,
)

def get_image_chunk(img: da.array, center_loc, chunk_shape) -> da.Array:
    """Returns an image chunk of a given size (chunk_shape) centered on a given location (center_loc)."""
    [
        [start_z, stop_z],
        [start_y, stop_y],
        [start_x, stop_x],
    ] = get_bbox_location(img.shape, chunk_shape, center_loc)

    img_chunk = img[start_z:stop_z, start_y:stop_y, start_x:stop_x]

    return img_chunk

def get_bbox_location(img_shape, chunk_shape, center_loc):
    """Returns the 3D bounding box coordinates of an image chunk centered on a given location."""
    center_loc_array = np.asarray(center_loc).astype(int)
    cz, cy, cx = center_loc_array
    depth, width, length = chunk_shape
    max_z, max_y, max_x = img_shape

    start_z = max(cz - depth // 2, 0)
    start_y = max(cy - width // 2, 0)
    start_x = max(cx - length // 2, 0)

    stop_z = min(cz + depth // 2, max_z)
    stop_y = min(cy + width // 2, max_y)
    stop_x = min(cx + length // 2, max_x)

    return np.array(
        [
            [start_z, stop_z],
            [start_y, stop_y],
            [start_x, stop_x],
        ]
    )

class QtViewerWrap(QtViewer):
    def __init__(self, main_viewer, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.main_viewer = main_viewer

    def _qt_open(
        self,
        filenames: list,
        stack: bool,
        plugin: str = None,
        layer_type: str = None,
        **kwargs,
    ):
        """for drag and drop open files"""
        self.main_viewer.window._qt_viewer._qt_open(
            filenames, stack, plugin, layer_type, **kwargs
        )

class ExtraViewerWidget(QSplitter):
    """The main widget of the example."""

    def __init__(self, extra_viewer) -> None:
        super().__init__()

        self.extra_viewer = extra_viewer

        splitter = QSplitter()
        splitter.setOrientation(Qt.Vertical)
        splitter.addWidget(self.extra_viewer)
        splitter.setContentsMargins(0, 0, 0, 0)

        self.addWidget(splitter)

class BigBrowserWidget(QWidget):
    def __init__(self, napari_viewer, minimap_viewer, image_shape) -> None:
        super().__init__()

        self.image_shape = image_shape
        self.center_loc = np.array(image_shape) // 2

        self.viewer = napari_viewer
        self.viewer.text_overlay.visible = True

        self.minimap_viewer = minimap_viewer

        ### QT Layout
        grid_layout = QGridLayout()
        grid_layout.setAlignment(Qt.AlignTop)
        self.setLayout(grid_layout)

        # Chunk size in X / Y / Z
        grid_layout.addWidget(QLabel("Z"), 3, 0)
        self.z_chunk_spinbox = QSpinBox()
        self.z_chunk_spinbox.setMinimum(1)
        self.z_chunk_spinbox.setMaximum(2000)
        self.z_chunk_spinbox.setValue(20)
        grid_layout.addWidget(self.z_chunk_spinbox, 3, 1)

        grid_layout.addWidget(QLabel("Y"), 4, 0)
        self.y_chunk_spinbox = QSpinBox()
        self.y_chunk_spinbox.setMinimum(1)
        self.y_chunk_spinbox.setMaximum(2000)
        self.y_chunk_spinbox.setValue(100)
        grid_layout.addWidget(self.y_chunk_spinbox, 4, 1)

        grid_layout.addWidget(QLabel("X"), 5, 0)
        self.x_chunk_spinbox = QSpinBox()
        self.x_chunk_spinbox.setMinimum(1)
        self.x_chunk_spinbox.setMaximum(2000)
        self.x_chunk_spinbox.setValue(100)
        grid_layout.addWidget(self.x_chunk_spinbox, 5, 1)

        # Update the view when the values change in the spinboxes
        self.z_chunk_spinbox.valueChanged.connect(self._update_view)
        self.y_chunk_spinbox.valueChanged.connect(self._update_view)
        self.x_chunk_spinbox.valueChanged.connect(self._update_view)

        # Bounding box in the minimap viewer
        initial_bbox_loc = get_bbox_location(
            image_shape, center_loc=self.center_loc, chunk_shape=self.chunk_shape
        )[1:].T[::-1]

        self.minimap_shapes_layer = self.minimap_viewer.add_shapes(
            data=initial_bbox_loc,
            shape_type="rectangle",
            edge_color="red",
            edge_width=5,
            face_color="transparent",
            name="Current location",
        )
        self.minimap_shapes_layer.mode = "SELECT"

        # Moving the bounding box updates the 3D view
        self.minimap_shapes_layer.events.set_data.connect(self._handle_minimap_moved)

        # Dragging the cursor updates the 3D view
        self.minimap_viewer.mouse_drag_callbacks.append(self._handle_cursor_drag)

        # Key bindings
        self.viewer.bind_key("Up", self._move_up)
        self.viewer.bind_key("Down", self._move_down)
        self.minimap_viewer.bind_key("Up", self._move_up)
        self.minimap_viewer.bind_key("Down", self._move_down)

        # Setup layer callbacks
        self.subscribed_layers = []
        self.minimap_viewer.layers.events.inserted.connect(
            lambda e: e.value.events.name.connect(self._on_layer_change)
        )
        self.viewer.layers.events.inserted.connect(self._on_layer_change)
        self.viewer.layers.events.removed.connect(self._on_layer_change)
        self._on_layer_change(None)

        self._update_view()
        self.viewer.reset_view()

    def _on_layer_change(self, e):
        for layer in self.viewer.layers:
            if isinstance(layer, napari.layers.Image) or isinstance(
                layer, napari.layers.Labels
            ):
                self.subscribed_layers.append((layer, layer.data))

    def _move_up(self, *args, **kwargs):
        """Moves the 3D view up in Z by 1 pixel."""
        self.center_loc[0] += 1
        self._update_subscribed_layers()

    def _move_down(self, *args, **kwargs):
        """Moves the 3D view down in Z by 1 pixel."""
        self.center_loc[0] -= 1
        self._update_subscribed_layers()

    def _handle_cursor_drag(self, source, event):
        dy = self.y_chunk_spinbox.value()
        dx = self.x_chunk_spinbox.value()

        bbox_data = self.minimap_shapes_layer.data[0]
        y0, x0 = bbox_data[0]  # Top left corner
        y1, x1 = bbox_data[2]  # Bottom right corner

        new_cy, new_cx = event.position

        tol = 3 # px
        if (np.abs(new_cy - y0) < tol) | \
            (np.abs(new_cy - y1) < tol) | \
            (np.abs(new_cx - x0) < tol) | \
            (np.abs(new_cx - x1) < tol):
            return

        self.center_loc[1] = new_cy
        self.center_loc[2] = new_cx

        y0 = new_cy - dy // 2
        x0 = new_cx - dx // 2
        y1 = new_cy + dy // 2
        x1 = new_cx + dx // 2 

        # Move the bounding box
        self.minimap_shapes_layer.data = np.array([[y0, x0], [y1, x1]])

    def _handle_minimap_moved(self, e):
        """Callback of the minimap set_data event."""
        bbox_data = e.source.data[0]
        y0, x0 = bbox_data[0]  # Top left corner
        y1, x1 = bbox_data[2]  # Bottom right corner

        dx = int(x1 - x0)
        dy = int(y1 - y0)

        self.center_loc[1] = y0 + dy // 2
        self.center_loc[2] = x0 + dx // 2

        # Disconnect the events to avoid crashing the GUI from recursive calls
        self.y_chunk_spinbox.valueChanged.disconnect(self._update_view)
        self.x_chunk_spinbox.valueChanged.disconnect(self._update_view)
        self.y_chunk_spinbox.valueChanged.connect(self._update_subscribed_layers)
        self.x_chunk_spinbox.valueChanged.connect(self._update_subscribed_layers)
        self.y_chunk_spinbox.setValue(dy)
        self.x_chunk_spinbox.setValue(dx)
        self.y_chunk_spinbox.valueChanged.connect(self._update_view)
        self.x_chunk_spinbox.valueChanged.connect(self._update_view)
        self.y_chunk_spinbox.valueChanged.disconnect(self._update_subscribed_layers)
        self.x_chunk_spinbox.valueChanged.disconnect(self._update_subscribed_layers)

        self._update_subscribed_layers()

    @property
    def chunk_shape(self):
        cz = self.z_chunk_spinbox.value()
        cy = self.y_chunk_spinbox.value()
        cx = self.x_chunk_spinbox.value()
        return (cz, cy, cx)

    def _update_view(self):
        self._update_subscribed_layers()
        self._update_minimap_bbox()

    def _update_subscribed_layers(self):
        for layer, layer_data in self.subscribed_layers:
            layer.data = get_image_chunk(
                layer_data, center_loc=self.center_loc, chunk_shape=self.chunk_shape
            )

    def _update_minimap_bbox(self):
        self.minimap_shapes_layer.data = get_bbox_location(
            self.image_shape, center_loc=self.center_loc, chunk_shape=self.chunk_shape
        )[1:].T[::-1]
z_max_proj = da.max(img, axis=0).compute()
z_max_proj_seg = da.max(seg, axis=0).compute()

viewer = napari.Viewer()
viewer.add_image(img, contrast_limits=[0, 0.5], colormap="viridis")
viewer.add_labels(seg)

viewer_model = ViewerModel(title="Max projection (Z)")
extra_viewer = QtViewerWrap(viewer, viewer_model)
viewer_model.add_image(z_max_proj, contrast_limits=[0, 1], colormap="viridis")
viewer_model.add_labels(z_max_proj_seg, opacity=0.5)

extra_viewer_widget = ExtraViewerWidget(extra_viewer)
viewer.window.add_dock_widget(extra_viewer_widget, name="Max projection (Z)")

image_browser = BigBrowserWidget(
    napari_viewer=viewer,
    minimap_viewer=viewer_model,
    image_shape=img.shape
)
viewer.window.add_dock_widget(image_browser, name="Volume subset size", area="left")
viewer.dims.ndisplay = 3
/home/wittwer/miniconda3/envs/image-analysis-field-guide/lib/python3.9/site-packages/napari/plugins/_plugin_manager.py:555: UserWarning: Plugin 'napari_skimage_regionprops2' has already registered a function widget 'duplicate current frame' which has now been overwritten
  warn(message=warn_message)