Abinivesh commited on
Commit
fca3535
·
verified ·
1 Parent(s): 7e1de43

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -1,42 +1,50 @@
1
  import gradio as gr
2
- from random import randint
3
- from all_models import models
4
- from externalmod import gr_Interface_load, randomize_seed
5
  import asyncio
6
  import os
 
7
  from threading import RLock
8
  from pathlib import Path
 
 
9
 
10
  # Create a lock for thread safety
11
  lock = RLock()
12
 
13
- # Load Hugging Face token from environment variable (if available)
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
 
16
- # Function to load models
17
  def load_fn(models):
18
  global models_load
19
  models_load = {}
 
20
  for model in models:
21
  if model not in models_load:
22
  try:
23
  print(f"Loading model: {model}")
24
- m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
 
 
 
 
 
25
  models_load[model] = m
26
  except Exception as e:
27
  print(f"Error loading model {model}: {e}")
28
- models_load[model] = gr.Interface(lambda: None, ['text'], ['image'])
29
 
30
  print("Loading models...")
31
  load_fn(models)
32
  print("Models loaded successfully.")
33
 
 
34
  num_models = 1
35
  starting_seed = randint(1941, 2024)
36
  MAX_SEED = 3999999999
37
- MAX_SEED = int(MAX_SEED)
38
  inference_timeout = 600
39
 
 
40
  def extend_choices(choices):
41
  return choices[:num_models] + ['NA'] * (num_models - len(choices))
42
 
@@ -44,58 +52,69 @@ def update_imgbox(choices):
44
  choices_extended = extend_choices(choices)
45
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_extended]
46
 
47
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
48
- if model_str not in models_load:
 
 
49
  return None
50
-
51
  kwargs = {"seed": seed}
52
  try:
53
  print(f"Running inference for model: {model_str} with prompt: '{prompt}'")
54
  result = await asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)
 
55
  if result:
56
  with lock:
57
  png_path = "image.png"
58
  result.save(png_path)
59
  return str(Path(png_path).resolve())
 
 
60
  except Exception as e:
61
  print(f"Error during inference for {model_str}: {e}")
 
62
  return None
63
 
 
64
  def gen_fnseed(model_str, prompt, seed=1):
65
  if model_str == 'NA':
66
  return None
 
67
  try:
68
  loop = asyncio.new_event_loop()
69
  asyncio.set_event_loop(loop)
70
- result = loop.run_until_complete(infer(model_str, prompt, seed, inference_timeout))
71
  except Exception as e:
72
  print(f"Error generating image for {model_str}: {e}")
73
  result = None
74
  finally:
75
  loop.close()
 
76
  return result
77
 
 
78
  print("Creating Gradio interface...")
79
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
80
  gr.HTML("<center><h1>Compare-6</h1></center>")
 
81
  with gr.Tab('Compare-6'):
82
  txt_input = gr.Textbox(label='Your prompt:', lines=4)
83
  gen_button = gr.Button('Generate up to 6 images')
84
  seed = gr.Slider(label="Seed (0 to MAX)", minimum=0, maximum=MAX_SEED, value=starting_seed)
85
  seed_rand = gr.Button("Randomize Seed 🎲")
86
-
87
  seed_rand.click(randomize_seed, None, [seed], queue=False)
88
-
89
  output = [gr.Image(label=m) for m in models[:num_models]]
90
  current_models = [gr.Textbox(m, visible=False) for m in models[:num_models]]
91
-
92
  for m, o in zip(current_models, output):
93
  gen_button.click(gen_fnseed, inputs=[m, txt_input, seed], outputs=[o], queue=False)
94
-
95
  with gr.Accordion('Model selection'):
96
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models')
97
  model_choice.change(update_imgbox, model_choice, output)
98
  model_choice.change(extend_choices, model_choice, current_models)
99
 
100
- demo.queue(default_concurrency_limit=50, max_size=100)
101
  demo.launch(show_api=False)
 
1
  import gradio as gr
2
+ import torch
 
 
3
  import asyncio
4
  import os
5
+ from random import randint
6
  from threading import RLock
7
  from pathlib import Path
