Spaces:
Runtime error
Runtime error
File size: 2,469 Bytes
970c399 ebe47d5 970c399 07012d2 970c399 d17dbf0 b7a4d35 ebe47d5 07012d2 c7ed554 970c399 c7ed554 970c399 c7ed554 9e13d8c c7ed554 8d99288 c7ed554 870e64e c7ed554 870e64e c7ed554 9e13d8c c7ed554 8d99288 c7ed554 870e64e c7ed554 8d99288 c7ed554 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import inspect
import os
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
import gradio as gr
from diffusers import StableDiffusionInpaintPipeline
from rembg import remove
import requests
from io import BytesIO
from huggingface_hub import login
token = os.getenv("WRITE_TOKEN")
login(token, True)
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
def predict(dict, prompt):
image = dict['image'].convert("RGB").resize((512, 512))
mask_image = dict['mask'].convert("RGB").resize((512, 512))
images = pipe(prompt=prompt, image=image, mask_image=mask_image).images
return(images[0])
def download_image(url):
response = requests.get(url)
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
model_path = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True
).to(device)
else:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_path,
# revision="fp16",
# torch_dtype=torch.float16,
use_auth_token=True
).to(device)
img_url = "https://cdn.faire.com/fastly/893b071985d70819da5f0d485f1b1bb97ee4f16a6e14ef1bdd4a086b3588be58.png" # wino
image = download_image(img_url).resize((512, 512))
inverted_mask_image = remove(data = image, only_mask = True)
mask_image = PIL.ImageOps.invert(inverted_mask_image)
prompt = "crazy portal universe"
guidance_scale=7.5
num_samples = 3
generator = torch.Generator(device=device).manual_seed(0) # change the seed to get different results
images = pipe(
prompt=prompt,
image=image,
mask_image=mask_image,
guidance_scale=guidance_scale,
generator=generator,
num_images_per_prompt=num_samples,
).images
images.insert(0, image)
image_grid(images, 1, num_samples + 1)
gr.Interface(
predict,
title = 'Stable Diffusion In-Painting',
inputs=[
gr.Image(source = 'upload', tool = 'sketch', type = 'pil'),
gr.Textbox(label = 'prompt')
],
outputs = [
gr.Image()
]
).launch(debug=True)
|