ginipick commited on
Commit
9fbd4b4
·
verified ·
1 Parent(s): 12cd271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -49,9 +49,18 @@ def setup_environment():
49
  raise ValueError("HF_TOKEN not found in environment variables")
50
  login(token=HF_TOKEN)
51
  return HF_TOKEN
 
52
  @spaces.GPU()
53
  def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
54
  try:
 
 
 
 
 
 
 
 
55
  # 파이프라인 초기화
56
  pipe = DiffusionPipeline.from_pretrained(
57
  BASE_MODEL,
@@ -73,13 +82,13 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
73
  # 이미지 생성
74
  with torch.inference_mode():
75
  result = pipe(
76
- prompt=f"{prompt} {trigger_word}",
77
  num_inference_steps=steps,
78
  guidance_scale=cfg_scale,
79
  width=width,
80
  height=height,
81
  generator=generator,
82
- cross_attention_kwargs={"scale": lora_scale},
83
  ).images[0]
84
 
85
  # 메모리 정리
@@ -92,6 +101,13 @@ def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512,
92
  clear_memory()
93
  raise gr.Error(f"Generation failed: {str(e)}")
94
 
 
 
 
 
 
 
 
95
 
96
  # 전역 변수 초기화
97
  fashion_pipe = None
 
49
  raise ValueError("HF_TOKEN not found in environment variables")
50
  login(token=HF_TOKEN)
51
  return HF_TOKEN
52
+
53
  @spaces.GPU()
54
  def generate_image(prompt, mode, cfg_scale=7.0, steps=30, seed=None, width=512, height=768, lora_scale=0.85):
55
  try:
56
+ # 한글 처리
57
+ if contains_korean(prompt):
58
+ translator = get_translator()
59
+ translated = translator(prompt)[0]['translation_text']
60
+ actual_prompt = translated
61
+ else:
62
+ actual_prompt = prompt
63
+
64
  # 파이프라인 초기화
65
  pipe = DiffusionPipeline.from_pretrained(
66
  BASE_MODEL,
 
82
  # 이미지 생성
83
  with torch.inference_mode():
84
  result = pipe(
85
+ prompt=f"{actual_prompt} {trigger_word}",
86
  num_inference_steps=steps,
87
  guidance_scale=cfg_scale,
88
  width=width,
89
  height=height,
90
  generator=generator,
91
+ joint_attention_kwargs={"scale": lora_scale},
92
  ).images[0]
93
 
94
  # 메모리 정리
 
101
  clear_memory()
102
  raise gr.Error(f"Generation failed: {str(e)}")
103
 
104
+ def contains_korean(text):
105
+ return any(ord('가') <= ord(char) <= ord('힣') for char in text)
106
+
107
+ @spaces.GPU()
108
+ def get_translator():
109
+ return pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda")
110
+
111
 
112
  # 전역 변수 초기화
113
  fashion_pipe = None