Spaces:
Paused
Paused
| # Copyright 2024 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| from typing import Union | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| from spandrel import ImageModelDescriptor, ModelLoader | |
| from ..image_processor import ImageMixin | |
| from ..utils import get_model_path, tiled_upscale | |
| class UpscaleWithModel(ImageMixin): | |
| r""" | |
| Upscaler class that uses a pytorch model. | |
| Args: | |
| model ([`ImageModelDescriptor`]): | |
| Upscaler model, must be supported by spandrel. | |
| scale (`int`, defaults to the scale of the model): | |
| The number of times to scale the image, it is recommended to use the model default scale which | |
| usually is what the model was trained for. | |
| """ | |
| def __init__(self, model: ImageModelDescriptor, scale: int = None): | |
| super().__init__() | |
| self.model = model | |
| def to(self, device): | |
| self.model.to(device) | |
| return self | |
| def from_pretrained( | |
| cls, pretrained_model_or_path: Union[str, os.PathLike], filename: str = None, subfolder: str = None | |
| ) -> ImageModelDescriptor: | |
| r""" | |
| Instantiate the Upscaler class from pretrained weights. | |
| Parameters: | |
| pretrained_model_name_or_path (`str` or `os.PathLike`): | |
| Can be either: | |
| - A string, the *repo id* (for example `OzzyGT/UltraSharp`) of a pretrained model | |
| hosted on the Hub, must be saved in safetensors. If there's more than one checkpoint | |
| in the repository and the filename wasn't specified, the first one found will be loaded. | |
| - A path to a *directory* (for example `./upscaler_model/`) containing a pretrained | |
| upscaler checkpoint. | |
| filename (`str`, *optional*): | |
| The name of the file in the repo. | |
| subfolder (`str`, *optional*): | |
| An optional value corresponding to a folder inside the model repo. | |
| """ | |
| model_path = get_model_path(pretrained_model_or_path, filename, subfolder) | |
| model = ModelLoader().load_from_file(model_path) | |
| # validate that it's the correct model | |
| assert isinstance(model, ImageModelDescriptor) | |
| return cls(model) | |
| def __call__( | |
| self, | |
| image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], | |
| tiling: bool = False, | |
| tile_width: int = 512, | |
| tile_height: int = 512, | |
| overlap: int = 8, | |
| return_type: str = "pil", | |
| ) -> Union[torch.Tensor, PIL.Image.Image, np.ndarray]: | |
| r""" | |
| Upscales the given image, optionally using tiling. | |
| Args: | |
| image (Union[PIL.Image.Image, np.ndarray, torch.Tensor]): | |
| The image to be upscaled. Can be a PIL Image, NumPy array, or PyTorch tensor. | |
| tiling (bool, optional): | |
| Whether to use tiling for upscaling. Default is False. | |
| tile_width (int, optional): | |
| The width of each tile if tiling is used. Default is 512. | |
| tile_height (int, optional): | |
| The height of each tile if tiling is used. Default is 512. | |
| overlap (int, optional): | |
| The overlap between tiles if tiling is used. Default is 8. | |
| return_type (str, optional): | |
| The type of the returned image. Can be 'pil', 'numpy', or 'tensor'. Default is 'pil'. | |
| Returns: | |
| Union[torch.Tensor, PIL.Image.Image, np.ndarray]: | |
| The upscaled image, in the format specified by `return_type`. | |
| """ | |
| if not isinstance(image, torch.Tensor): | |
| image = self.convert_image_to_tensor(image) | |
| image = image.to(self.model.device) | |
| if tiling: | |
| upscaled_tensor = tiled_upscale(image, self.model, self.model.scale, tile_width, tile_height, overlap) | |
| else: | |
| upscaled_tensor = self.model(image) | |
| image = self.post_process_image(upscaled_tensor, return_type)[0] | |
| return image | |