ruslanmv commited on
Commit
0036d0a
·
verified ·
1 Parent(s): 5c9fdb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -7
app.py CHANGED
@@ -19,6 +19,17 @@ import matplotlib.pyplot as plt
19
  import gc # Import the garbage collector
20
  from audio import *
21
  import os
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Download necessary NLTK data
24
  try:
@@ -46,30 +57,39 @@ def log_gpu_memory():
46
  print("CUDA is not available. Cannot log GPU memory.")
47
 
48
  # --------- MinDalle Image Generation Functions ---------
49
-
50
  # Load MinDalle model once
51
- def load_min_dalle_model(models_root: str = 'pretrained', fp16: bool = True):
 
52
  """
53
- Load the MinDalle model.
54
 
55
  Args:
56
  models_root: Path to the directory containing MinDalle models.
57
- fp16: Whether to use float16 for faster generation (requires CUDA).
58
 
59
  Returns:
60
  An instance of the MinDalle model.
61
  """
62
  print("DEBUG: Loading MinDalle model...")
 
 
 
 
 
 
 
 
 
 
63
  return MinDalle(
64
  is_mega=True,
65
  models_root=models_root,
66
- is_reusable=False, # Set is_reusable to False
67
  is_verbose=True,
68
- dtype=torch.float16 if fp16 else torch.float32,
69
  device=device
70
  )
71
 
72
- # Initialize the MinDalle model
73
  min_dalle_model = load_min_dalle_model()
74
 
75
  def generate_image_with_min_dalle(
 
19
  import gc # Import the garbage collector
20
  from audio import *
21
  import os
22
+ # Define a fallback for environments without GPU
23
+ if os.environ.get("SPACES_ZERO_GPU") is not None:
24
+ import spaces
25
+ else:
26
+ class spaces:
27
+ @staticmethod
28
+ def GPU(func):
29
+ def wrapper(*args, **kwargs):
30
+ return func(*args, **kwargs)
31
+ return wrapper
32
+
33
 
34
  # Download necessary NLTK data
35
  try:
 
57
  print("CUDA is not available. Cannot log GPU memory.")
58
 
59
  # --------- MinDalle Image Generation Functions ---------
 
60
  # Load MinDalle model once
61
+ # Dynamically determine device and precision
62
+ def load_min_dalle_model(models_root: str = 'pretrained'):
63
  """
64
+ Load the MinDalle model, automatically selecting device and precision.
65
 
66
  Args:
67
  models_root: Path to the directory containing MinDalle models.
 
68
 
69
  Returns:
70
  An instance of the MinDalle model.
71
  """
72
  print("DEBUG: Loading MinDalle model...")
73
+
74
+ if torch.cuda.is_available():
75
+ device = 'cuda'
76
+ dtype = torch.float16
77
+ print("DEBUG: Using GPU with float16 precision.")
78
+ else:
79
+ device = 'cpu'
80
+ dtype = torch.float32
81
+ print("DEBUG: Using CPU with float32 precision.")
82
+
83
  return MinDalle(
84
  is_mega=True,
85
  models_root=models_root,
86
+ is_reusable=False,
87
  is_verbose=True,
88
+ dtype=dtype,
89
  device=device
90
  )
91
 
92
+ # Initialize the MinDalle model (will now automatically use GPU if available)
93
  min_dalle_model = load_min_dalle_model()
94
 
95
  def generate_image_with_min_dalle(