Coverage for qubalab/images/image_server.py: 89%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-31 11:24 +0000

1import numpy as np 

2import dask.array as da 

3import dask 

4from dask_image import ndinterp 

5import warnings 

6from typing import Union, Iterable 

7from abc import ABC, abstractmethod 

8from PIL import Image 

9from .region_2d import Region2D 

10from .metadata.image_metadata import ImageMetadata 

11 

12 

13class ImageServer(ABC): 

14 """ 

15 An abtract class to read pixels and metadata of an image. 

16 

17 An image server must be closed (see the close() function) once no longer used. 

18 """ 

19 

20 def __init__(self, resize_method: Image.Resampling = Image.Resampling.BICUBIC): 

21 """ 

22 :param resize_method: the resampling method to use when resizing the image for downsampling. Bicubic by default 

23 """ 

24 super().__init__() 

25 self._metadata = None 

26 self._resize_method = resize_method 

27 

28 @property 

29 def metadata(self) -> ImageMetadata: 

30 """ 

31 The image metadata. 

32 """ 

33 if self._metadata is None: 

34 self._metadata = self._build_metadata() 

35 return self._metadata 

36 

37 def read_region( 

38 self, 

39 downsample: float = 1.0, 

40 region: Union[Region2D, tuple[int, ...]] = None, 

41 x: int = 0, 

42 y: int = 0, 

43 width: int = -1, 

44 height: int = -1, 

45 z: int = 0, 

46 t: int = 0 

47 ) -> np.ndarray: 

48 """ 

49 Read pixels from any arbitrary image region, at any resolution determined by the downsample. 

50 

51 This method can be called in one of two ways: passing a region (as a Region2D object or a tuple of integers), 

52 or passing x, y, width, height, z and t parameters separately. The latter can be more convenient and readable 

53 when calling interactively, without the need to create a region object. 

54 If a region is passed, the other parameters (except for the downsample) are ignored. 

55 

56 Important: coordinates and width/height values are given in the coordinate space of the full-resolution image, 

57 and the downsample is applied before reading the region. This means that, except when the downsample is 1.0, 

58 the width and height of the returned image will usually be different from the width and height passed as parameters. 

59 

60 :param downsample: the downsample to use 

61 :param region: a Region2D object or a tuple of integers (x, y, width, height, z, t) 

62 :param x: the x coordinate of the region to read 

63 :param y: the y coordinate of the region to read 

64 :param width: the width of the region to read 

65 :param height: the height of the region to read 

66 :param z: the z index of the region to read 

67 :param t: the t index of the region to read 

68 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region. 

69 The [c, y, x] index of the returned array returns the channel of index c of the 

70 pixel located at coordinates [x, y] on the image 

71 :raises ValueError: when the region to read is not specified 

72 """ 

73 if region is None: 

74 region = Region2D(x=x, y=y, width=width, height=height, z=z, t=t) 

75 elif isinstance(region, tuple): 

76 # If we have a tuple, use it along with the downsample if available 

77 region = Region2D(*region) 

78 if not isinstance(region, Region2D): 

79 raise ValueError('No valid region provided to read_region method') 

80 

81 # Fix negative values for width or height 

82 if region.width < 0 or region.height < 0: 

83 w = region.width if region.width >= 0 else self.metadata.width - region.x 

84 h = region.height if region.height >= 0 else self.metadata.height - region.y 

85 region = Region2D(x=region.x, y=region.y, width=w, height=h, z=region.z, t=region.t) 

86 

87 level = ImageServer._get_level(self.metadata.downsamples, downsample) 

88 level_downsample = self.metadata.downsamples[level] 

89 image = self._read_block(level, region.downsample_region(downsample=level_downsample)) 

90 

91 if downsample == level_downsample: 

92 return image 

93 else: 

94 target_size = (round(region.width / downsample), round(region.height / downsample)) 

95 return self._resize(image, target_size, self._resize_method) 

96 

97 def level_to_dask(self, level: int = 0, chunk_width: int = 1024, chunk_height: int = 1024) -> da.Array: 

98 """ 

99 Return a dask array representing a single resolution of the image. 

100 

101 Pixels of the returned array can be accessed with the following order: 

102 (t, c, z, y, x). There may be less dimensions for simple images: for 

103 example, an image with a single timepoint and a single z-slice will 

104 return an array of dimensions (c, y, x). However, there will always be 

105 dimensions x and y, even if they have a size of 1. 

106 

107 Subclasses of ImageServer may override this function if they can provide 

108 a faster implementation. 

109 

110 :param level: the pyramid level (0 is full resolution). Must be less than the number 

111 of resolutions of the image 

112 :param chunk_width: the image will be read chunk by chunk. This parameter specifies the 

113 size of the chunks on the x-axis 

114 :param chunk_height: the size of the chunks on the y-axis 

115 :returns: a dask array containing all pixels of the provided level 

116 :raises ValueError: when level is not valid 

117 """ 

