Coverage for qubalab/images/image_server.py: 91%

134 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-10-07 15:29 +0000

1import numpy as np 

2import dask.array as da 

3from dask.delayed import delayed 

4from dask_image import ndinterp 

5import warnings 

6from typing import Union, Iterable, Optional, Tuple 

7from abc import ABC, abstractmethod 

8from PIL import Image 

9from .region_2d import Region2D 

10from .metadata.image_metadata import ImageMetadata 

11 

12 

13class ImageServer(ABC): 

14 """ 

15 An abtract class to read pixels and metadata of an image. 

16 

17 An image server must be closed (see the close() function) once no longer used. 

18 """ 

19 

20 def __init__(self, resize_method: Image.Resampling = Image.Resampling.BICUBIC): 

21 """ 

22 :param resize_method: the resampling method to use when resizing the image for downsampling. Bicubic by default 

23 """ 

24 super().__init__() 

25 self._metadata = None 

26 self._resize_method = resize_method 

27 

28 @property 

29 def metadata(self) -> ImageMetadata: 

30 """ 

31 The image metadata. 

32 """ 

33 if self._metadata is None: 

34 self._metadata = self._build_metadata() 

35 return self._metadata 

36 

37 def read_region( 

38 self, 

39 downsample: float = 1.0, 

40 region: Optional[Union[Region2D, tuple[int, ...]]] = None, 

41 x: int = 0, 

42 y: int = 0, 

43 width: int = -1, 

44 height: int = -1, 

45 z: int = 0, 

46 t: int = 0, 

47 ) -> Union[np.ndarray, Image.Image]: 

48 """ 

49 Read pixels from any arbitrary image region, at any resolution determined by the downsample. 

50 

51 This method can be called in one of two ways: passing a region (as a Region2D object or a tuple of integers), 

52 or passing x, y, width, height, z and t parameters separately. The latter can be more convenient and readable 

53 when calling interactively, without the need to create a region object. 

54 If a region is passed, the other parameters (except for the downsample) are ignored. 

55 

56 Important: coordinates and width/height values are given in the coordinate space of the full-resolution image, 

57 and the downsample is applied before reading the region. This means that, except when the downsample is 1.0, 

58 the width and height of the returned image will usually be different from the width and height passed as parameters. 

59 

60 :param downsample: the downsample to use 

61 :param region: a Region2D object or a tuple of integers (x, y, width, height, z, t) 

62 :param x: the x coordinate of the region to read 

63 :param y: the y coordinate of the region to read 

64 :param width: the width of the region to read 

65 :param height: the height of the region to read 

66 :param z: the z index of the region to read 

67 :param t: the t index of the region to read 

68 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region. 

69 The [c, y, x] index of the returned array returns the channel of index c of the 

70 pixel located at coordinates [x, y] on the image 

71 :raises ValueError: when the region to read is not specified 

72 """ 

73 if region is None: 

74 region = Region2D(x=x, y=y, width=width, height=height, z=z, t=t) 

75 elif isinstance(region, tuple): 

76 # If we have a tuple, use it along with the downsample if available 

77 region = Region2D(*region) 

78 if not isinstance(region, Region2D): 

79 raise ValueError("No valid region provided to read_region method") 

80 

81 # Fix negative values for width or height 

82 if region.width < 0 or region.height < 0: 

83 w = region.width if region.width >= 0 else self.metadata.width - region.x 

84 h = region.height if region.height >= 0 else self.metadata.height - region.y 

85 region = Region2D( 

86 x=region.x, y=region.y, width=w, height=h, z=region.z, t=region.t 

87 ) 

88 

89 level = ImageServer._get_level(self.metadata.downsamples, downsample) 

90 level_downsample = self.metadata.downsamples[level] 

91 image = self._read_block( 

92 level, region.downsample_region(downsample=level_downsample) 

93 ) 

94 

95 if downsample == level_downsample: 

96 return image 

97 else: 

98 target_size = ( 

99 round(region.width / downsample), 

100 round(region.height / downsample), 

101 ) 

102 return self._resize(image, target_size, self._resize_method) 

103 

104 def level_to_dask( 

105 self, level: int = 0, chunk_width: int = 1024, chunk_height: int = 1024 

106 ) -> da.Array: 

107 """ 

108 Return a dask array representing a single resolution of the image. 

109 

110 Pixels of the returned array can be accessed with the following order: 

111 (t, c, z, y, x). There may be less dimensions for simple images: for 

112 example, an image with a single timepoint and a single z-slice will 

113 return an array of dimensions (c, y, x). However, there will always be 

114 dimensions x and y, even if they have a size of 1. 

115 

116 Subclasses of ImageServer may override this function if they can provide 

117 a faster implementation. 

118 

119 :param level: the pyramid level (0 is full resolution). Must be less than the number 

120 of resolutions of the image 

121 :param chunk_width: the image will be read chunk by chunk. This parameter specifies the 

122 size of the chunks on the x-axis 

123 :param chunk_height: the size of the chunks on the y-axis 

124 :returns: a dask array containing all pixels of the provided level 

125 :raises ValueError: when level is not valid 

126 """ 

