Coverage for qubalab/images/image_server.py: 89%
130 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-31 11:24 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-31 11:24 +0000
1import numpy as np
2import dask.array as da
3import dask
4from dask_image import ndinterp
5import warnings
6from typing import Union, Iterable
7from abc import ABC, abstractmethod
8from PIL import Image
9from .region_2d import Region2D
10from .metadata.image_metadata import ImageMetadata
13class ImageServer(ABC):
14 """
15 An abtract class to read pixels and metadata of an image.
17 An image server must be closed (see the close() function) once no longer used.
18 """
20 def __init__(self, resize_method: Image.Resampling = Image.Resampling.BICUBIC):
21 """
22 :param resize_method: the resampling method to use when resizing the image for downsampling. Bicubic by default
23 """
24 super().__init__()
25 self._metadata = None
26 self._resize_method = resize_method
28 @property
29 def metadata(self) -> ImageMetadata:
30 """
31 The image metadata.
32 """
33 if self._metadata is None:
34 self._metadata = self._build_metadata()
35 return self._metadata
37 def read_region(
38 self,
39 downsample: float = 1.0,
40 region: Union[Region2D, tuple[int, ...]] = None,
41 x: int = 0,
42 y: int = 0,
43 width: int = -1,
44 height: int = -1,
45 z: int = 0,
46 t: int = 0
47 ) -> np.ndarray:
48 """
49 Read pixels from any arbitrary image region, at any resolution determined by the downsample.
51 This method can be called in one of two ways: passing a region (as a Region2D object or a tuple of integers),
52 or passing x, y, width, height, z and t parameters separately. The latter can be more convenient and readable
53 when calling interactively, without the need to create a region object.
54 If a region is passed, the other parameters (except for the downsample) are ignored.
56 Important: coordinates and width/height values are given in the coordinate space of the full-resolution image,
57 and the downsample is applied before reading the region. This means that, except when the downsample is 1.0,
58 the width and height of the returned image will usually be different from the width and height passed as parameters.
60 :param downsample: the downsample to use
61 :param region: a Region2D object or a tuple of integers (x, y, width, height, z, t)
62 :param x: the x coordinate of the region to read
63 :param y: the y coordinate of the region to read
64 :param width: the width of the region to read
65 :param height: the height of the region to read
66 :param z: the z index of the region to read
67 :param t: the t index of the region to read
68 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region.
69 The [c, y, x] index of the returned array returns the channel of index c of the
70 pixel located at coordinates [x, y] on the image
71 :raises ValueError: when the region to read is not specified
72 """
73 if region is None:
74 region = Region2D(x=x, y=y, width=width, height=height, z=z, t=t)
75 elif isinstance(region, tuple):
76 # If we have a tuple, use it along with the downsample if available
77 region = Region2D(*region)
78 if not isinstance(region, Region2D):
79 raise ValueError('No valid region provided to read_region method')
81 # Fix negative values for width or height
82 if region.width < 0 or region.height < 0:
83 w = region.width if region.width >= 0 else self.metadata.width - region.x
84 h = region.height if region.height >= 0 else self.metadata.height - region.y
85 region = Region2D(x=region.x, y=region.y, width=w, height=h, z=region.z, t=region.t)
87 level = ImageServer._get_level(self.metadata.downsamples, downsample)
88 level_downsample = self.metadata.downsamples[level]
89 image = self._read_block(level, region.downsample_region(downsample=level_downsample))
91 if downsample == level_downsample:
92 return image
93 else:
94 target_size = (round(region.width / downsample), round(region.height / downsample))
95 return self._resize(image, target_size, self._resize_method)
97 def level_to_dask(self, level: int = 0, chunk_width: int = 1024, chunk_height: int = 1024) -> da.Array:
98 """
99 Return a dask array representing a single resolution of the image.
101 Pixels of the returned array can be accessed with the following order:
102 (t, c, z, y, x). There may be less dimensions for simple images: for
103 example, an image with a single timepoint and a single z-slice will
104 return an array of dimensions (c, y, x). However, there will always be
105 dimensions x and y, even if they have a size of 1.
107 Subclasses of ImageServer may override this function if they can provide
108 a faster implementation.
110 :param level: the pyramid level (0 is full resolution). Must be less than the number
111 of resolutions of the image
112 :param chunk_width: the image will be read chunk by chunk. This parameter specifies the
113 size of the chunks on the x-axis
114 :param chunk_height: the size of the chunks on the y-axis
115 :returns: a dask array containing all pixels of the provided level
116 :raises ValueError: when level is not valid
117 """
118 if level < 0 or level >= self.metadata.n_resolutions:
119 raise ValueError(
120 "The provided level ({0}) is outside of the valid range ([0, {1}])".format(level, self.metadata.n_resolutions - 1)
121 )
123 ts = []
124 for t in range(self.metadata.n_timepoints):
125 zs = []
126 for z in range(self.metadata.n_z_slices):
127 xs = []
128 for x in range(0, self.metadata.shapes[level].x, chunk_width):
129 ys = []
130 for y in range(0, self.metadata.shapes[level].y, chunk_height):
131 width = min(chunk_width, self.metadata.shapes[level].x - x)
132 height = min(chunk_height, self.metadata.shapes[level].y - y)
134 ys.append(da.from_delayed(
135 dask.delayed(self._read_block)(level, Region2D(x, y, width, height, z, t)),
136 shape=(
137 self.metadata.n_channels,
138 height,
139 width
140 ),
141 dtype=self.metadata.dtype
142 ))
143 xs.append(da.concatenate(ys, axis=1))
144 zs.append(da.concatenate(xs, axis=2))
145 ts.append(da.stack(zs))
146 image = da.stack(ts)
148 # Swap channels and z-stacks axis
149 image = da.swapaxes(image, 1, 2)
151 # Remove axis of length 1
152 axes_to_squeeze = []
153 if self.metadata.n_timepoints == 1:
154 axes_to_squeeze.append(0)
155 if self.metadata.n_channels == 1:
156 axes_to_squeeze.append(1)
157 if self.metadata.n_z_slices == 1:
158 axes_to_squeeze.append(2)
159 image = da.squeeze(image, tuple(axes_to_squeeze))
161 return image
163 def to_dask(self, downsample: Union[float, Iterable[float]] = None) -> Union[da.Array, tuple[da.Array, ...]]:
164 """
165 Convert this image to one or more dask arrays, at any arbitary downsample factor.
167 It turns out that requesting at an arbitrary downsample level is very slow - currently, all
168 pixels are requested upon first compute (even for a small region), and then resized.
169 Prefer using ImageServer.level_to_dask() instead.
171 :param downsample: the downsample factor to use, or a list of downsample factors to use. If None, all available resolutions will be used
172 :return: a dask array or tuple of dask arrays, depending upon whether one or more downsample factors are required
173 """
175 if downsample is None:
176 if self.n_resolutions == 1:
177 return self.level_to_dask(level=0)
178 else:
179 return tuple([self.level_to_dask(level=level) for level in range(self.metadata.n_resolutions)])
181 if isinstance(downsample, Iterable):
182 return tuple([self.to_dask(downsample=float(d)) for d in downsample])
184 level = ImageServer._get_level(self.metadata.downsamples, downsample)
185 array = self.level_to_dask(level=level)
187 rescale = downsample / self.metadata.downsamples[level]
188 input_width = array.shape[-1]
189 input_height = array.shape[-2]
190 output_width = int(round(input_width / rescale))
191 output_height = int(round(input_height / rescale))
192 if input_width == output_width and input_height == output_height:
193 return array
195 # Couldn't find an easy resizing method for dask arrays... so we try this instead
196 # TODO: Urgently need something better! Performance is terrible for large images - all pixels requested
197 # upon first compute (even for a small region), and then resized. This is not scalable.
198 if array.size > 10000:
199 warnings.warn('Warning - calling affine_transform on a large dask array can be *very* slow')
201 transform = np.eye(array.ndim)
202 transform[array.ndim-1, array.ndim-1] = rescale
203 transform[array.ndim-2, array.ndim-2] = rescale
204 output_shape = list(array.shape)
205 output_shape[-1] = output_width
206 output_shape[-2] = output_height
208 return ndinterp.affine_transform(array, transform, output_shape=tuple(output_shape))
210 @abstractmethod
211 def close(self):
212 """
213 Close this image server.
215 This should be called whenever this server is not used anymore.
216 """
217 pass
219 @abstractmethod
220 def _build_metadata(self) -> ImageMetadata:
221 """
222 Create metadata for the current image.
224 :return: the metadata of the image
225 """
226 pass
228 @abstractmethod
229 def _read_block(self, level: int, region: Region2D) -> np.ndarray:
230 """
231 Read a block of pixels from a specific level.
233 Coordinates are provided in the coordinate space of the level, NOT the full-resolution image.
234 This means that the returned image should have the width and height specified.
236 :param level: the pyramidal level to read from
237 :param region: the region to read
238 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region.
239 The [c, y, x] index of the returned array returns the channel of index c of the
240 pixel located at coordinates [x, y] on the image
241 """
242 pass
244 @staticmethod
245 def _get_level(all_downsamples: tuple[float], downsample: float, abs_tol=1e-3) -> int:
246 """
247 Get the level (index) from the image downsamples that is best for fulfilling an image region request.
249 This is the index of the entry in self.downsamples that either (almost) matches the requested downsample,
250 or relates to the next highest resolution image (so that any required scaling is to reduce resolution).
252 :param downsample: the requested downsample value
253 :param abs_tol: absolute tolerance when comparing downsample values; this allows for a stored downsample
254 value to be slightly off due to rounding
255 (e.g. requesting 4.0 would match a level 4 +/- abs_tol)
256 :return: the level that is best for fulfilling an image region request at the specified downsample
257 """
258 if len(all_downsamples) == 1 or downsample <= all_downsamples[0]:
259 return 0
260 elif downsample >= all_downsamples[-1]:
261 return len(all_downsamples) - 1
262 else:
263 # Allow a little bit of a tolerance because downsamples are often calculated
264 # by rounding the ratio of image dimensions... and can end up a little bit off
265 for level, d in reversed(list(enumerate(all_downsamples))):
266 if downsample >= d - abs_tol:
267 return level
268 return 0
270 @staticmethod
271 def _resize(image: Union[np.ndarray, Image.Image], target_size: tuple[int, int], resample: int = Image.Resampling.BICUBIC) -> Union[np.ndarray, Image.Image]:
272 """
273 Resize an image to a target size.
275 This uses the implementation from PIL.
277 :param image: the image to resize. Either a 3-dimensional numpy array with dimensions (c, y, x)
278 or a PIL image
279 :param target_size: target size in (width, height) format
280 :param resample: resampling mode to use, by default bicubic
281 :return: the resized image, either a 3-dimensional numpy array with dimensions (c, y, x) or a PIL image
282 """
284 if ImageServer._get_size(image) == target_size:
285 return image
287 # If we have a PIL image, just resize normally
288 if isinstance(image, Image.Image):
289 return image.resize(size=target_size, resample=resample)
290 # If we have NumPy, do one channel at a time
291 else:
292 if image.ndim == 2:
293 if image.dtype in [np.uint8, np.float32]:
294 pilImage = Image.fromarray(image)
295 elif np.issubdtype(image.dtype, np.integer):
296 pilImage = Image.fromarray(image.astype(np.int32), mode='I')
297 elif np.issubdtype(image.dtype, np.bool_):
298 pilImage = Image.fromarray(image, "1")
299 else:
300 pilImage = Image.fromarray(image.astype(np.float32), mode='F')
301 pilImage = ImageServer._resize(pilImage, target_size=target_size, resample=resample)
302 return np.asarray(pilImage).astype(image.dtype)
303 else:
304 return np.stack([
305 ImageServer._resize(image[c, :, :], target_size=target_size, resample=resample) for c in range(image.shape[0])
306 ])
308 @staticmethod
309 def _get_size(image: Union[np.ndarray, Image.Image]):
310 """
311 Get the size of an image as a two-element tuple (width, height).
313 :param image: the image whose size should be computed. Either a 3-dimensional numpy array with dimensions (c, y, x)
314 or a PIL image
315 """
316 return image.size if isinstance(image, Image.Image) else image.shape[1:][::-1]