Spaces:
Runtime error
Runtime error
File size: 5,222 Bytes
1ba3df3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import argparse
from pathlib import Path
from typing import Optional, Union, Tuple, List
import subprocess
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf, DictConfig
from inference import InferenceServicer
PATH_DOCS = os.getenv("PATH_DOCS", default="docs/ml-font-style-transfer.md")
MODEL_CONFIG = os.getenv("MODEL_CONFIG", default="config/models/google-font.yaml")
MODEL_CHECKPOINT_PATH = os.getenv("MODEL_CHECKPOINT_PATH", default=None)
NOTO_SANS_ZIP_PATH = os.getenv("NOTO_SANS_ZIP_PATH", default=None)
LOCAL_CHECKPOINT_PATH = "checkpoint/checkpoint.ckpt"
LOCAL_NOTO_ZIP_PATH = "data/NotoSans.zip"
if MODEL_CHECKPOINT_PATH is not None:
subprocess.call(f"wget --no-check-certificate -O {LOCAL_CHECKPOINT_PATH} {MODEL_CHECKPOINT_PATH}", shell=True)
if NOTO_SANS_ZIP_PATH is not None:
subprocess.call(f"wget --no-check-certificate -O {LOCAL_NOTO_ZIP_PATH} {NOTO_SANS_ZIP_PATH}", shell=True)
subprocess.call(f"unzip data/NotoSans.zip -d {str(Path(LOCAL_NOTO_ZIP_PATH).parent)}", shell=True)
assert Path("checkpoint/checkpoint.ckpt").exists()
assert Path("data/NotoSans").exists()
EXAMPLE_FONTS = sorted([
"example_fonts/BalooDa2-Bold.ttf",
"example_fonts/BalooDa2-Regular.ttf",
"example_fonts/Lalezar-Regular.ttf",
"example_fonts/MaShanZheng-Regular.ttf",
])
def parse_args():
parser = argparse.ArgumentParser(description="Augmentation simulator for NetsPresso Trainer")
# -------- User arguments ----------------------------------------
parser.add_argument(
'--docs', type=Path, default=PATH_DOCS,
help="Docs string file")
parser.add_argument(
'--config', type=Path, default=MODEL_CONFIG,
help="Config for model")
parser.add_argument(
'--local', action='store_true',
help="Whether to run in local environment or not")
parser.add_argument(
'--port', type=int, default=50003,
help="Service port (only applicable when running on local server)")
args, _ = parser.parse_known_args()
return args
class InferenceServiceResolver(InferenceServicer):
def __init__(self, hp, checkpoint_path, content_image_dir, imsize=64, gpu_id='0') -> None:
super().__init__(hp, checkpoint_path, content_image_dir, imsize, gpu_id)
def generate(self, content_char: str, style_font: Union[str, Path]) -> List[Image.Image]:
try:
content_image, style_images, result = self.inference(content_char=content_char, style_font=style_font)
return [content_image, *style_images, result]
except Exception as e:
raise gr.Error(str(e))
def launch_gradio(docs_path: Path, hp: DictConfig, checkpoint_path: Path, content_image_dir: Path, is_local: bool, port: Optional[int] = None):
servicer = InferenceServiceResolver(hp, checkpoint_path, content_image_dir, gpu_id=None)
with gr.Blocks(title="Multilingual Font Style Transfer (training with Google Fonts)") as demo:
gr.Markdown(docs_path.read_text())
with gr.Row(equal_height=True):
character_input = gr.Textbox(max_lines=1, value="7", info="Only single character is acceptable (e.g. '간', '7', or 'ជ')")
style_font = gr.Dropdown(label="Select example font: ", choices=EXAMPLE_FONTS, value=EXAMPLE_FONTS[0])
run_button = gr.Button(value="Generate", variant='primary')
with gr.Row(equal_height=True):
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"<center><h3>Content character</h3></center>")
content_char = gr.Image(label="Content character", show_label=False)
with gr.Column(scale=5):
with gr.Group():
gr.Markdown(f"<center><h3>Style font images</h3></center>")
with gr.Row(equal_height=True):
style_char_1 = gr.Image(label="Style #1", show_label=False)
style_char_2 = gr.Image(label="Style #2", show_label=False)
style_char_3 = gr.Image(label="Style #3", show_label=False)
style_char_4 = gr.Image(label="Style #4", show_label=False)
style_char_5 = gr.Image(label="Style #5", show_label=False)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"<center><h3>Generated font image</h3></center>")
generated_font = gr.Image(label="Generated font image", show_label=False)
outputs = [content_char, style_char_1, style_char_2, style_char_3, style_char_4, style_char_5, generated_font]
run_inputs = [character_input, style_font]
run_button.click(servicer.generate, inputs=run_inputs, outputs=outputs)
if is_local:
demo.launch(server_name="0.0.0.0", server_port=port)
else:
demo.launch()
if __name__ == "__main__":
args = parse_args()
hp = OmegaConf.load(args.config)
checkpoint_path = Path(LOCAL_CHECKPOINT_PATH)
content_image_dir = Path(LOCAL_NOTO_ZIP_PATH).with_suffix("")
launch_gradio(args.docs, hp, checkpoint_path, content_image_dir, args.local, args.port) |