Abinivesh commited on
Commit
a9293ca
·
verified ·
1 Parent(s): e95bdda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  from random import randint
3
  from all_models import models
4
-
5
  from externalmod import gr_Interface_load, randomize_seed
6
 
7
  import asyncio
@@ -14,16 +13,18 @@ HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None
14
  def load_fn(models):
15
  global models_load
16
  models_load = {}
17
-
18
  for model in models:
19
  if model not in models_load:
20
  try:
21
  print(f"Loading model: {model}")
22
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
 
 
23
  print(f"Loaded model: {model}")
24
  except Exception as error:
25
  print(f"Error loading model {model}: {error}")
26
- m = gr.Interface(lambda: None, ['text'], ['image'])
27
  models_load[model] = m
28
 
29
  print("Loading models...")
@@ -45,16 +46,24 @@ def update_imgbox(choices):
45
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
46
 
47
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
 
 
 
 
 
48
  kwargs = {"seed": seed}
 
49
  print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
50
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
51
  try:
 
52
  result = await asyncio.wait_for(task, timeout=timeout)
53
  except Exception as e:
54
  print(f"Error during inference: {e}")
55
  if not task.done():
56
  task.cancel()
57
  return None
 
58
  if task.done() and result:
59
  with lock:
60
  result.save("image.png")
@@ -62,7 +71,7 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
62
  return None
63
 
64
  def gen_fnseed(model_str, prompt, seed=1):
65
- if model_str == 'NA':
66
  return None
67
  loop = asyncio.new_event_loop()
68
  result = loop.run_until_complete(infer(model_str, prompt, seed))
@@ -80,10 +89,10 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
80
  seed_rand.click(randomize_seed, None, [seed])
81
  output = [gr.Image(label=m) for m in default_models]
82
  current_models = [gr.Textbox(m, visible=False) for m in default_models]
83
-
84
  for m, o in zip(current_models, output):
85
  gen_button.click(gen_fnseed, [m, txt_input, seed], o)
86
-
87
  with gr.Accordion('Model selection'):
88
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
89
  model_choice.change(update_imgbox, model_choice, output)
 
1
  import gradio as gr
2
  from random import randint
3
  from all_models import models
 
4
  from externalmod import gr_Interface_load, randomize_seed
5
 
6
  import asyncio
 
13
  def load_fn(models):
14
  global models_load
15
  models_load = {}
16
+
17
  for model in models:
18
  if model not in models_load:
19
  try:
20
  print(f"Loading model: {model}")
21
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
22
+ if not hasattr(m, 'predict'):
23
+ raise ValueError(f"Model {model} does not have a 'predict' method.")
24
  print(f"Loaded model: {model}")
25
  except Exception as error:
26
  print(f"Error loading model {model}: {error}")
27
+ m = None # Ensure failed models are not stored
28
  models_load[model] = m
29
 
30
  print("Loading models...")
 
46
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
47
 
48
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
49
+ if model_str not in models_load or models_load[model_str] is None:
50
+ print(f"Model {model_str} is not available.")
51
+ return None
52
+
53
+ model = models_load[model_str]
54
  kwargs = {"seed": seed}
55
+
56
  print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
57
+
58
  try:
59
+ task = asyncio.create_task(asyncio.to_thread(model.predict, prompt, **kwargs))
60
  result = await asyncio.wait_for(task, timeout=timeout)
61
  except Exception as e:
62
  print(f"Error during inference: {e}")
63
  if not task.done():
64
  task.cancel()
65
  return None
66
+
67
  if task.done() and result:
68
  with lock:
69
  result.save("image.png")
 
71
  return None
72
 
73
  def gen_fnseed(model_str, prompt, seed=1):
74
+ if model_str == 'NA' or models_load.get(model_str) is None:
75
  return None
76
  loop = asyncio.new_event_loop()
77
  result = loop.run_until_complete(infer(model_str, prompt, seed))
 
89
  seed_rand.click(randomize_seed, None, [seed])
90
  output = [gr.Image(label=m) for m in default_models]
91
  current_models = [gr.Textbox(m, visible=False) for m in default_models]
92
+
93
  for m, o in zip(current_models, output):
94
  gen_button.click(gen_fnseed, [m, txt_input, seed], o)
95
+
96
  with gr.Accordion('Model selection'):
97
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
98
  model_choice.change(update_imgbox, model_choice, output)