Abinivesh commited on
Commit
2d762c7
·
verified ·
1 Parent(s): 9bfd4be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -14,7 +14,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None
14
  def load_fn(models):
15
  global models_load
16
  models_load = {}
17
-
18
  for model in models:
19
  if model not in models_load:
20
  try:
@@ -23,7 +23,7 @@ def load_fn(models):
23
  print(f"Loaded model: {model}")
24
  except Exception as error:
25
  print(f"Error loading model {model}: {error}")
26
- m = gr.Interface(lambda: None, ['text'], ['image'])
27
  models_load[model] = m
28
 
29
  print("Loading models...")
@@ -45,49 +45,54 @@ def update_imgbox(choices):
45
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
46
 
47
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
 
 
 
 
48
  kwargs = {"seed": seed}
49
  print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
50
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
51
  try:
52
- result = await asyncio.wait_for(task, timeout=timeout)
 
 
 
 
 
 
 
 
53
  except Exception as e:
54
  print(f"Error during inference: {e}")
55
- if not task.done():
56
- task.cancel()
57
- return None
58
- if task.done() and result:
59
- with lock:
60
- result.save("image.png")
61
- return "image.png"
62
  return None
63
 
64
  def gen_fnseed(model_str, prompt, seed=1):
65
  if model_str == 'NA':
66
  return None
67
- loop = asyncio.new_event_loop()
68
- result = loop.run_until_complete(infer(model_str, prompt, seed))
69
- loop.close()
70
- return result
71
 
72
  print("Creating Gradio interface...")
73
  with gr.Blocks(theme="gradio/soft") as demo:
74
  gr.HTML("<center><h1>TEXT-IMAGE-USING-MULTIMODELS</h1></center>")
 
75
  with gr.Tab():
76
  txt_input = gr.Textbox(label='Your prompt:', lines=4)
77
  gen_button = gr.Button('Generate')
78
  seed = gr.Slider("Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
79
  seed_rand = gr.Button("Randomize Seed 🎲")
80
  seed_rand.click(randomize_seed, None, [seed])
 
81
  output = [gr.Image(label=m) for m in default_models]
82
  current_models = [gr.Textbox(m, visible=False) for m in default_models]
83
-
84
  for m, o in zip(current_models, output):
85
  gen_button.click(gen_fnseed, [m, txt_input, seed], o)
86
-
87
  with gr.Accordion('Model selection'):
88
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
89
  model_choice.change(update_imgbox, model_choice, output)
90
  model_choice.change(extend_choices, model_choice, current_models)
91
 
92
- demo.queue(default_concurrency_limit=200, max_size=200)
93
  demo.launch(show_api=False, max_threads=400)
 
14
  def load_fn(models):
15
  global models_load
16
  models_load = {}
17
+
18
  for model in models:
19
  if model not in models_load:
20
  try:
 
23
  print(f"Loaded model: {model}")
24
  except Exception as error:
25
  print(f"Error loading model {model}: {error}")
26
+ m = None # Avoid using gr.Interface here
27
  models_load[model] = m
28
 
29
  print("Loading models...")
 
45
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
46
 
47
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
48
+ if model_str not in models_load or models_load[model_str] is None:
49
+ print(f"Model {model_str} is not available.")
50
+ return None
51
+
52
  kwargs = {"seed": seed}
53
  print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
54
+
55
  try:
56
+ result = await asyncio.wait_for(
57
+ asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs),
58
+ timeout=timeout
59
+ )
60
+ if result:
61
+ save_path = "image.png"
62
+ with lock:
63
+ result.save(save_path)
64
+ return save_path
65
  except Exception as e:
66
  print(f"Error during inference: {e}")
67
+
 
 
 
 
 
 
68
  return None
69
 
70
  def gen_fnseed(model_str, prompt, seed=1):
71
  if model_str == 'NA':
72
  return None
73
+ return asyncio.run(infer(model_str, prompt, seed))
 
 
 
74
 
75
  print("Creating Gradio interface...")
76
  with gr.Blocks(theme="gradio/soft") as demo:
77
  gr.HTML("<center><h1>TEXT-IMAGE-USING-MULTIMODELS</h1></center>")
78
+
79
  with gr.Tab():
80
  txt_input = gr.Textbox(label='Your prompt:', lines=4)
81
  gen_button = gr.Button('Generate')
82
  seed = gr.Slider("Seed", minimum=0, maximum=MAX_SEED, step=1, value=starting_seed)
83
  seed_rand = gr.Button("Randomize Seed 🎲")
84
  seed_rand.click(randomize_seed, None, [seed])
85
+
86
  output = [gr.Image(label=m) for m in default_models]
87
  current_models = [gr.Textbox(m, visible=False) for m in default_models]
88
+
89
  for m, o in zip(current_models, output):
90
  gen_button.click(gen_fnseed, [m, txt_input, seed], o)
91
+
92
  with gr.Accordion('Model selection'):
93
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
94
  model_choice.change(update_imgbox, model_choice, output)
95
  model_choice.change(extend_choices, model_choice, current_models)
96
 
97
+ demo.queue(default_concurrency_limit=500, max_size=500)
98
  demo.launch(show_api=False, max_threads=400)