118 if level < 0 or level >= self.metadata.n_resolutions: 

119 raise ValueError( 

120 "The provided level ({0}) is outside of the valid range ([0, {1}])".format(level, self.metadata.n_resolutions - 1) 

121 ) 

122 

123 ts = [] 

124 for t in range(self.metadata.n_timepoints): 

125 zs = [] 

126 for z in range(self.metadata.n_z_slices): 

127 xs = [] 

128 for x in range(0, self.metadata.shapes[level].x, chunk_width): 

129 ys = [] 

130 for y in range(0, self.metadata.shapes[level].y, chunk_height): 

131 width = min(chunk_width, self.metadata.shapes[level].x - x) 

132 height = min(chunk_height, self.metadata.shapes[level].y - y) 

133 

134 ys.append(da.from_delayed( 

135 dask.delayed(self._read_block)(level, Region2D(x, y, width, height, z, t)), 

136 shape=( 

137 self.metadata.n_channels, 

138 height, 

139 width 

140 ), 

141 dtype=self.metadata.dtype 

142 )) 

143 xs.append(da.concatenate(ys, axis=1)) 

144 zs.append(da.concatenate(xs, axis=2)) 

145 ts.append(da.stack(zs)) 

146 image = da.stack(ts) 

147 

148 # Swap channels and z-stacks axis 

149 image = da.swapaxes(image, 1, 2) 

150 

151 # Remove axis of length 1 

152 axes_to_squeeze = [] 

153 if self.metadata.n_timepoints == 1: 

154 axes_to_squeeze.append(0) 

155 if self.metadata.n_channels == 1: 

156 axes_to_squeeze.append(1) 

157 if self.metadata.n_z_slices == 1: 

158 axes_to_squeeze.append(2) 

159 image = da.squeeze(image, tuple(axes_to_squeeze)) 

160 

161 return image 

162 

163 def to_dask(self, downsample: Union[float, Iterable[float]] = None) -> Union[da.Array, tuple[da.Array, ...]]: 

164 """ 

165 Convert this image to one or more dask arrays, at any arbitary downsample factor. 

166 

167 It turns out that requesting at an arbitrary downsample level is very slow - currently, all 

168 pixels are requested upon first compute (even for a small region), and then resized. 

169 Prefer using ImageServer.level_to_dask() instead. 

170 

171 :param downsample: the downsample factor to use, or a list of downsample factors to use. If None, all available resolutions will be used 

172 :return: a dask array or tuple of dask arrays, depending upon whether one or more downsample factors are required 

173 """ 

174 

175 if downsample is None: 

176 if self.n_resolutions == 1: 

177 return self.level_to_dask(level=0) 

178 else: 

179 return tuple([self.level_to_dask(level=level) for level in range(self.metadata.n_resolutions)]) 

180 

181 if isinstance(downsample, Iterable): 

182 return tuple([self.to_dask(downsample=float(d)) for d in downsample]) 

183 

184 level = ImageServer._get_level(self.metadata.downsamples, downsample) 

185 array = self.level_to_dask(level=level) 

186 

187 rescale = downsample / self.metadata.downsamples[level] 

188 input_width = array.shape[-1] 

189 input_height = array.shape[-2] 

190 output_width = int(round(input_width / rescale)) 

191 output_height = int(round(input_height / rescale)) 

192 if input_width == output_width and input_height == output_height: 

193 return array 

194 

195 # Couldn't find an easy resizing method for dask arrays... so we try this instead 

196 # TODO: Urgently need something better! Performance is terrible for large images - all pixels requested 

197 # upon first compute (even for a small region), and then resized. This is not scalable. 

198 if array.size > 10000: 

199 warnings.warn('Warning - calling affine_transform on a large dask array can be *very* slow') 

200 

201 transform = np.eye(array.ndim) 

202 transform[array.ndim-1, array.ndim-1] = rescale 

203 transform[array.ndim-2, array.ndim-2] = rescale 

204 output_shape = list(array.shape) 

205 output_shape[-1] = output_width 

206 output_shape[-2] = output_height 

207 

208 return ndinterp.affine_transform(array, transform, output_shape=tuple(output_shape)) 

209 

210 @abstractmethod 

