harmionestark commited on
Commit
a31088e
Β·
verified Β·
1 Parent(s): 5d5af54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -16,9 +16,9 @@ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
16
 
17
  # Explicitly convert submodules to float32 to prevent dtype mismatch
18
  pipe.to(device)
19
- pipe.text_encoder.to(dtype=torch.float32)
20
- pipe.vae.to(dtype=torch.float32)
21
- pipe.unet.to(dtype=torch.float32)
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
@@ -38,7 +38,11 @@ def infer(
38
  if randomize_seed:
39
  seed = random.randint(0, MAX_SEED)
40
 
41
- generator = torch.Generator().manual_seed(seed)
 
 
 
 
42
 
43
  image = pipe(
44
  prompt=prompt,
@@ -152,4 +156,4 @@ with gr.Blocks(css=css) as demo:
152
  )
153
 
154
  if __name__ == "__main__":
155
- demo.launch()
 
16
 
17
  # Explicitly convert submodules to float32 to prevent dtype mismatch
18
  pipe.to(device)
19
+ pipe.text_encoder.to(device, dtype=torch.float32)
20
+ pipe.vae.to(device, dtype=torch.float32)
21
+ pipe.unet.to(device, dtype=torch.float32)
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  MAX_IMAGE_SIZE = 1024
 
38
  if randomize_seed:
39
  seed = random.randint(0, MAX_SEED)
40
 
41
+ generator = torch.Generator(device).manual_seed(seed)
42
+
43
+ # Ensure text inputs are moved to the correct device and dtype
44
+ prompt = str(prompt) if prompt else ""
45
+ negative_prompt = str(negative_prompt) if negative_prompt else ""
46
 
47
  image = pipe(
48
  prompt=prompt,
 
156
  )
157
 
158
  if __name__ == "__main__":
159
+ demo.launch()