Surn commited on
Commit
ba4dacb
·
1 Parent(s): 995bb60
Files changed (1) hide show
  1. audiocraft/utils/extend.py +69 -2
audiocraft/utils/extend.py CHANGED
@@ -13,6 +13,7 @@ from io import BytesIO
13
  from huggingface_hub import hf_hub_download
14
  import librosa
15
  import gradio as gr
 
16
 
17
 
18
  INTERRUPTING = False
@@ -224,10 +225,76 @@ def save_image(image):
224
  finally:
225
  return file_path
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  def hex_to_rgba(hex_color):
228
  try:
229
- # Convert hex color to RGBA tuple
230
- rgba = ImageColor.getcolor(hex_color, "RGBA")
 
 
 
 
231
  except ValueError:
232
  # If the hex color is invalid, default to yellow
233
  rgba = (255,255,0,255)
 
13
  from huggingface_hub import hf_hub_download
14
  import librosa
15
  import gradio as gr
16
+ import re
17
 
18
 
19
  INTERRUPTING = False
 
225
  finally:
226
  return file_path
227
 
228
+ def detect_color_format(color):
229
+ """
230
+ Detects if the color is in RGB, RGBA, or hex format,
231
+ and converts it to an RGBA tuple with integer components.
232
+
233
+ Args:
234
+ color (str or tuple): The color to detect.
235
+
236
+ Returns:
237
+ tuple: The color in RGBA format as a tuple of 4 integers.
238
+
239
+ Raises:
240
+ ValueError: If the input color is not in a recognized format.
241
+ """
242
+ # Handle color as a tuple of floats or integers
243
+ if isinstance(color, tuple):
244
+ if len(color) == 3 or len(color) == 4:
245
+ # Ensure all components are numbers
246
+ if all(isinstance(c, (int, float)) for c in color):
247
+ r, g, b = color[:3]
248
+ a = color[3] if len(color) == 4 else 255
249
+ return (
250
+ max(0, min(255, int(round(r)))),
251
+ max(0, min(255, int(round(g)))),
252
+ max(0, min(255, int(round(b)))),
253
+ max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
254
+ )
255
+ else:
256
+ raise ValueError(f"Invalid color tuple length: {len(color)}")
257
+ # Handle hex color codes
258
+ if isinstance(color, str):
259
+ color = color.strip()
260
+ # Try to use PIL's ImageColor
261
+ try:
262
+ rgba = ImageColor.getcolor(color, "RGBA")
263
+ return rgba
264
+ except ValueError:
265
+ pass
266
+ # Handle 'rgba(r, g, b, a)' string format
267
+ rgba_match = re.match(r'rgba\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
268
+ if rgba_match:
269
+ r, g, b, a = map(float, rgba_match.groups())
270
+ return (
271
+ max(0, min(255, int(round(r)))),
272
+ max(0, min(255, int(round(g)))),
273
+ max(0, min(255, int(round(b)))),
274
+ max(0, min(255, int(round(a * 255)) if a <= 1 else round(a))),
275
+ )
276
+ # Handle 'rgb(r, g, b)' string format
277
+ rgb_match = re.match(r'rgb\(\s*([0-9.]+),\s*([0-9.]+),\s*([0-9.]+)\s*\)', color)
278
+ if rgb_match:
279
+ r, g, b = map(float, rgb_match.groups())
280
+ return (
281
+ max(0, min(255, int(round(r)))),
282
+ max(0, min(255, int(round(g)))),
283
+ max(0, min(255, int(round(b)))),
284
+ 255,
285
+ )
286
+
287
+ # If none of the above conversions work, raise an error
288
+ raise ValueError(f"Invalid color format: {color}")
289
+
290
  def hex_to_rgba(hex_color):
291
  try:
292
+ if hex_color.startswith("#"):
293
+ clean_hex = hex_color.replace('#','')
294
+ # Use a generator expression to convert pairs of hexadecimal digits to integers and create a tuple
295
+ rgba = tuple(int(clean_hex[i:i+2], 16) for i in range(0, len(clean_hex),2))
296
+ else:
297
+ rgba = tuple(map(int,detect_color_format(hex_color)))
298
  except ValueError:
299
  # If the hex color is invalid, default to yellow
300
  rgba = (255,255,0,255)