aiqtech commited on
Commit
57bc130
·
verified ·
1 Parent(s): 3ff67eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -37
app.py CHANGED
@@ -34,35 +34,46 @@ class GlobalVars:
34
  g = GlobalVars()
35
 
36
  def initialize_models(device):
37
- # 3D 생성 파이프라인
38
- g.trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
39
- "JeffreyXiang/TRELLIS-image-large"
40
- )
41
- if torch.cuda.is_available():
42
- g.trellis_pipeline = g.trellis_pipeline.to("cuda")
43
-
44
- # 이미지 생성 파이프라인
45
- g.flux_pipe = FluxPipeline.from_pretrained(
46
- "black-forest-labs/FLUX.1-dev",
47
- torch_dtype=torch.bfloat16,
48
- device_map="balanced"
49
- )
50
-
51
- # Hyper-SD LoRA 로드
52
- lora_path = hf_hub_download(
53
- "ByteDance/Hyper-SD",
54
- "Hyper-FLUX.1-dev-8steps-lora.safetensors",
55
- use_auth_token=HF_TOKEN
56
- )
57
- g.flux_pipe.load_lora_weights(lora_path)
58
- g.flux_pipe.fuse_lora(lora_scale=0.125)
59
-
60
- # 번역기 초기화
61
- g.translator = transformers_pipeline(
62
- "translation",
63
- model="Helsinki-NLP/opus-mt-ko-en",
64
- device=device if device != "cuda" else 0
65
- )
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # CUDA 메모리 관리 설정
68
  torch.cuda.empty_cache()
@@ -113,6 +124,10 @@ def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]:
113
  return "", None
114
 
115
  try:
 
 
 
 
116
  # webp 이미지를 RGB로 변환
117
  if isinstance(image, str) and image.endswith('.webp'):
118
  image = Image.open(image).convert('RGB')
@@ -370,7 +385,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
370
  trial_id = gr.Textbox(visible=False)
371
  output_buf = gr.State()
372
 
373
- # Examples 갤러리를 맨 아래로 이동
374
  if example_images:
375
  gr.Markdown("""### Example Images""")
376
  with gr.Row():
@@ -379,11 +394,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
379
  label="Click an image to use it",
380
  show_label=True,
381
  elem_id="gallery",
382
- columns=5,
383
- rows=5,
384
- height=600,
385
- allow_preview=True
 
386
  )
 
 
387
 
388
  def load_example(evt: gr.SelectData):
389
  selected_image = Image.open(example_images[evt.index])
@@ -454,24 +472,33 @@ if __name__ == "__main__":
454
  device = "cuda" if torch.cuda.is_available() else "cpu"
455
  print(f"Using device: {device}")
456
 
 
 
 
 
 
457
  # 모델 초기화
458
  initialize_models(device)
459
 
460
  # 초기 이미지 전처리 테스트
461
  try:
462
  test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
463
- g.trellis_pipeline.preprocess_image(test_image)
 
 
 
464
  except Exception as e:
465
  print(f"Warning: Initial preprocessing test failed: {e}")
466
 
467
  # Gradio 인터페이스 실행
468
  demo.queue() # 큐 기능 활성화
469
  demo.launch(
470
- allowed_paths=[PERSISTENT_DIR],
471
  server_name="0.0.0.0",
472
  server_port=7860,
473
  show_error=True,
474
- share=True # share를 True로 설정
 
475
  )
476
 
477
  except Exception as e:
 
34
  g = GlobalVars()
35
 
36
  def initialize_models(device):
37
+ try:
38
+ print("Initializing models...")
39
+ # 3D 생성 파이프라인
40
+ g.trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained(
41
+ "JeffreyXiang/TRELLIS-image-large"
42
+ )
43
+ if torch.cuda.is_available():
44
+ print("Moving trellis_pipeline to CUDA")
45
+ g.trellis_pipeline = g.trellis_pipeline.to("cuda")
46
+
47
+ # 이미지 생성 파이프라인
48
+ print("Loading flux_pipe...")
49
+ g.flux_pipe = FluxPipeline.from_pretrained(
50
+ "black-forest-labs/FLUX.1-dev",
51
+ torch_dtype=torch.bfloat16,
52
+ device_map="balanced"
53
+ )
54
+
55
+ # Hyper-SD LoRA 로드
56
+ print("Loading LoRA weights...")
57
+ lora_path = hf_hub_download(
58
+ "ByteDance/Hyper-SD",
59
+ "Hyper-FLUX.1-dev-8steps-lora.safetensors",
60
+ use_auth_token=HF_TOKEN
61
+ )
62
+ g.flux_pipe.load_lora_weights(lora_path)
63
+ g.flux_pipe.fuse_lora(lora_scale=0.125)
64
+
65
+ # 번역기 초기화
66
+ print("Initializing translator...")
67
+ g.translator = transformers_pipeline(
68
+ "translation",
69
+ model="Helsinki-NLP/opus-mt-ko-en",
70
+ device=device if device != "cuda" else 0
71
+ )
72
+ print("Model initialization completed successfully")
73
+
74
+ except Exception as e:
75
+ print(f"Error during model initialization: {str(e)}")
76
+ raise
77
 
78
  # CUDA 메모리 관리 설정
79
  torch.cuda.empty_cache()
 
124
  return "", None
125
 
126
  try:
127
+ if g.trellis_pipeline is None:
128
+ print("Error: trellis_pipeline is not initialized")
129
+ return "", None
130
+
131
  # webp 이미지를 RGB로 변환
132
  if isinstance(image, str) and image.endswith('.webp'):
133
  image = Image.open(image).convert('RGB')
 
385
  trial_id = gr.Textbox(visible=False)
386
  output_buf = gr.State()
387
 
388
+ # Examples 갤러리를 맨 아래로 이동
389
  if example_images:
390
  gr.Markdown("""### Example Images""")
391
  with gr.Row():
 
394
  label="Click an image to use it",
395
  show_label=True,
396
  elem_id="gallery",
397
+ columns=12, # 한 줄에 12개
398
+ rows=2, # 2줄
399
+ height=300, # 높이 조정
400
+ allow_preview=True,
401
+ object_fit="contain" # 이미지 비율 유지
402
  )
403
+
404
+
405
 
406
  def load_example(evt: gr.SelectData):
407
  selected_image = Image.open(example_images[evt.index])
 
472
  device = "cuda" if torch.cuda.is_available() else "cpu"
473
  print(f"Using device: {device}")
474
 
475
+ # CUDA 메모리 초기화
476
+ if torch.cuda.is_available():
477
+ torch.cuda.empty_cache()
478
+ torch.cuda.synchronize()
479
+
480
  # 모델 초기화
481
  initialize_models(device)
482
 
483
  # 초기 이미지 전처리 테스트
484
  try:
485
  test_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))
486
+ if g.trellis_pipeline is not None:
487
+ g.trellis_pipeline.preprocess_image(test_image)
488
+ else:
489
+ print("Warning: trellis_pipeline is None")
490
  except Exception as e:
491
  print(f"Warning: Initial preprocessing test failed: {e}")
492
 
493
  # Gradio 인터페이스 실행
494
  demo.queue() # 큐 기능 활성화
495
  demo.launch(
496
+ allowed_paths=[PERSISTENT_DIR, TMP_DIR], # TMP_DIR 추가
497
  server_name="0.0.0.0",
498
  server_port=7860,
499
  show_error=True,
500
+ share=True, # share를 True로 설정
501
+ enable_queue=True # 큐 활성화
502
  )
503
 
504
  except Exception as e: