Spaces:
Running
Running
import random | |
from enum import IntEnum | |
from io import BytesIO | |
from typing import Any, Optional, overload | |
from httpx import HTTPError | |
from hibiapi.api.sauce.constants import SauceConstants | |
from hibiapi.utils.decorators import enum_auto_doc | |
from hibiapi.utils.exceptions import ClientSideException | |
from hibiapi.utils.net import catch_network_error | |
from hibiapi.utils.routing import BaseEndpoint, BaseHostUrl | |
class UnavailableSourceException(ClientSideException): | |
code = 422 | |
detail = "given image is not avaliable to fetch" | |
class ImageSourceOversizedException(UnavailableSourceException): | |
code = 413 | |
detail = ( | |
"given image size is rather than maximum limit " | |
f"{SauceConstants.IMAGE_MAXIMUM_SIZE} bytes" | |
) | |
class HostUrl(BaseHostUrl): | |
allowed_hosts = SauceConstants.IMAGE_ALLOWED_HOST | |
class UploadFileIO(BytesIO): | |
def __get_validators__(cls): | |
yield cls.validate | |
def validate(cls, v: Any) -> BytesIO: | |
if not isinstance(v, BytesIO): | |
raise ValueError(f"Expected UploadFile, received: {type(v)}") | |
return v | |
class DeduplicateType(IntEnum): | |
DISABLED = 0 | |
"""no result deduplicating""" | |
IDENTIFIER = 1 | |
"""consolidate search results and deduplicate by item identifier""" | |
ALL = 2 | |
"""all implemented deduplicate methods such as by series name""" | |
class SauceEndpoint(BaseEndpoint, cache_endpoints=False): | |
base = "https://saucenao.com" | |
async def fetch(self, host: HostUrl) -> UploadFileIO: | |
try: | |
response = await self.client.get( | |
url=host, | |
headers=SauceConstants.IMAGE_HEADERS, | |
timeout=SauceConstants.IMAGE_TIMEOUT, | |
) | |
response.raise_for_status() | |
if len(response.content) > SauceConstants.IMAGE_MAXIMUM_SIZE: | |
raise ImageSourceOversizedException | |
return UploadFileIO(response.content) | |
except HTTPError as e: | |
raise UnavailableSourceException(detail=str(e)) from e | |
async def request( | |
self, *, file: UploadFileIO, params: dict[str, Any] | |
) -> dict[str, Any]: | |
response = await self.client.post( | |
url=self._join( | |
self.base, | |
"search.php", | |
params={ | |
**params, | |
"api_key": random.choice(SauceConstants.API_KEY), | |
"output_type": 2, | |
}, | |
), | |
files={"file": file}, | |
) | |
if response.status_code >= 500: | |
response.raise_for_status() | |
return response.json() | |
async def search( | |
self, | |
*, | |
url: HostUrl, | |
size: int = 30, | |
deduplicate: DeduplicateType = DeduplicateType.ALL, | |
database: Optional[int] = None, | |
enabled_mask: Optional[int] = None, | |
disabled_mask: Optional[int] = None, | |
) -> dict[str, Any]: | |
... | |
async def search( | |
self, | |
*, | |
file: UploadFileIO, | |
size: int = 30, | |
deduplicate: DeduplicateType = DeduplicateType.ALL, | |
database: Optional[int] = None, | |
enabled_mask: Optional[int] = None, | |
disabled_mask: Optional[int] = None, | |
) -> dict[str, Any]: | |
... | |
async def search( | |
self, | |
*, | |
url: Optional[HostUrl] = None, | |
file: Optional[UploadFileIO] = None, | |
size: int = 30, | |
deduplicate: DeduplicateType = DeduplicateType.ALL, | |
database: Optional[int] = None, | |
enabled_mask: Optional[int] = None, | |
disabled_mask: Optional[int] = None, | |
): | |
if url is not None: | |
file = await self.fetch(url) | |
assert file is not None | |
return await self.request( | |
file=file, | |
params={ | |
"dbmask": enabled_mask, | |
"dbmaski": disabled_mask, | |
"db": database, | |
"numres": size, | |
"dedupe": deduplicate, | |
}, | |
) | |