Abinivesh commited on
Commit
a51b33a
·
verified ·
1 Parent(s): a9293ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -16
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from random import randint
3
  from all_models import models
 
4
  from externalmod import gr_Interface_load, randomize_seed
5
 
6
  import asyncio
@@ -13,18 +14,16 @@ HF_TOKEN = os.environ.get("HF_TOKEN") if os.environ.get("HF_TOKEN") else None
13
  def load_fn(models):
14
  global models_load
15
  models_load = {}
16
-
17
  for model in models:
18
  if model not in models_load:
19
  try:
20
  print(f"Loading model: {model}")
21
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
22
- if not hasattr(m, 'predict'):
23
- raise ValueError(f"Model {model} does not have a 'predict' method.")
24
  print(f"Loaded model: {model}")
25
  except Exception as error:
26
  print(f"Error loading model {model}: {error}")
27
- m = None # Ensure failed models are not stored
28
  models_load[model] = m
29
 
30
  print("Loading models...")
@@ -46,24 +45,16 @@ def update_imgbox(choices):
46
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
47
 
48
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
49
- if model_str not in models_load or models_load[model_str] is None:
50
- print(f"Model {model_str} is not available.")
51
- return None
52
-
53
- model = models_load[model_str]
54
  kwargs = {"seed": seed}
55
-
56
  print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
57
-
58
  try:
59
- task = asyncio.create_task(asyncio.to_thread(model.predict, prompt, **kwargs))
60
  result = await asyncio.wait_for(task, timeout=timeout)
61
  except Exception as e:
62
  print(f"Error during inference: {e}")
63
  if not task.done():
64
  task.cancel()
65
  return None
66
-
67
  if task.done() and result:
68
  with lock:
69
  result.save("image.png")
@@ -71,7 +62,7 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
71
  return None
72
 
73
  def gen_fnseed(model_str, prompt, seed=1):
74
- if model_str == 'NA' or models_load.get(model_str) is None:
75
  return None
76
  loop = asyncio.new_event_loop()
77
  result = loop.run_until_complete(infer(model_str, prompt, seed))
@@ -89,10 +80,10 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
89
  seed_rand.click(randomize_seed, None, [seed])
90
  output = [gr.Image(label=m) for m in default_models]
91
  current_models = [gr.Textbox(m, visible=False) for m in default_models]
92
-
93
  for m, o in zip(current_models, output):
94
  gen_button.click(gen_fnseed, [m, txt_input, seed], o)
95
-
96
  with gr.Accordion('Model selection'):
97
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
98
  model_choice.change(update_imgbox, model_choice, output)
 
1
  import gradio as gr
2
  from random import randint
3
  from all_models import models
4
+
5
  from externalmod import gr_Interface_load, randomize_seed
6
 
7
  import asyncio
 
14
  def load_fn(models):
15
  global models_load
16
  models_load = {}
17
+
18
  for model in models:
19
  if model not in models_load:
20
  try:
21
  print(f"Loading model: {model}")
22
  m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
 
 
23
  print(f"Loaded model: {model}")
24
  except Exception as error:
25
  print(f"Error loading model {model}: {error}")
26
+ m = gr.Interface(lambda: None, ['text'], ['image'])
27
  models_load[model] = m
28
 
29
  print("Loading models...")
 
45
  return [gr.Image(None, label=m, visible=(m != 'NA')) for m in choices_plus]
46
 
47
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
 
 
 
 
 
48
  kwargs = {"seed": seed}
 
49
  print(f"Starting inference: {model_str} | Prompt: '{prompt}' | Seed: {seed}")
50
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
51
  try:
 
52
  result = await asyncio.wait_for(task, timeout=timeout)
53
  except Exception as e:
54
  print(f"Error during inference: {e}")
55
  if not task.done():
56
  task.cancel()
57
  return None
 
58
  if task.done() and result:
59
  with lock:
60
  result.save("image.png")
 
62
  return None
63
 
64
  def gen_fnseed(model_str, prompt, seed=1):
65
+ if model_str == 'NA':
66
  return None
67
  loop = asyncio.new_event_loop()
68
  result = loop.run_until_complete(infer(model_str, prompt, seed))
 
80
  seed_rand.click(randomize_seed, None, [seed])
81
  output = [gr.Image(label=m) for m in default_models]
82
  current_models = [gr.Textbox(m, visible=False) for m in default_models]
83
+
84
  for m, o in zip(current_models, output):
85
  gen_button.click(gen_fnseed, [m, txt_input, seed], o)
86
+
87
  with gr.Accordion('Model selection'):
88
  model_choice = gr.CheckboxGroup(models, label=f'Choose up to {num_models} models', value=default_models)
89
  model_choice.change(update_imgbox, model_choice, output)