Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,10 +10,17 @@ import numpy as np
|
|
10 |
from diffusers import DiffusionPipeline
|
11 |
from transformers import pipeline as hf_pipeline
|
12 |
|
13 |
-
# ----------------------
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
17 |
# νκ΅μ΄-μμ΄ λ²μ λͺ¨λΈ λ‘λ (μ₯μΉμ λ°λΌ CPU λλ GPU μ¬μ©)
|
18 |
translator = hf_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=0 if device=="cuda" else -1)
|
19 |
|
@@ -79,7 +86,6 @@ logger = logging.getLogger("idea_generator")
|
|
79 |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
80 |
genai.configure(api_key=GEMINI_API_KEY)
|
81 |
|
82 |
-
# μ¬λμ("/")κ° ν¬ν¨λ λ³ν λ¬Έμμ΄μμ λ μ΅μ
μ€ νλλ§ μ ννλ ν¬νΌ ν¨μ
|
83 |
def choose_alternative(transformation):
|
84 |
if "/" not in transformation:
|
85 |
return transformation
|
@@ -101,7 +107,7 @@ def choose_alternative(transformation):
|
|
101 |
else:
|
102 |
return random.choice([left, right])
|
103 |
|
104 |
-
# μ°½μμ μΈ λͺ¨λΈ/컨μ
/νμ λ³ν μμ΄λμ΄λ₯Ό μν μΉ΄ν
κ³ λ¦¬
|
105 |
physical_transformation_categories = {
|
106 |
"κ³΅κ° μ΄λ": [
|
107 |
"μ/λ€ μ΄λ", "μ’/μ° μ΄λ", "μ/μλ μ΄λ", "μΈλ‘μΆ νμ (κ³ κ° λλμ)",
|
@@ -209,7 +215,7 @@ physical_transformation_categories = {
|
|
209 |
"μν₯ λ°μ¬/ν‘μ", "μν₯ λνλ¬ ν¨κ³Ό", "μν κ°μ", "μν₯ 곡μ§",
|
210 |
"μ§λ ν¨ν΄ λ³ν", "νμ
ν¨κ³Ό", "μν₯ νΌλλ°±", "μν₯ μ°¨ν/μ¦ν",
|
211 |
"μ리 μ§ν₯μ±", "μν₯ μ곑", "λΉνΈ μμ±", "νλͺ¨λμ€ μμ±", "μ£Όνμ λ³μ‘°",
|
212 |
-
"μν₯ 좩격ν", "μν₯ νν°λ§"
|
213 |
],
|
214 |
|
215 |
"μλ¬Όνμ λ³ν": [
|
@@ -244,7 +250,7 @@ physical_transformation_categories = {
|
|
244 |
}
|
245 |
|
246 |
##############################################################################
|
247 |
-
# Gemini API νΈμΆ ν¨μ (
|
248 |
##############################################################################
|
249 |
def query_gemini_api(prompt):
|
250 |
try:
|
@@ -253,15 +259,13 @@ def query_gemini_api(prompt):
|
|
253 |
try:
|
254 |
if hasattr(response, 'text'):
|
255 |
return response.text
|
256 |
-
|
257 |
if hasattr(response, 'candidates') and response.candidates:
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
if
|
263 |
-
|
264 |
-
return content.parts[0].text
|
265 |
if hasattr(response, 'parts') and response.parts:
|
266 |
if len(response.parts) > 0:
|
267 |
return response.parts[0].text
|
@@ -290,48 +294,42 @@ def enhance_with_llm(base_description, obj_name, category):
|
|
290 |
return query_gemini_api(prompt)
|
291 |
|
292 |
##############################################################################
|
293 |
-
#
|
294 |
##############################################################################
|
295 |
-
def
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
return
|
302 |
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
base_description = template.format(obj1=obj1, obj2=obj2, change=transformation)
|
315 |
-
results[category] = {"base": base_description, "enhanced": None}
|
316 |
-
return results
|
317 |
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
base_description = template.format(obj1=obj1, obj2=obj2, obj3=obj3, change=transformation)
|
330 |
-
results[category] = {"base": base_description, "enhanced": None}
|
331 |
-
return results
|
332 |
|
333 |
##############################################################################
|
334 |
-
# μμ±λ κΈ°λ³Έ μ€λͺ
μ LLMμ ν΅ν΄ νμ₯
|
335 |
##############################################################################
|
336 |
def enhance_descriptions(results, objects):
|
337 |
obj_name = " λ° ".join([obj for obj in objects if obj])
|
@@ -340,17 +338,17 @@ def enhance_descriptions(results, objects):
|
|
340 |
return results
|
341 |
|
342 |
##############################################################################
|
343 |
-
# μ¬μ©μ μ
λ ₯(μ΅λ 3κ° ν€μλ)
|
344 |
##############################################################################
|
345 |
-
def generate_transformations(text1, text2
|
346 |
if text2 and text3:
|
347 |
-
results =
|
348 |
objects = [text1, text2, text3]
|
349 |
elif text2:
|
350 |
-
results =
|
351 |
objects = [text1, text2]
|
352 |
else:
|
353 |
-
results =
|
354 |
objects = [text1]
|
355 |
return enhance_descriptions(results, objects)
|
356 |
|
@@ -378,13 +376,8 @@ def process_inputs(text1, text2, text3, selected_category, progress=gr.Progress(
|
|
378 |
time.sleep(0.3)
|
379 |
progress(0.1, desc="μ°½μοΏ½οΏ½οΏ½μΈ μμ΄λμ΄ μμ± μμ...")
|
380 |
|
381 |
-
|
382 |
-
|
383 |
-
# μ νν μΉ΄ν
κ³ λ¦¬ κ²°κ³Όλ§ νν°λ§
|
384 |
-
if selected_category in results:
|
385 |
-
results = {selected_category: results[selected_category]}
|
386 |
-
else:
|
387 |
-
return "μ νν μΉ΄ν
κ³ λ¦¬κ° κ²°κ³Όμ μ‘΄μ¬νμ§ μμ΅λλ€."
|
388 |
|
389 |
progress(0.8, desc="κ²°κ³Ό ν¬λ§·ν
μ€...")
|
390 |
formatted = format_results(results)
|
@@ -395,9 +388,7 @@ def process_inputs(text1, text2, text3, selected_category, progress=gr.Progress(
|
|
395 |
# μλ‘μ΄ ν΅ν© ν¨μ: μμ΄λμ΄ ν
μ€νΈ μμ± λ° μ΄λ―Έμ§ μμ±
|
396 |
##############################################################################
|
397 |
def process_all(text1, text2, text3, selected_category, progress=gr.Progress()):
|
398 |
-
# νμ₯ μμ΄λμ΄ ν
μ€νΈ μμ±
|
399 |
idea_result = process_inputs(text1, text2, text3, selected_category, progress)
|
400 |
-
# μμ±λ μμ΄λμ΄λ₯Ό κ·Έλλ‘ μ΄λ―Έμ§ μμ± ν둬ννΈλ‘ μ¬μ©
|
401 |
image_result = generate_design_image(idea_result, seed=42, randomize_seed=True, width=1024, height=1024, num_inference_steps=4)
|
402 |
return idea_result, image_result
|
403 |
|
@@ -444,7 +435,7 @@ with gr.Blocks(title="ν€μλ κΈ°λ° μ°½μμ λ³ν μμ΄λμ΄ λ° λμ
|
|
444 |
value=list(physical_transformation_categories.keys())[0],
|
445 |
info="μΆλ ₯ν μΉ΄ν
κ³ λ¦¬λ₯Ό μ ννμΈμ."
|
446 |
)
|
447 |
-
status_msg = gr.Markdown("π‘ 'μμ΄λμ΄ μμ±νκΈ°' λ²νΌμ ν΄λ¦νλ©΄
|
448 |
processing_indicator = gr.HTML("""
|
449 |
<div style="display: flex; justify-content: center; align-items: center; margin: 10px 0;">
|
450 |
<div style="border: 5px solid #f3f3f3; border-top: 5px solid #3498db; border-radius: 50%; width: 30px; height: 30px; animation: spin 2s linear infinite;"></div>
|
|
|
10 |
from diffusers import DiffusionPipeline
|
11 |
from transformers import pipeline as hf_pipeline
|
12 |
|
13 |
+
# ---------------------- ZeroGPU νκ²½μμ GPU μ¬μ© μ€μ ----------------------
|
14 |
+
# Hugging Face Spacesμ ZeroGPU νκ²½μ΄λ©΄ νκ²½ λ³μ ZERO_GPUκ° μ€μ λμ΄ μλ€κ³ κ°μ
|
15 |
+
if os.getenv("ZERO_GPU"):
|
16 |
+
device = "cuda"
|
17 |
+
torch.cuda.set_device(0)
|
18 |
+
else:
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
|
21 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
22 |
|
23 |
+
# ---------------------- μ΄λ―Έμ§ μμ± κ΄λ ¨ μ€μ ----------------------
|
24 |
# νκ΅μ΄-μμ΄ λ²μ λͺ¨λΈ λ‘λ (μ₯μΉμ λ°λΌ CPU λλ GPU μ¬μ©)
|
25 |
translator = hf_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device=0 if device=="cuda" else -1)
|
26 |
|
|
|
86 |
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
87 |
genai.configure(api_key=GEMINI_API_KEY)
|
88 |
|
|
|
89 |
def choose_alternative(transformation):
|
90 |
if "/" not in transformation:
|
91 |
return transformation
|
|
|
107 |
else:
|
108 |
return random.choice([left, right])
|
109 |
|
110 |
+
# μ°½μμ μΈ λͺ¨λΈ/컨μ
/νμ λ³ν μμ΄λμ΄λ₯Ό μν μΉ΄ν
κ³ λ¦¬ (μ΄ 15κ°)
|
111 |
physical_transformation_categories = {
|
112 |
"κ³΅κ° μ΄λ": [
|
113 |
"μ/λ€ μ΄λ", "μ’/μ° μ΄λ", "μ/μλ μ΄λ", "μΈλ‘μΆ νμ (κ³ κ° λλμ)",
|
|
|
215 |
"μν₯ λ°μ¬/ν‘μ", "μν₯ λνλ¬ ν¨κ³Ό", "μν κ°μ", "μν₯ 곡μ§",
|
216 |
"μ§λ ν¨ν΄ λ³ν", "νμ
ν¨κ³Ό", "μν₯ νΌλλ°±", "μν₯ μ°¨ν/μ¦ν",
|
217 |
"μ리 μ§ν₯μ±", "μν₯ μ곑", "λΉνΈ μμ±", "νλͺ¨λμ€ μμ±", "μ£Όνμ λ³μ‘°",
|
218 |
+
"μν₯ 좩격ν", "μν₯ νν°λ§"
|
219 |
],
|
220 |
|
221 |
"μλ¬Όνμ λ³ν": [
|
|
|
250 |
}
|
251 |
|
252 |
##############################################################################
|
253 |
+
# Gemini API νΈμΆ ν¨μ (λͺ¨λΈ: gemini-2.0-flash-thinking-exp-01-21)
|
254 |
##############################################################################
|
255 |
def query_gemini_api(prompt):
|
256 |
try:
|
|
|
259 |
try:
|
260 |
if hasattr(response, 'text'):
|
261 |
return response.text
|
|
|
262 |
if hasattr(response, 'candidates') and response.candidates:
|
263 |
+
candidate = response.candidates[0]
|
264 |
+
if hasattr(candidate, 'content'):
|
265 |
+
content = candidate.content
|
266 |
+
if hasattr(content, 'parts') and content.parts:
|
267 |
+
if len(content.parts) > 0:
|
268 |
+
return content.parts[0].text
|
|
|
269 |
if hasattr(response, 'parts') and response.parts:
|
270 |
if len(response.parts) > 0:
|
271 |
return response.parts[0].text
|
|
|
294 |
return query_gemini_api(prompt)
|
295 |
|
296 |
##############################################################################
|
297 |
+
# μ νλ μΉ΄ν
κ³ λ¦¬λ§μ λμμΌλ‘ ν μ°½μμ λ³ν μμ΄λμ΄ μμ± ν¨μλ€
|
298 |
##############################################################################
|
299 |
+
def generate_single_object_transformation_for_category(obj, selected_category):
|
300 |
+
transformations = physical_transformation_categories.get(selected_category)
|
301 |
+
if not transformations:
|
302 |
+
return {}
|
303 |
+
transformation = choose_alternative(random.choice(transformations))
|
304 |
+
base_description = f"{obj}μ΄(κ°) {transformation} νμμ 보μΈλ€"
|
305 |
+
return {selected_category: {"base": base_description, "enhanced": None}}
|
306 |
|
307 |
+
def generate_two_objects_interaction_for_category(obj1, obj2, selected_category):
|
308 |
+
transformations = physical_transformation_categories.get(selected_category)
|
309 |
+
if not transformations:
|
310 |
+
return {}
|
311 |
+
transformation = choose_alternative(random.choice(transformations))
|
312 |
+
template = random.choice([
|
313 |
+
"{obj1}μ΄(κ°) {obj2}μ κ²°ν©νμ¬ {change}κ° λ°μνλ€",
|
314 |
+
"{obj1}κ³Ό(μ) {obj2}μ΄(κ°) μΆ©λνλ©΄μ {change}κ° μΌμ΄λ¬λ€"
|
315 |
+
])
|
316 |
+
base_description = template.format(obj1=obj1, obj2=obj2, change=transformation)
|
317 |
+
return {selected_category: {"base": base_description, "enhanced": None}}
|
|
|
|
|
|
|
318 |
|
319 |
+
def generate_three_objects_interaction_for_category(obj1, obj2, obj3, selected_category):
|
320 |
+
transformations = physical_transformation_categories.get(selected_category)
|
321 |
+
if not transformations:
|
322 |
+
return {}
|
323 |
+
transformation = choose_alternative(random.choice(transformations))
|
324 |
+
template = random.choice([
|
325 |
+
"{obj1}, {obj2}, {obj3}μ΄(κ°) μΌκ°ν κ΅¬μ‘°λ‘ κ²°ν©νμ¬ {change}κ° λ°μνλ€",
|
326 |
+
"{obj1}μ΄(κ°) {obj2}μ(κ³Ό) {obj3} μ¬μ΄μμ λ§€κ°μ²΄ μν μ νλ©° {change}λ₯Ό μ΄μ§νλ€"
|
327 |
+
])
|
328 |
+
base_description = template.format(obj1=obj1, obj2=obj2, obj3=obj3, change=transformation)
|
329 |
+
return {selected_category: {"base": base_description, "enhanced": None}}
|
|
|
|
|
|
|
330 |
|
331 |
##############################################################################
|
332 |
+
# μμ±λ κΈ°λ³Έ μ€λͺ
μ LLMμ ν΅ν΄ νμ₯ (μ νλ μΉ΄ν
κ³ λ¦¬λ§ ν΄λΉ)
|
333 |
##############################################################################
|
334 |
def enhance_descriptions(results, objects):
|
335 |
obj_name = " λ° ".join([obj for obj in objects if obj])
|
|
|
338 |
return results
|
339 |
|
340 |
##############################################################################
|
341 |
+
# μ¬μ©μ μ
λ ₯(μ΅λ 3κ° ν€μλ)μ μ νλ μΉ΄ν
κ³ λ¦¬μ λ°λΌ μ°½μμ λ³ν μμ΄λμ΄ μμ±
|
342 |
##############################################################################
|
343 |
+
def generate_transformations(text1, text2, text3, selected_category):
|
344 |
if text2 and text3:
|
345 |
+
results = generate_three_objects_interaction_for_category(text1, text2, text3, selected_category)
|
346 |
objects = [text1, text2, text3]
|
347 |
elif text2:
|
348 |
+
results = generate_two_objects_interaction_for_category(text1, text2, selected_category)
|
349 |
objects = [text1, text2]
|
350 |
else:
|
351 |
+
results = generate_single_object_transformation_for_category(text1, selected_category)
|
352 |
objects = [text1]
|
353 |
return enhance_descriptions(results, objects)
|
354 |
|
|
|
376 |
time.sleep(0.3)
|
377 |
progress(0.1, desc="μ°½μοΏ½οΏ½οΏ½μΈ μμ΄λμ΄ μμ± μμ...")
|
378 |
|
379 |
+
# μ νλ μΉ΄ν
κ³ λ¦¬λ§ λμμΌλ‘ μμ΄λμ΄ μμ±
|
380 |
+
results = generate_transformations(text1, text2, text3, selected_category)
|
|
|
|
|
|
|
|
|
|
|
381 |
|
382 |
progress(0.8, desc="κ²°κ³Ό ν¬λ§·ν
μ€...")
|
383 |
formatted = format_results(results)
|
|
|
388 |
# μλ‘μ΄ ν΅ν© ν¨μ: μμ΄λμ΄ ν
μ€νΈ μμ± λ° μ΄λ―Έμ§ μμ±
|
389 |
##############################################################################
|
390 |
def process_all(text1, text2, text3, selected_category, progress=gr.Progress()):
|
|
|
391 |
idea_result = process_inputs(text1, text2, text3, selected_category, progress)
|
|
|
392 |
image_result = generate_design_image(idea_result, seed=42, randomize_seed=True, width=1024, height=1024, num_inference_steps=4)
|
393 |
return idea_result, image_result
|
394 |
|
|
|
435 |
value=list(physical_transformation_categories.keys())[0],
|
436 |
info="μΆλ ₯ν μΉ΄ν
κ³ λ¦¬λ₯Ό μ ννμΈμ."
|
437 |
)
|
438 |
+
status_msg = gr.Markdown("π‘ 'μμ΄λμ΄ μμ±νκΈ°' λ²νΌμ ν΄λ¦νλ©΄ μ νν μΉ΄ν
κ³ λ¦¬μ ν΄λΉνλ μμ΄λμ΄μ λμμΈ μ΄λ―Έμ§κ° μμ±λ©λλ€.")
|
439 |
processing_indicator = gr.HTML("""
|
440 |
<div style="display: flex; justify-content: center; align-items: center; margin: 10px 0;">
|
441 |
<div style="border: 5px solid #f3f3f3; border-top: 5px solid #3498db; border-radius: 50%; width: 30px; height: 30px; animation: spin 2s linear infinite;"></div>
|