Spaces:
Sleeping
Sleeping
File size: 5,625 Bytes
346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a 1d20b52 346533a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import copy
from argparse import ArgumentError, ArgumentParser
from collections.abc import Awaitable
from contextvars import ContextVar
from dataclasses import dataclass, field
from io import BytesIO
from pathlib import Path
from typing import (
IO,
Any,
Callable,
Literal,
Optional,
TypeVar,
Union,
cast,
)
from pil_utils import BuildImage
from pydantic import BaseModel, ValidationError
from .exception import (
ArgModelMismatch,
ArgParserExit,
ImageNumberMismatch,
OpenImageFailed,
ParserExit,
TextNumberMismatch,
TextOrNameNotEnough,
)
from .utils import is_coroutine_callable, random_image, random_text, run_sync
class UserInfo(BaseModel):
name: str = ""
gender: Literal["male", "female", "unknown"] = "unknown"
class MemeArgsModel(BaseModel):
user_infos: list[UserInfo] = []
ArgsModel = TypeVar("ArgsModel", bound=MemeArgsModel)
MemeFunction = Union[
Callable[[list[BuildImage], list[str], ArgsModel], BytesIO],
Callable[[list[BuildImage], list[str], ArgsModel], Awaitable[BytesIO]],
]
parser_message: ContextVar[str] = ContextVar("parser_message")
class MemeArgsParser(ArgumentParser):
"""`shell_like` 命令参数解析器,解析出错时不会退出程序。
用法:
用法与 `argparse.ArgumentParser` 相同,
参考文档: [argparse](https://docs.python.org/3/library/argparse.html)
"""
def _print_message(self, message: str, file: Optional[IO[str]] = None):
if (msg := parser_message.get(None)) is not None:
parser_message.set(msg + message)
else:
super()._print_message(message, file)
def exit(self, status: int = 0, message: Optional[str] = None):
if message:
self._print_message(message)
raise ParserExit(status=status, error_message=parser_message.get(None))
@dataclass
class MemeArgsType:
parser: MemeArgsParser
model: type[MemeArgsModel]
instances: list[MemeArgsModel] = field(default_factory=list)
@dataclass
class MemeParamsType:
min_images: int = 0
max_images: int = 0
min_texts: int = 0
max_texts: int = 0
default_texts: list[str] = field(default_factory=list)
args_type: Optional[MemeArgsType] = None
@dataclass
class Meme:
key: str
function: MemeFunction
params_type: MemeParamsType
keywords: list[str] = field(default_factory=list)
patterns: list[str] = field(default_factory=list)
async def __call__(
self,
*,
images: Union[list[str], list[Path], list[bytes], list[BytesIO]] = [],
texts: list[str] = [],
args: dict[str, Any] = {},
) -> BytesIO:
if not (
self.params_type.min_images <= len(images) <= self.params_type.max_images
):
raise ImageNumberMismatch(
self.key, self.params_type.min_images, self.params_type.max_images
)
if not (self.params_type.min_texts <= len(texts) <= self.params_type.max_texts):
raise TextNumberMismatch(
self.key, self.params_type.min_texts, self.params_type.max_texts
)
if args_type := self.params_type.args_type:
args_model = args_type.model
else:
args_model = MemeArgsModel
try:
model = args_model.parse_obj(args)
except ValidationError as e:
raise ArgModelMismatch(self.key, str(e))
imgs: list[BuildImage] = []
try:
for image in images:
if isinstance(image, bytes):
image = BytesIO(image)
imgs.append(BuildImage.open(image)) # type: ignore
except Exception as e:
raise OpenImageFailed(str(e))
values = {"images": imgs, "texts": texts, "args": model}
if is_coroutine_callable(self.function):
return await cast(Callable[..., Awaitable[BytesIO]], self.function)(
**values
)
else:
return await run_sync(cast(Callable[..., BytesIO], self.function))(**values)
def parse_args(self, args: list[str] = []) -> dict[str, Any]:
parser = (
copy.deepcopy(self.params_type.args_type.parser)
if self.params_type.args_type
else MemeArgsParser()
)
parser.add_argument("texts", nargs="*", default=[])
t = parser_message.set("")
try:
return vars(parser.parse_args(args))
except ArgumentError as e:
raise ArgParserExit(self.key, str(e))
except ParserExit as e:
raise ArgParserExit(self.key, e.error_message)
finally:
parser_message.reset(t)
async def generate_preview(self, *, args: dict[str, Any] = {}) -> BytesIO:
default_images = [random_image() for _ in range(self.params_type.min_images)]
default_texts = (
self.params_type.default_texts.copy()
if (
self.params_type.min_texts
<= len(self.params_type.default_texts)
<= self.params_type.max_texts
)
else [random_text() for _ in range(self.params_type.min_texts)]
)
async def _generate_preview(images: list[BytesIO], texts: list[str]):
try:
return await self.__call__(images=images, texts=texts, args=args)
except TextOrNameNotEnough:
texts.append(random_text())
return await _generate_preview(images, texts)
return await _generate_preview(default_images, default_texts)
|