fffiloni commited on
Commit
a408b27
·
1 Parent(s): 7a38b89

raise a gr.Error if model is missing

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -19,13 +19,21 @@ pipe = DiffusionPipeline.from_pretrained(
19
  device="cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  def load_model(custom_model, weight_name):
 
 
 
 
 
22
  # This is where you load your trained weights
23
  pipe.load_lora_weights(custom_model, weight_name=weight_name, use_auth_token=True)
24
  pipe.to(device)
 
25
  return "Model loaded!"
26
 
27
  def infer (prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
 
28
  generator = torch.Generator(device="cuda").manual_seed(seed)
 
29
  image = pipe(
30
  prompt=prompt,
31
  num_inference_steps=inf_steps,
@@ -33,6 +41,7 @@ def infer (prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Pro
33
  generator=generator,
34
  cross_attention_kwargs={"scale": lora_weight}
35
  ).images[0]
 
36
  return image
37
 
38
  css="""
 
19
  device="cuda" if torch.cuda.is_available() else "cpu"
20
 
21
  def load_model(custom_model, weight_name):
22
+
23
+ if custom_model == "":
24
+ gr.Warning("If you want to use a private model, you need to duplicate this space on your personal account.")
25
+ raise gr.Error("You forgot to define Model ID.")
26
+
27
  # This is where you load your trained weights
28
  pipe.load_lora_weights(custom_model, weight_name=weight_name, use_auth_token=True)
29
  pipe.to(device)
30
+
31
  return "Model loaded!"
32
 
33
  def infer (prompt, inf_steps, guidance_scale, seed, lora_weight, progress=gr.Progress(track_tqdm=True)):
34
+
35
  generator = torch.Generator(device="cuda").manual_seed(seed)
36
+
37
  image = pipe(
38
  prompt=prompt,
39
  num_inference_steps=inf_steps,
 
41
  generator=generator,
42
  cross_attention_kwargs={"scale": lora_weight}
43
  ).images[0]
44
+
45
  return image
46
 
47
  css="""