8
+ from all_models import models
9
+ from externalmod import gr_Interface_load, randomize_seed
10
 
11
  # Create a lock for thread safety
12
  lock = RLock()
13
 
14
+ # Load Hugging Face token from environment variable
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
+ # Function to load models with optimized settings
18
  def load_fn(models):
19
  global models_load
20
  models_load = {}
21
+
22
  for model in models:
23
  if model not in models_load:
24
  try:
25
  print(f"Loading model: {model}")
26
+ m = gr_Interface_load(
27
+ f'models/{model}',
28
+ hf_token=HF_TOKEN,
29
+ torch_dtype=torch.float16 # Reduce memory usage
30
+ )
31
+ m.enable_model_cpu_offload() # Offload to CPU when not in use
32
  models_load[model] = m
33
  except Exception as e:
34
  print(f"Error loading model {model}: {e}")
35
+ models_load[model] = None
36
 
37
  print("Loading models...")
38
  load_fn(models)
39
  print("Models loaded successfully.")
40
 
41
+ # Constants
42
  num_models = 1
43
  starting_seed = randint(1941, 2024)
44
  MAX_SEED = 3999999999
 
45
  inference_timeout = 600
46
 
47
+ # Update UI components
48
  def extend_choices(choices):
49
  return choices[:num_models] + ['NA'] * (num_models - len(choices))
50
 
 
52
  choices_extended = extend_choices(choices)
53
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_extended]
54
 
55
+ # Async inference function
56
+ async def infer(model_str, prompt, seed=1):
57
+ if model_str not in models_load or models_load[model_str] is None:
58
+ print(f"Model {model_str} is unavailable.")
59
  return None
60
+
61
  kwargs = {"seed": seed}
62
  try:
63
  print(f"Running inference for model: {model_str} with prompt: '{prompt}'")
64
  result = await asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)
65
+
66
  if result:
67
  with lock:
68
  png_path = "image.png"
69
  result.save(png_path)
70
  return str(Path(png_path).resolve())
71
+ except torch.cuda.OutOfMemoryError:
72
+ print(f"CUDA memory error for {model_str}. Try reducing image size.")
73
  except Exception as e:
74
  print(f"Error during inference for {model_str}: {e}")
75
+
76
  return None
77
 
78
+ # Synchronous wrapper
79
  def gen_fnseed(model_str, prompt, seed=1):
80
  if model_str == 'NA':
81
  return None
82
+
83
  try:
84
  loop = asyncio.new_event_loop()
85
  asyncio.set_event_loop(loop)
86
+ result = loop.run_until_complete(infer(model_str, prompt, seed))
87
  except Exception as e:
88
  print(f"Error generating image for {model_str}: {e}")
89
  result = None
90
  finally:
91
  loop.close()
92
+
93
  return result
94
 
95
+ # Gradio UI
96
  print("Creating Gradio interface...")
97
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
98
  gr.HTML("<center><h1>Compare-6</h1></center>")
99
+
100
  with gr.Tab('Compare-6'):
101
  txt_input = gr.Textbox(label='Your prompt:', lines=4)
102
  gen_button = gr.Button('Generate up to 6 images')
103
  seed = gr.Slider(label="Seed (0 to MAX)", minimum=0, maximum=MAX_SEED, value=starting_seed)
104
  seed_rand = gr.Button("Randomize Seed 🎲")
105
+
106
  seed_rand.click(randomize_seed, None, [seed], queue=False)
107
+
108
  output = [gr.Image(label=m) for m in models[:num_models]]
109
  current_models = [gr.Textbox(m, visible=False) for m in models[:num_models]]
110
+
111
  for m, o in zip(current_models, output):
112
  gen_button.click(gen_fnseed, inputs=[m, txt_input, seed], outputs=[o], queue=False)
113
+
114
  with gr.Accordion('Model selection'):
115
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models')
116
  model_choice.change(update_imgbox, model_choice, output)
117
  model_choice.change(extend_choices, model_choice, current_models)
118
 
119
+ demo.queue(default_concurrency_limit=20, max_size=50) # Adjusted for better stability
120
  demo.launch(show_api=False)