ruslanmv commited on
Commit
8e24c1f
·
verified ·
1 Parent(s): bcd53e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -26,10 +26,10 @@ if os.environ.get("SPACES_ZERO_GPU") is not None:
26
  else:
27
  class spaces:
28
  @staticmethod
29
- def GPU(func):
30
- def wrapper(*args, **kwargs):
31
- return func(*args, **kwargs)
32
- return wrapper
33
 
34
  # Download necessary NLTK data
35
  def setup_nltk():
@@ -82,23 +82,35 @@ def check_gpu_availability():
82
  check_gpu_availability()
83
 
84
  # GPU-Safe MinDalle Model Loading
85
- def initialize_min_dalle_with_gpu():
86
  """Load the MinDalle model with GPU support."""
87
- @spaces.GPU()
88
- def load_model():
89
- print("Loading MinDalle model...")
 
 
 
 
 
 
 
 
 
 
 
 
90
  return MinDalle(
91
  is_mega=True,
92
  models_root='pretrained',
93
  is_reusable=False,
94
  is_verbose=True,
95
- dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
96
- device='cuda' if torch.cuda.is_available() else 'cpu'
97
  )
98
- return load_model()
99
 
100
  # Initialize MinDalle Model
101
- min_dalle_model = initialize_min_dalle_with_gpu()
 
102
 
103
 
104
  def generate_image_with_min_dalle(
 
26
  else:
27
  class spaces:
28
  @staticmethod
29
+ def GPU(func=None, duration=None):
30
+ def wrapper(fn):
31
+ return fn
32
+ return wrapper if func is None else wrapper(func)
33
 
34
  # Download necessary NLTK data
35
  def setup_nltk():
 
82
  check_gpu_availability()
83
 
84
  # GPU-Safe MinDalle Model Loading
85
+ def initialize_min_dalle():
86
  """Load the MinDalle model with GPU support."""
87
+ if torch.cuda.is_available():
88
+ @spaces.GPU(duration=60 * 3)
89
+ def load_model():
90
+ print("Loading MinDalle model on GPU...")
91
+ return MinDalle(
92
+ is_mega=True,
93
+ models_root='pretrained',
94
+ is_reusable=False,
95
+ is_verbose=True,
96
+ dtype=torch.float16,
97
+ device='cuda'
98
+ )
99
+ return load_model()
100
+ else:
101
+ print("Loading MinDalle model on CPU...")
102
  return MinDalle(
103
  is_mega=True,
104
  models_root='pretrained',
105
  is_reusable=False,
106
  is_verbose=True,
107
+ dtype=torch.float32,
108
+ device='cpu'
109
  )
 
110
 
111
  # Initialize MinDalle Model
112
+ min_dalle_model = initialize_min_dalle()
113
+
114
 
115
 
116
  def generate_image_with_min_dalle(