noahfaire commited on
Commit
c7ed554
·
1 Parent(s): d17dbf0

rearranged fun

Browse files
Files changed (1) hide show
  1. app.py +52 -21
app.py CHANGED
@@ -9,34 +9,65 @@ from rembg import remove
9
  import requests
10
  from io import BytesIO
11
 
12
- # def image_grid(imgs, rows, cols):
13
- # assert len(imgs) == rows*cols
14
 
15
- # w, h = imgs[0].size
16
- # grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
17
- # grid_w, grid_h = grid.size
 
 
 
 
18
 
19
- # for i, img in enumerate(imgs):
20
- # grid.paste(img, box=(i%cols*w, i//cols*h))
21
- # return grid
22
 
23
- # def predict(dict, prompt):
24
- # image = dict['image'].convert("RGB").resize((512, 512))
25
- # mask_image = dict['mask'].convert("RGB").resize((512, 512))
26
- # images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
27
- # return(images[0])
28
 
29
  def download_image(url):
30
  response = requests.get(url)
31
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
32
 
33
- def greet(name):
34
- img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
35
- image = download_image(img_url).resize((512, 512))
36
- inverted_mask_image = remove(data = image, only_mask = True)
37
- mask_image = PIL.ImageOps.invert(inverted_mask_image)
 
 
 
 
 
 
 
 
 
38
 
39
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
42
- iface.launch()
 
 
 
 
 
 
 
 
 
 
9
  import requests
10
  from io import BytesIO
11
 
12
+ def image_grid(imgs, rows, cols):
13
+ assert len(imgs) == rows*cols
14
 
15
+ w, h = imgs[0].size
16
+ grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
17
+ grid_w, grid_h = grid.size
18
+
19
+ for i, img in enumerate(imgs):
20
+ grid.paste(img, box=(i%cols*w, i//cols*h))
21
+ return grid
22
 
 
 
 
23
 
24
+ def predict(dict, prompt):
25
+ image = dict['image'].convert("RGB").resize((512, 512))
26
+ mask_image = dict['mask'].convert("RGB").resize((512, 512))
27
+ images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
28
+ return(images[0])
29
 
30
  def download_image(url):
31
  response = requests.get(url)
32
  return PIL.Image.open(BytesIO(response.content)).convert("RGB")
33
 
34
+ device = "cuda"
35
+ model_path = "runwayml/stable-diffusion-inpainting"
36
+
37
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
38
+ model_path,
39
+ revision="fp16",
40
+ torch_dtype=torch.float16,
41
+ use_auth_token=True
42
+ ).to(device)
43
+
44
+ img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
45
+ inverted_mask_image = remove(data = image, only_mask = True)
46
+ mask_image = PIL.ImageOps.invert(inverted_mask_image)
47
+ prompt = "crazy portal universe"
48
 
49
+ guidance_scale=7.5
50
+ num_samples = 3
51
+ generator = torch.Generator(device="cuda").manual_seed(0) # change the seed to get different results
52
+ images = pipe(
53
+ prompt=prompt,
54
+ image=image,
55
+ mask_image=mask_image,
56
+ guidance_scale=guidance_scale,
57
+ generator=generator,
58
+ num_images_per_prompt=num_samples,
59
+ ).images
60
+ images.insert(0, image)
61
+ image_grid(images, 1, num_samples + 1)
62
 
63
+ gr.Interface(
64
+ predict,
65
+ title = 'Stable Diffusion In-Painting',
66
+ inputs=[
67
+ gr.Image(source = 'upload', tool = 'sketch', type = 'pil'),
68
+ gr.Textbox(label = 'prompt')
69
+ ],
70
+ outputs = [
71
+ gr.Image()
72
+ ]
73
+ ).launch(debug=True)