vivaceailab commited on
Commit
bfaf167
Β·
verified Β·
1 Parent(s): 9974d04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, types, importlib.machinery, importlib
2
+
3
+ spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
4
+ mod = types.ModuleType('flash_attn')
5
+ mod.__spec__ = spec
6
+ sys.modules['flash_attn'] = mod
7
+
8
+ import huggingface_hub as _hf_hub
9
+ _hf_hub.cached_download = _hf_hub.hf_hub_download
10
+
11
+ import gradio as gr
12
+ import torch
13
+ import random
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, AutoModelForCausalLM
16
+ from diffusers import DiffusionPipeline
17
+ try:
18
+ from diffusers import FlowMatchEulerDiscreteScheduler
19
+ except ImportError:
20
+ from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler
21
+
22
+ import transformers.utils.import_utils as _import_utils
23
+ from transformers.utils import is_flash_attn_2_available
24
+ _import_utils._is_package_available = lambda pkg: False
25
+ _import_utils.is_flash_attn_2_available = lambda: False
26
+
27
+ hf_utils = importlib.import_module('transformers.utils')
28
+ hf_utils.is_flash_attn_2_available = lambda *a, **k: False
29
+ hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False
30
+
31
+ mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
32
+ for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
33
+ if not hasattr(mask_utils, fn):
34
+ setattr(mask_utils, fn, lambda *a, **k: None)
35
+
36
+ cfg_mod = importlib.import_module("transformers.configuration_utils")
37
+ _PrC = cfg_mod.PretrainedConfig
38
+ _orig_getattr = _PrC.__getattribute__
39
+ def _getattr(self, name):
40
+ if name == "_attn_implementation":
41
+ return "sdpa"
42
+ return _orig_getattr(self, name)
43
+ _PrC.__getattribute__ = _getattr
44
+
45
+ REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"
46
+
47
+ # Florence-2 λ‘œλ“œ
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True, torch_dtype=torch.float16)
50
+ florence_model.to("cpu")
51
+ florence_model.eval()
52
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', revision = REVISION, trust_remote_code=True)
53
+
54
+ # Stable Diffusion TurboX λ‘œλ“œ
55
+ model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"
56
+ pipe = DiffusionPipeline.from_pretrained(
57
+ model_repo,
58
+ trust_remote_code=True,
59
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
60
+ )
61
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", shift=5)
62
+ pipe = pipe.to(device)
63
+
64
+ MAX_SEED = 2**31 - 1
65
+
66
+ def pseudo_translate_to_korean_style(en_prompt: str) -> str:
67
+ # λ²ˆμ—­ 없이 μŠ€νƒ€μΌ 적용
68
+ return f"Cartoon styled {en_prompt} handsome or pretty people"
69
+
70
+ def generate_prompt(image):
71
+ """이미지 β†’ μ˜μ–΄ μ„€λͺ… β†’ ν•œκ΅­μ–΄ ν”„λ‘¬ν”„νŠΈ μŠ€νƒ€μΌλ‘œ λ³€ν™˜"""
72
+ if not isinstance(image, Image.Image):
73
+ image = Image.fromarray(image)
74
+
75
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
76
+ generated_ids = florence_model.generate(
77
+ input_ids=inputs["input_ids"],
78
+ pixel_values=inputs["pixel_values"],
79
+ max_new_tokens=512,
80
+ num_beams=3
81
+ )
82
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
83
+ parsed_answer = florence_processor.post_process_generation(
84
+ generated_text,
85
+ task="<MORE_DETAILED_CAPTION>",
86
+ image_size=(image.width, image.height)
87
+ )
88
+ prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]
89
+
90
+ # λ²ˆμ—­κΈ° 없이 μŠ€νƒ€μΌ 적용
91
+ cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
92
+ return cartoon_prompt
93
+
94
+ def generate_image(prompt, seed=42, randomize_seed=False):
95
+ """ν…μŠ€νŠΈ ν”„λ‘¬ν”„νŠΈ β†’ 이미지 생성"""
96
+ if randomize_seed:
97
+ seed = random.randint(0, MAX_SEED)
98
+ generator = torch.Generator().manual_seed(seed)
99
+ image = pipe(
100
+ prompt=prompt,
101
+ negative_prompt="μ™œκ³‘λœ 손, 흐림, μ΄μƒν•œ μ–Όκ΅΄",
102
+ guidance_scale=1.5,
103
+ num_inference_steps=8,
104
+ width=768,
105
+ height=768,
106
+ generator=generator
107
+ ).images[0]
108
+ return image, seed
109
+
110
+ # Gradio UI ꡬ성
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("# πŸ–Ό 이미지 β†’ μ„€λͺ… 생성 β†’ 카툰 이미지 μžλ™ 생성기")
113
+
114
+ gr.Markdown("**πŸ“Œ μ‚¬μš©λ²• μ•ˆλ‚΄ (ν•œκ΅­μ–΄)**\n"
115
+ "- μ™Όμͺ½μ— 이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš”.\n"
116
+ "- AIκ°€ μ˜μ–΄ μ„€λͺ…을 λ§Œλ“€κ³ , λ‚΄λΆ€μ—μ„œ ν•œκ΅­μ–΄ μŠ€νƒ€μΌ ν”„λ‘¬ν”„νŠΈλ‘œ μž¬κ΅¬μ„±ν•©λ‹ˆλ‹€.\n"
117
+ "- 였λ₯Έμͺ½μ— κ²°κ³Ό 이미지가 μƒμ„±λ©λ‹ˆλ‹€.")
118
+
119
+ with gr.Row():
120
+ with gr.Column():
121
+ input_img = gr.Image(label="🎨 원본 이미지 μ—…λ‘œλ“œ")
122
+ run_button = gr.Button("✨ 생성 μ‹œμž‘")
123
+
124
+ with gr.Column():
125
+ prompt_out = gr.Textbox(label="πŸ“ μŠ€νƒ€μΌ 적용된 ν”„λ‘¬ν”„νŠΈ", lines=3, show_copy_button=True)
126
+ output_img = gr.Image(label="πŸŽ‰ μƒμ„±λœ 이미지")
127
+
128
+ def full_process(img):
129
+ prompt = generate_prompt(img)
130
+ image, seed = generate_image(prompt, randomize_seed=True)
131
+ return prompt, image
132
+
133
+ run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])
134
+
135
+ demo.launch()