noahfaire MosaFaire commited on
Commit
870e64e
·
1 Parent(s): f7f137e

more tweaks to get it working on cpu (#1)

Browse files

- more tweaks to get cpu to work (28604b78ceb26ffc39a74b097d71da236ea6e088)
- automatically pick cpu vs cuda (9224f11f64ba6bf3e570bc07b79eb0abf7bf3788)


Co-authored-by: Mosa <[email protected]>

Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -37,13 +37,22 @@ def download_image(url):
37
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
38
 
39
  model_path = "runwayml/stable-diffusion-inpainting"
 
40
 
41
- pipe = StableDiffusionInpaintPipeline.from_pretrained(
42
- model_path,
43
- revision="fp16",
44
- torch_dtype=torch.float16,
45
- use_auth_token=True
46
- )
 
 
 
 
 
 
 
 
47
 
48
  img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
49
  image = download_image(img_url).resize((512, 512))
@@ -53,7 +62,7 @@ prompt = "crazy portal universe"
53
 
54
  guidance_scale=7.5
55
  num_samples = 3
56
- generator = torch.Generator(device="cpu").manual_seed(0) # change the seed to get different results
57
  images = pipe(
58
  prompt=prompt,
59
  image=image,
 
37
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
38
 
39
  model_path = "runwayml/stable-diffusion-inpainting"
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
 
42
+ if device == "cuda":
43
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
44
+ model_path,
45
+ revision="fp16",
46
+ torch_dtype=torch.float16,
47
+ use_auth_token=True
48
+ ).to(device)
49
+ else:
50
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
51
+ model_path,
52
+ # revision="fp16",
53
+ # torch_dtype=torch.float16,
54
+ use_auth_token=True
55
+ ).to(device)
56
 
57
  img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
58
  image = download_image(img_url).resize((512, 512))
 
62
 
63
  guidance_scale=7.5
64
  num_samples = 3
65
+ generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
66
  images = pipe(
67
  prompt=prompt,
68
  image=image,