File size: 2,469 Bytes
970c399
ebe47d5
970c399
 
 
 
07012d2
970c399
 
d17dbf0
 
b7a4d35
 
ebe47d5
 
07012d2
c7ed554
 
970c399
c7ed554
 
 
 
 
 
 
970c399
 
c7ed554
9e13d8c
c7ed554
 
 
8d99288
 
 
 
 
c7ed554
870e64e
c7ed554
870e64e
 
 
 
 
 
 
 
 
 
 
 
 
 
c7ed554
 
9e13d8c
c7ed554
 
 
8d99288
c7ed554
 
870e64e
c7ed554
 
 
 
 
 
 
 
 
 
8d99288
c7ed554
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import inspect
import os
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
from rembg import remove
import requests
from io import BytesIO
from huggingface_hub import login

token = os.getenv("WRITE_TOKEN")
login(token, True)

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
    

def predict(dict, prompt):
  image = dict['image'].convert("RGB").resize((512, 512))
  mask_image = dict['mask'].convert("RGB").resize((512, 512))
  images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
  return(images[0])

def download_image(url):
    response = requests.get(url)
    return PIL.Image.open(BytesIO(response.content)).convert("RGB")

model_path = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"

if device == "cuda":
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        model_path,
        revision="fp16",
        torch_dtype=torch.float16,
        use_auth_token=True
    ).to(device)
else:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        model_path,
        # revision="fp16",
        # torch_dtype=torch.float16,
        use_auth_token=True
    ).to(device)

img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
image = download_image(img_url).resize((512, 512))
inverted_mask_image = remove(data = image, only_mask = True)
mask_image = PIL.ImageOps.invert(inverted_mask_image)
prompt = "crazy portal universe"

guidance_scale=7.5
num_samples = 3
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
images = pipe(
    prompt=prompt,
    image=image,
    mask_image=mask_image,
    guidance_scale=guidance_scale,
    generator=generator,
    num_images_per_prompt=num_samples,
).images
images.insert(0, image)
image_grid(images, 1, num_samples + 1)

gr.Interface(
    predict,
    title = 'Stable Diffusion In-Painting',
    inputs=[
        gr.Image(source = 'upload', tool = 'sketch', type = 'pil'),
        gr.Textbox(label = 'prompt')
    ],
    outputs = [
        gr.Image()
        ]
).launch(debug=True)