127 if level < 0 or level >= self.metadata.n_resolutions: 

128 raise ValueError( 

129 "The provided level ({0}) is outside of the valid range ([0, {1}])".format( 

130 level, self.metadata.n_resolutions - 1 

131 ) 

132 ) 

133 

134 ts = [] 

135 for t in range(self.metadata.n_timepoints): 

136 zs = [] 

137 for z in range(self.metadata.n_z_slices): 

138 xs = [] 

139 for x in range(0, self.metadata.shapes[level].x, chunk_width): 

140 ys = [] 

141 for y in range(0, self.metadata.shapes[level].y, chunk_height): 

142 width = min(chunk_width, self.metadata.shapes[level].x - x) 

143 height = min(chunk_height, self.metadata.shapes[level].y - y) 

144 

145 ys.append( 

146 da.from_delayed( 

147 delayed(self._read_block)( 

148 level, Region2D(x, y, width, height, z, t) 

149 ), 

150 shape=(self.metadata.n_channels, height, width), 

151 dtype=self.metadata.dtype, 

152 ) 

153 ) 

154 xs.append(da.concatenate(ys, axis=1)) 

155 zs.append(da.concatenate(xs, axis=2)) 

156 ts.append(da.stack(zs)) 

157 image = da.stack(ts) 

158 

159 # Swap channels and z-stacks axis 

160 image = da.swapaxes(image, 1, 2) 

161 

162 # Remove axis of length 1 

163 axes_to_squeeze = [] 

164 if self.metadata.n_timepoints == 1: 

165 axes_to_squeeze.append(0) 

166 if self.metadata.n_channels == 1: 

167 axes_to_squeeze.append(1) 

168 if self.metadata.n_z_slices == 1: 

169 axes_to_squeeze.append(2) 

170 image = da.squeeze(image, tuple(axes_to_squeeze)) 

171 

172 return image 

173 

174 def to_dask( 

175 self, downsample: Optional[Union[float, Iterable[float]]] = None 

176 ) -> Union[da.Array, tuple[da.Array, ...]]: 

177 """ 

178 Convert this image to one or more dask arrays, at any arbitary downsample factor. 

179 

180 It turns out that requesting at an arbitrary downsample level is very slow - currently, all 

181 pixels are requested upon first compute (even for a small region), and then resized. 

182 Prefer using ImageServer.level_to_dask() instead. 

183 

184 :param downsample: the downsample factor to use, or a list of downsample factors to use. If None, all available resolutions will be used 

185 :return: a dask array or tuple of dask arrays, depending upon whether one or more downsample factors are required 

186 """ 

187 

188 if isinstance(downsample, Iterable): 

189 return tuple(self._to_dask_impl(downsample=d) for d in downsample) 

190 

191 if downsample is None: 

192 if self.metadata.n_resolutions == 1: 

193 return self.level_to_dask(level=0) 

194 else: 

195 return tuple( 

196 [ 

197 self.level_to_dask(level=level) 

198 for level in range(self.metadata.n_resolutions) 

199 ] 

200 ) 

201 return self._to_dask_impl(downsample=downsample) 

202 

203 def _to_dask_impl(self, downsample: float) -> da.Array: 

204 

205 level = ImageServer._get_level(self.metadata.downsamples, downsample) 

206 array = self.level_to_dask(level=level) 

207 

208 rescale = downsample / self.metadata.downsamples[level] 

209 input_width = array.shape[-1] 

210 input_height = array.shape[-2] 

211 output_width = int(round(input_width / rescale)) 

212 output_height = int(round(input_height / rescale)) 

213 if input_width == output_width and input_height == output_height: 

214 return array 

215 

216 # Couldn't find an easy resizing method for dask arrays... so we try this instead 

217 # TODO: Urgently need something better! Performance is terrible for large images - all pixels requested 

218 # upon first compute (even for a small region), and then resized. This is not scalable. 

219 if array.size > 10000: 

220 warnings.warn( 

221 "Warning - calling affine_transform on a large dask array can be *very* slow" 

222 ) 

223 

224 transform = np.eye(array.ndim) 

225 transform[array.ndim - 1, array.ndim - 1] = rescale 

226 transform[array.ndim - 2, array.ndim - 2] = rescale 

227 output_shape = list(array.shape) 

228 output_shape[-1] = output_width 

229 output_shape[-2] = output_height 

230 

231 return ndinterp.affine_transform( 

232 array, transform, output_shape=tuple(output_shape) 

233 ) 

