Update app.py
Browse files
app.py
CHANGED
@@ -49,9 +49,18 @@ def setup_environment():
|
|
49 |
raise ValueError("HF_TOKEN not found in environment variables")
|
50 |
login(token=HF_TOKEN)
|
51 |
return HF_TOKEN
|
|
|
52 |
@spaces.GPU()
|
53 |
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
|
54 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# 파이프라인 초기화
|
56 |
pipe = DiffusionPipeline.from_pretrained(
|
57 |
BASE_MODEL,
|
@@ -73,13 +82,13 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
|
|
73 |
# 이미지 생성
|
74 |
with torch.inference_mode():
|
75 |
result = pipe(
|
76 |
-
prompt=f"{
|
77 |
num_inference_steps=steps,
|
78 |
guidance_scale=cfg_scale,
|
79 |
width=width,
|
80 |
height=height,
|
81 |
generator=generator,
|
82 |
-
|
83 |
).images[0]
|
84 |
|
85 |
# 메모리 정리
|
@@ -92,6 +101,13 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
|
|
92 |
clear_memory()
|
93 |
raise gr.Error(f"Generation failed: {str(e)}")
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
# 전역 변수 초기화
|
97 |
fashion_pipe = None
|
|
|
49 |
raise ValueError("HF_TOKEN not found in environment variables")
|
50 |
login(token=HF_TOKEN)
|
51 |
return HF_TOKEN
|
52 |
+
|
53 |
@spaces.GPU()
|
54 |
def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
|
55 |
try:
|
56 |
+
# 한글 처리
|
57 |
+
if contains_korean(prompt):
|
58 |
+
translator = get_translator()
|
59 |
+
translated = translator(prompt)[0]['translation_text']
|
60 |
+
actual_prompt = translated
|
61 |
+
else:
|
62 |
+
actual_prompt = prompt
|
63 |
+
|
64 |
# 파이프라인 초기화
|
65 |
pipe = DiffusionPipeline.from_pretrained(
|
66 |
BASE_MODEL,
|
|
|
82 |
# 이미지 생성
|
83 |
with torch.inference_mode():
|
84 |
result = pipe(
|
85 |
+
prompt=f"{actual_prompt} {trigger_word}",
|
86 |
num_inference_steps=steps,
|
87 |
guidance_scale=cfg_scale,
|
88 |
width=width,
|
89 |
height=height,
|
90 |
generator=generator,
|
91 |
+
joint_attention_kwargs={"scale": lora_scale},
|
92 |
).images[0]
|
93 |
|
94 |
# 메모리 정리
|
|
|
101 |
clear_memory()
|
102 |
raise gr.Error(f"Generation failed: {str(e)}")
|
103 |
|
104 |
+
def contains_korean(text):
|
105 |
+
return any(ord('가') <= ord(char) <= ord('힣') for char in text)
|
106 |
+
|
107 |
+
@spaces.GPU()
|
108 |
+
def get_translator():
|
109 |
+
return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
|
110 |
+
|
111 |
|
112 |
# 전역 변수 초기화
|
113 |
fashion_pipe = None
|