211 def close(self): 

212 """ 

213 Close this image server. 

214  

215 This should be called whenever this server is not used anymore. 

216 """ 

217 pass 

218 

219 @abstractmethod 

220 def _build_metadata(self) -> ImageMetadata: 

221 """ 

222 Create metadata for the current image. 

223 

224 :return: the metadata of the image 

225 """ 

226 pass 

227 

228 @abstractmethod 

229 def _read_block(self, level: int, region: Region2D) -> np.ndarray: 

230 """ 

231 Read a block of pixels from a specific level. 

232 

233 Coordinates are provided in the coordinate space of the level, NOT the full-resolution image. 

234 This means that the returned image should have the width and height specified. 

235  

236 :param level: the pyramidal level to read from 

237 :param region: the region to read 

238 :return: a 3-dimensional numpy array containing the requested pixels from the 2D region. 

239 The [c, y, x] index of the returned array returns the channel of index c of the 

240 pixel located at coordinates [x, y] on the image 

241 """ 

242 pass 

243 

244 @staticmethod 

245 def _get_level(all_downsamples: tuple[float], downsample: float, abs_tol=1e-3) -> int: 

246 """ 

247 Get the level (index) from the image downsamples that is best for fulfilling an image region request. 

248 

249 This is the index of the entry in self.downsamples that either (almost) matches the requested downsample, 

250 or relates to the next highest resolution image (so that any required scaling is to reduce resolution). 

251 

252 :param downsample: the requested downsample value 

253 :param abs_tol: absolute tolerance when comparing downsample values; this allows for a stored downsample 

254 value to be slightly off due to rounding 

255 (e.g. requesting 4.0 would match a level 4 +/- abs_tol) 

256 :return: the level that is best for fulfilling an image region request at the specified downsample 

257 """ 

258 if len(all_downsamples) == 1 or downsample <= all_downsamples[0]: 

259 return 0 

260 elif downsample >= all_downsamples[-1]: 

261 return len(all_downsamples) - 1 

262 else: 

263 # Allow a little bit of a tolerance because downsamples are often calculated 

264 # by rounding the ratio of image dimensions... and can end up a little bit off 

265 for level, d in reversed(list(enumerate(all_downsamples))): 

266 if downsample >= d - abs_tol: 

267 return level 

268 return 0 

269 

270 @staticmethod 

271 def _resize(image: Union[np.ndarray, Image.Image], target_size: tuple[int, int], resample: int = Image.Resampling.BICUBIC) -> Union[np.ndarray, Image.Image]: 

272 """ 

273 Resize an image to a target size. 

274 

275 This uses the implementation from PIL. 

276 

277 :param image: the image to resize. Either a 3-dimensional numpy array with dimensions (c, y, x) 

278 or a PIL image 

279 :param target_size: target size in (width, height) format 

280 :param resample: resampling mode to use, by default bicubic 

281 :return: the resized image, either a 3-dimensional numpy array with dimensions (c, y, x) or a PIL image 

282 """ 

283 

284 if ImageServer._get_size(image) == target_size: 

285 return image 

286 

287 # If we have a PIL image, just resize normally 

288 if isinstance(image, Image.Image): 

289 return image.resize(size=target_size, resample=resample) 

290 # If we have NumPy, do one channel at a time 

291 else: 

292 if image.ndim == 2: 

293 if image.dtype in [np.uint8, np.float32]: 

294 pilImage = Image.fromarray(image) 

295 elif np.issubdtype(image.dtype, np.integer): 

296 pilImage = Image.fromarray(image.astype(np.int32), mode='I') 

297 elif np.issubdtype(image.dtype, np.bool_): 

298 pilImage = Image.fromarray(image, "1") 

299 else: 

300 pilImage = Image.fromarray(image.astype(np.float32), mode='F') 

301 pilImage = ImageServer._resize(pilImage, target_size=target_size, resample=resample) 

302 return np.asarray(pilImage).astype(image.dtype) 

303 else: 

304 return np.stack([ 

305 ImageServer._resize(image[c, :, :], target_size=target_size, resample=resample) for c in range(image.shape[0]) 

306 ]) 

307 

308 @staticmethod 

309 def _get_size(image: Union[np.ndarray, Image.Image]): 

310 """ 

311 Get the size of an image as a two-element tuple (width, height). 

312 

313 :param image: the image whose size should be computed. Either a 3-dimensional numpy array with dimensions (c, y, x) 

314 or a PIL image 

315 """ 

316 return image.size if isinstance(image, Image.Image) else image.shape[1:][::-1]