234 

235 @abstractmethod 

236 def close(self): 

237 """ 

238 Close this image server. 

239 

240 This should be called whenever this server is not used anymore. 

241 """ 

242 pass 

243 

244 @abstractmethod 

245 def _build_metadata(self) -> ImageMetadata: 

246 """ 

247 Create metadata for the current image. 

248 

249 :return: the metadata of the image 

250 """ 

251 pass 

252 

253 @abstractmethod 

254 def _read_block(self, level: int, region: Region2D) -> np.ndarray: 

255 """ 

256 Read a block of pixels from a specific level. 

257 

258 Coordinates are provided in the coordinate space of the level, NOT the full-resolution image. 

259 This means that the returned image should have the width and height specified. 

260 

261 :param level: the pyramidal level to read from 

262 :param region: the region to read 

263 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region. 

264 The [c, y, x] index of the returned array returns the channel of index c of the 

265 pixel located at coordinates [x, y] on the image 

266 """ 

267 pass 

268 

269 @staticmethod 

270 def _get_level( 

271 all_downsamples: Tuple[float, ...], downsample: float, abs_tol=1e-3 

272 ) -> int: 

273 """ 

274 Get the level (index) from the image downsamples that is best for fulfilling an image region request. 

275 

276 This is the index of the entry in self.downsamples that either (almost) matches the requested downsample, 

277 or relates to the next highest resolution image (so that any required scaling is to reduce resolution). 

278 

279 :param downsample: the requested downsample value 

280 :param abs_tol: absolute tolerance when comparing downsample values; this allows for a stored downsample 

281 value to be slightly off due to rounding 

282 (e.g. requesting 4.0 would match a level 4 +/- abs_tol) 

283 :return: the level that is best for fulfilling an image region request at the specified downsample 

284 """ 

285 if len(all_downsamples) == 1 or downsample <= all_downsamples[0]: 

286 return 0 

287 elif downsample >= all_downsamples[-1]: 

288 return len(all_downsamples) - 1 

289 else: 

290 # Allow a little bit of a tolerance because downsamples are often calculated 

291 # by rounding the ratio of image dimensions... and can end up a little bit off 

292 for level, d in reversed(list(enumerate(all_downsamples))): 

293 if downsample >= d - abs_tol: 

294 return level 

295 return 0 

296 

297 @staticmethod 

298 def _resize( 

299 image: Union[np.ndarray, Image.Image], 

300 target_size: tuple[int, int], 

301 resample: int = Image.Resampling.BICUBIC, 

302 ) -> Union[np.ndarray, Image.Image]: 

303 """ 

304 Resize an image to a target size. 

305 

306 This uses the implementation from PIL. 

307 

308 :param image: the image to resize. Either a 3-dimensional numpy array with dimensions (c, y, x) 

309 or a PIL image 

310 :param target_size: target size in (width, height) format 

311 :param resample: resampling mode to use, by default bicubic 

312 :return: the resized image, either a 3-dimensional numpy array with dimensions (c, y, x) or a PIL image 

313 """ 

314 

315 if ImageServer._get_size(image) == target_size: 

316 return image 

317 

318 # If we have a PIL image, just resize normally 

319 if isinstance(image, Image.Image): 

320 return image.resize(size=target_size, resample=resample) 

321 # If we have NumPy, do one channel at a time 

322 else: 

323 if image.ndim == 2: 

324 if image.dtype in [np.uint8, np.float32]: 

325 pilImage = Image.fromarray(image) 

326 elif np.issubdtype(image.dtype, np.integer): 

327 pilImage = Image.fromarray(image.astype(np.int32), mode="I") 

328 elif np.issubdtype(image.dtype, np.bool_): 

329 pilImage = Image.fromarray(image, "1") 

330 else: 

331 pilImage = Image.fromarray(image.astype(np.float32), mode="F") 

332 pilImage = ImageServer._resize( 

333 pilImage, target_size=target_size, resample=resample 

334 ) 

335 return np.asarray(pilImage).astype(image.dtype) 

336 else: 

337 arrs = [ 

338 ImageServer._resize( 

339 image[c, :, :], target_size=target_size, resample=resample 

340 ) 

341 for c in range(image.shape[0]) 

342 ] 

343 arrs = [np.array(x) if isinstance(x, Image.Image) else x for x in arrs] 

344 return np.stack(arrs) 

345 

346 @staticmethod 

347 def _get_size(image: Union[np.ndarray, Image.Image]): 

348 """ 

349 Get the size of an image as a two-element tuple (width, height). 

350 

351 :param image: the image whose size should be computed. Either a 3-dimensional numpy array with dimensions (c, y, x) 

352 or a PIL image 

353 """ 

354 return image.size if isinstance(image, Image.Image) else image.shape[1:][::-1]