Coverage for qubalab/images/image_server.py: 91%
134 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-10-07 15:29 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-10-07 15:29 +0000
1import numpy as np
2import dask.array as da
3from dask.delayed import delayed
4from dask_image import ndinterp
5import warnings
6from typing import Union, Iterable, Optional, Tuple
7from abc import ABC, abstractmethod
8from PIL import Image
9from .region_2d import Region2D
10from .metadata.image_metadata import ImageMetadata
13class ImageServer(ABC):
14 """
15 An abtract class to read pixels and metadata of an image.
17 An image server must be closed (see the close() function) once no longer used.
18 """
20 def __init__(self, resize_method: Image.Resampling = Image.Resampling.BICUBIC):
21 """
22 :param resize_method: the resampling method to use when resizing the image for downsampling. Bicubic by default
23 """
24 super().__init__()
25 self._metadata = None
26 self._resize_method = resize_method
28 @property
29 def metadata(self) -> ImageMetadata:
30 """
31 The image metadata.
32 """
33 if self._metadata is None:
34 self._metadata = self._build_metadata()
35 return self._metadata
37 def read_region(
38 self,
39 downsample: float = 1.0,
40 region: Optional[Union[Region2D, tuple[int, ...]]] = None,
41 x: int = 0,
42 y: int = 0,
43 width: int = -1,
44 height: int = -1,
45 z: int = 0,
46 t: int = 0,
47 ) -> Union[np.ndarray, Image.Image]:
48 """
49 Read pixels from any arbitrary image region, at any resolution determined by the downsample.
51 This method can be called in one of two ways: passing a region (as a Region2D object or a tuple of integers),
52 or passing x, y, width, height, z and t parameters separately. The latter can be more convenient and readable
53 when calling interactively, without the need to create a region object.
54 If a region is passed, the other parameters (except for the downsample) are ignored.
56 Important: coordinates and width/height values are given in the coordinate space of the full-resolution image,
57 and the downsample is applied before reading the region. This means that, except when the downsample is 1.0,
58 the width and height of the returned image will usually be different from the width and height passed as parameters.
60 :param downsample: the downsample to use
61 :param region: a Region2D object or a tuple of integers (x, y, width, height, z, t)
62 :param x: the x coordinate of the region to read
63 :param y: the y coordinate of the region to read
64 :param width: the width of the region to read
65 :param height: the height of the region to read
66 :param z: the z index of the region to read
67 :param t: the t index of the region to read
68 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region.
69 The [c, y, x] index of the returned array returns the channel of index c of the
70 pixel located at coordinates [x, y] on the image
71 :raises ValueError: when the region to read is not specified
72 """
73 if region is None:
74 region = Region2D(x=x, y=y, width=width, height=height, z=z, t=t)
75 elif isinstance(region, tuple):
76 # If we have a tuple, use it along with the downsample if available
77 region = Region2D(*region)
78 if not isinstance(region, Region2D):
79 raise ValueError("No valid region provided to read_region method")
81 # Fix negative values for width or height
82 if region.width < 0 or region.height < 0:
83 w = region.width if region.width >= 0 else self.metadata.width - region.x
84 h = region.height if region.height >= 0 else self.metadata.height - region.y
85 region = Region2D(
86 x=region.x, y=region.y, width=w, height=h, z=region.z, t=region.t
87 )
89 level = ImageServer._get_level(self.metadata.downsamples, downsample)
90 level_downsample = self.metadata.downsamples[level]
91 image = self._read_block(
92 level, region.downsample_region(downsample=level_downsample)
93 )
95 if downsample == level_downsample:
96 return image
97 else:
98 target_size = (
99 round(region.width / downsample),
100 round(region.height / downsample),
101 )
102 return self._resize(image, target_size, self._resize_method)
104 def level_to_dask(
105 self, level: int = 0, chunk_width: int = 1024, chunk_height: int = 1024
106 ) -> da.Array:
107 """
108 Return a dask array representing a single resolution of the image.
110 Pixels of the returned array can be accessed with the following order:
111 (t, c, z, y, x). There may be less dimensions for simple images: for
112 example, an image with a single timepoint and a single z-slice will
113 return an array of dimensions (c, y, x). However, there will always be
114 dimensions x and y, even if they have a size of 1.
116 Subclasses of ImageServer may override this function if they can provide
117 a faster implementation.
119 :param level: the pyramid level (0 is full resolution). Must be less than the number
120 of resolutions of the image
121 :param chunk_width: the image will be read chunk by chunk. This parameter specifies the
122 size of the chunks on the x-axis
123 :param chunk_height: the size of the chunks on the y-axis
124 :returns: a dask array containing all pixels of the provided level
125 :raises ValueError: when level is not valid
126 """
127 if level < 0 or level >= self.metadata.n_resolutions:
128 raise ValueError(
129 "The provided level ({0}) is outside of the valid range ([0, {1}])".format(
130 level, self.metadata.n_resolutions - 1
131 )
132 )
134 ts = []
135 for t in range(self.metadata.n_timepoints):
136 zs = []
137 for z in range(self.metadata.n_z_slices):
138 xs = []
139 for x in range(0, self.metadata.shapes[level].x, chunk_width):
140 ys = []
141 for y in range(0, self.metadata.shapes[level].y, chunk_height):
142 width = min(chunk_width, self.metadata.shapes[level].x - x)
143 height = min(chunk_height, self.metadata.shapes[level].y - y)
145 ys.append(
146 da.from_delayed(
147 delayed(self._read_block)(
148 level, Region2D(x, y, width, height, z, t)
149 ),
150 shape=(self.metadata.n_channels, height, width),
151 dtype=self.metadata.dtype,
152 )
153 )
154 xs.append(da.concatenate(ys, axis=1))
155 zs.append(da.concatenate(xs, axis=2))
156 ts.append(da.stack(zs))
157 image = da.stack(ts)
159 # Swap channels and z-stacks axis
160 image = da.swapaxes(image, 1, 2)
162 # Remove axis of length 1
163 axes_to_squeeze = []
164 if self.metadata.n_timepoints == 1:
165 axes_to_squeeze.append(0)
166 if self.metadata.n_channels == 1:
167 axes_to_squeeze.append(1)
168 if self.metadata.n_z_slices == 1:
169 axes_to_squeeze.append(2)
170 image = da.squeeze(image, tuple(axes_to_squeeze))
172 return image
174 def to_dask(
175 self, downsample: Optional[Union[float, Iterable[float]]] = None
176 ) -> Union[da.Array, tuple[da.Array, ...]]:
177 """
178 Convert this image to one or more dask arrays, at any arbitary downsample factor.
180 It turns out that requesting at an arbitrary downsample level is very slow - currently, all
181 pixels are requested upon first compute (even for a small region), and then resized.
182 Prefer using ImageServer.level_to_dask() instead.
184 :param downsample: the downsample factor to use, or a list of downsample factors to use. If None, all available resolutions will be used
185 :return: a dask array or tuple of dask arrays, depending upon whether one or more downsample factors are required
186 """
188 if isinstance(downsample, Iterable):
189 return tuple(self._to_dask_impl(downsample=d) for d in downsample)
191 if downsample is None:
192 if self.metadata.n_resolutions == 1:
193 return self.level_to_dask(level=0)
194 else:
195 return tuple(
196 [
197 self.level_to_dask(level=level)
198 for level in range(self.metadata.n_resolutions)
199 ]
200 )
201 return self._to_dask_impl(downsample=downsample)
203 def _to_dask_impl(self, downsample: float) -> da.Array:
205 level = ImageServer._get_level(self.metadata.downsamples, downsample)
206 array = self.level_to_dask(level=level)
208 rescale = downsample / self.metadata.downsamples[level]
209 input_width = array.shape[-1]
210 input_height = array.shape[-2]
211 output_width = int(round(input_width / rescale))
212 output_height = int(round(input_height / rescale))
213 if input_width == output_width and input_height == output_height:
214 return array
216 # Couldn't find an easy resizing method for dask arrays... so we try this instead
217 # TODO: Urgently need something better! Performance is terrible for large images - all pixels requested
218 # upon first compute (even for a small region), and then resized. This is not scalable.
219 if array.size > 10000:
220 warnings.warn(
221 "Warning - calling affine_transform on a large dask array can be *very* slow"
222 )
224 transform = np.eye(array.ndim)
225 transform[array.ndim - 1, array.ndim - 1] = rescale
226 transform[array.ndim - 2, array.ndim - 2] = rescale
227 output_shape = list(array.shape)
228 output_shape[-1] = output_width
229 output_shape[-2] = output_height
231 return ndinterp.affine_transform(
232 array, transform, output_shape=tuple(output_shape)
233 )
235 @abstractmethod
236 def close(self):
237 """
238 Close this image server.
240 This should be called whenever this server is not used anymore.
241 """
242 pass
244 @abstractmethod
245 def _build_metadata(self) -> ImageMetadata:
246 """
247 Create metadata for the current image.
249 :return: the metadata of the image
250 """
251 pass
253 @abstractmethod
254 def _read_block(self, level: int, region: Region2D) -> np.ndarray:
255 """
256 Read a block of pixels from a specific level.
258 Coordinates are provided in the coordinate space of the level, NOT the full-resolution image.
259 This means that the returned image should have the width and height specified.
261 :param level: the pyramidal level to read from
262 :param region: the region to read
263 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region.
264 The [c, y, x] index of the returned array returns the channel of index c of the
265 pixel located at coordinates [x, y] on the image
266 """
267 pass
269 @staticmethod
270 def _get_level(
271 all_downsamples: Tuple[float, ...], downsample: float, abs_tol=1e-3
272 ) -> int:
273 """
274 Get the level (index) from the image downsamples that is best for fulfilling an image region request.
276 This is the index of the entry in self.downsamples that either (almost) matches the requested downsample,
277 or relates to the next highest resolution image (so that any required scaling is to reduce resolution).
279 :param downsample: the requested downsample value
280 :param abs_tol: absolute tolerance when comparing downsample values; this allows for a stored downsample
281 value to be slightly off due to rounding
282 (e.g. requesting 4.0 would match a level 4 +/- abs_tol)
283 :return: the level that is best for fulfilling an image region request at the specified downsample
284 """
285 if len(all_downsamples) == 1 or downsample <= all_downsamples[0]:
286 return 0
287 elif downsample >= all_downsamples[-1]:
288 return len(all_downsamples) - 1
289 else:
290 # Allow a little bit of a tolerance because downsamples are often calculated
291 # by rounding the ratio of image dimensions... and can end up a little bit off
292 for level, d in reversed(list(enumerate(all_downsamples))):
293 if downsample >= d - abs_tol:
294 return level
295 return 0
297 @staticmethod
298 def _resize(
299 image: Union[np.ndarray, Image.Image],
300 target_size: tuple[int, int],
301 resample: int = Image.Resampling.BICUBIC,
302 ) -> Union[np.ndarray, Image.Image]:
303 """
304 Resize an image to a target size.
306 This uses the implementation from PIL.
308 :param image: the image to resize. Either a 3-dimensional numpy array with dimensions (c, y, x)
309 or a PIL image
310 :param target_size: target size in (width, height) format
311 :param resample: resampling mode to use, by default bicubic
312 :return: the resized image, either a 3-dimensional numpy array with dimensions (c, y, x) or a PIL image
313 """
315 if ImageServer._get_size(image) == target_size:
316 return image
318 # If we have a PIL image, just resize normally
319 if isinstance(image, Image.Image):
320 return image.resize(size=target_size, resample=resample)
321 # If we have NumPy, do one channel at a time
322 else:
323 if image.ndim == 2:
324 if image.dtype in [np.uint8, np.float32]:
325 pilImage = Image.fromarray(image)
326 elif np.issubdtype(image.dtype, np.integer):
327 pilImage = Image.fromarray(image.astype(np.int32), mode="I")
328 elif np.issubdtype(image.dtype, np.bool_):
329 pilImage = Image.fromarray(image, "1")
330 else:
331 pilImage = Image.fromarray(image.astype(np.float32), mode="F")
332 pilImage = ImageServer._resize(
333 pilImage, target_size=target_size, resample=resample
334 )
335 return np.asarray(pilImage).astype(image.dtype)
336 else:
337 arrs = [
338 ImageServer._resize(
339 image[c, :, :], target_size=target_size, resample=resample
340 )
341 for c in range(image.shape[0])
342 ]
343 arrs = [np.array(x) if isinstance(x, Image.Image) else x for x in arrs]
344 return np.stack(arrs)
346 @staticmethod
347 def _get_size(image: Union[np.ndarray, Image.Image]):
348 """
349 Get the size of an image as a two-element tuple (width, height).
351 :param image: the image whose size should be computed. Either a 3-dimensional numpy array with dimensions (c, y, x)
352 or a PIL image
353 """
354 return image.size if isinstance(image, Image.Image) else image.shape[1:][::-1]