|  | from typing import List | 
					
						
						|  |  | 
					
						
						|  | import PIL.Image | 
					
						
						|  | import PIL.ImageOps | 
					
						
						|  | from packaging import version | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): | 
					
						
						|  | PIL_INTERPOLATION = { | 
					
						
						|  | "linear": PIL.Image.Resampling.BILINEAR, | 
					
						
						|  | "bilinear": PIL.Image.Resampling.BILINEAR, | 
					
						
						|  | "bicubic": PIL.Image.Resampling.BICUBIC, | 
					
						
						|  | "lanczos": PIL.Image.Resampling.LANCZOS, | 
					
						
						|  | "nearest": PIL.Image.Resampling.NEAREST, | 
					
						
						|  | } | 
					
						
						|  | else: | 
					
						
						|  | PIL_INTERPOLATION = { | 
					
						
						|  | "linear": PIL.Image.LINEAR, | 
					
						
						|  | "bilinear": PIL.Image.BILINEAR, | 
					
						
						|  | "bicubic": PIL.Image.BICUBIC, | 
					
						
						|  | "lanczos": PIL.Image.LANCZOS, | 
					
						
						|  | "nearest": PIL.Image.NEAREST, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def pt_to_pil(images): | 
					
						
						|  | """ | 
					
						
						|  | Convert a torch image to a PIL image. | 
					
						
						|  | """ | 
					
						
						|  | images = (images / 2 + 0.5).clamp(0, 1) | 
					
						
						|  | images = images.cpu().permute(0, 2, 3, 1).float().numpy() | 
					
						
						|  | images = numpy_to_pil(images) | 
					
						
						|  | return images | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def numpy_to_pil(images): | 
					
						
						|  | """ | 
					
						
						|  | Convert a numpy image or a batch of images to a PIL image. | 
					
						
						|  | """ | 
					
						
						|  | if images.ndim == 3: | 
					
						
						|  | images = images[None, ...] | 
					
						
						|  | images = (images * 255).round().astype("uint8") | 
					
						
						|  | if images.shape[-1] == 1: | 
					
						
						|  |  | 
					
						
						|  | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | 
					
						
						|  | else: | 
					
						
						|  | pil_images = [Image.fromarray(image) for image in images] | 
					
						
						|  |  | 
					
						
						|  | return pil_images | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: | 
					
						
						|  | """ | 
					
						
						|  | Prepares a single grid of images. Useful for visualization purposes. | 
					
						
						|  | """ | 
					
						
						|  | assert len(images) == rows * cols | 
					
						
						|  |  | 
					
						
						|  | if resize is not None: | 
					
						
						|  | images = [img.resize((resize, resize)) for img in images] | 
					
						
						|  |  | 
					
						
						|  | w, h = images[0].size | 
					
						
						|  | grid = Image.new("RGB", size=(cols * w, rows * h)) | 
					
						
						|  |  | 
					
						
						|  | for i, img in enumerate(images): | 
					
						
						|  | grid.paste(img, box=(i % cols * w, i // cols * h)) | 
					
						
						|  | return grid | 
					
						
						|  |  |