V1 / app.py
michaelapplydesign's picture
test x formers
98eda10
raw
history blame
7.98 kB
import os
# os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "caching_allocator"
import gradio as gr
import numpy as np
from models import make_inpainting
import utils
from transformers import MaskFormerImageProcessor, MaskFormerForInstanceSegmentation
from PIL import Image
import requests
from transformers import pipeline
import torch
import random
import io
import base64
import json
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionLatentUpscalePipeline, StableDiffusionPipeline
from diffusers import StableDiffusionUpscalePipeline
from diffusers import LDMSuperResolutionPipeline
import cv2
import onnxruntime
import xformers
# from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
def removeFurniture(input_img1,
input_img2,
positive_prompt,
negative_prompt,
num_of_images,
resolution
):
print("removeFurniture")
HEIGHT = resolution
WIDTH = resolution
input_img1 = input_img1.resize((resolution, resolution))
input_img2 = input_img2.resize((resolution, resolution))
canvas_mask = np.array(input_img2)
mask = utils.get_mask(canvas_mask)
print(input_img1, mask, positive_prompt, negative_prompt)
retList= make_inpainting(positive_prompt=positive_prompt,
image=input_img1,
mask_image=mask,
negative_prompt=negative_prompt,
num_of_images=num_of_images,
resolution=resolution
)
# add the rest up to 10
while (len(retList)<10):
retList.append(None)
return retList
def imageToString(img):
output = io.BytesIO()
img.save(output, format="png")
return output.getvalue()
def segmentation(img):
print("segmentation")
# semantic_segmentation = pipeline("image-segmentation", "nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
pipe = pipeline("image-segmentation", "facebook/maskformer-swin-large-ade")
results = pipe(img)
for p in results:
p['mask'] = utils.image_to_byte_array(p['mask'])
p['mask'] = base64.b64encode(p['mask']).decode("utf-8")
#print(results)
return json.dumps(results)
def upscale(image, prompt):
print("upscale",image,prompt)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device",device)
# image.thumbnail((512, 512))
# print("resize",image)
pipe = StableDiffusionUpscalePipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16)
# pipe = StableDiffusionLatentUpscalePipeline.from_pretrained("stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
pipe.enable_xformers_memory_efficient_attention(attention_op=xformers.ops.MemoryEfficientAttentionFlashAttentionOp)
# Workaround for not accepting attention shape using VAE for Flash Attention
pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
ret = pipe(prompt=prompt,
image=image,
num_inference_steps=10,
guidance_scale=0)
print("ret",ret)
upscaled_image = ret.images[0]
print("up",upscaled_image)
return upscaled_image
def upscale2(image, prompt):
print("upscale2",image,prompt)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device",device)
pipe = LDMSuperResolutionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages", torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
upscaled_image = pipe(image, num_inference_steps=10, eta=1).images[0]
return upscaled_image
def convert_pil_to_cv2(image):
# pil_image = image.convert("RGB")
open_cv_image = np.array(image)
# RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
return open_cv_image
def inference(model_path: str, img_array: np.array) -> np.array:
options = onnxruntime.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
ort_session = onnxruntime.InferenceSession(model_path, options)
ort_inputs = {ort_session.get_inputs()[0].name: img_array}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]
def post_process(img: np.array) -> np.array:
# 1, C, H, W -> C, H, W
img = np.squeeze(img)
# C, H, W -> H, W, C
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
return img
def pre_process(img: np.array) -> np.array:
# H, W, C -> C, H, W
img = np.transpose(img[:, :, 0:3], (2, 0, 1))
# C, H, W -> 1, C, H, W
img = np.expand_dims(img, axis=0).astype(np.float32)
return img
def upscale3(image):
print("upscale3",image)
model_path = f"up_models/modelx4.ort"
img = convert_pil_to_cv2(image)
# if img.ndim == 2:
# print("upscale3","img.ndim == 2")
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
# if img.shape[2] == 4:
# print("upscale3","img.shape[2] == 4")
# alpha = img[:, :, 3] # GRAY
# alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
# alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
# alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
# img = img[:, :, 0:3] # BGR
# image_output = post_process(inference(model_path, pre_process(img))) # BGR
# image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
# image_output[:, :, 3] = alpha_output
# print("upscale3","img.shape[2] == 3")
image_output = post_process(inference(model_path, pre_process(img))) # BGR
return image_output
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
gr.Button("FurnituRemove").click(removeFurniture,
inputs=[gr.Image(label="img", type="pil"),
gr.Image(label="mask", type="pil"),
gr.Textbox(label="positive_prompt",value="empty room"),
gr.Textbox(label="negative_prompt",value=""),
gr.Number(label="num_of_images",value=2),
gr.Number(label="resolution",value=512)
],
outputs=[
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image(),
gr.Image()])
with gr.Column():
gr.Button("Segmentation").click(segmentation, inputs=gr.Image(type="pil"), outputs=gr.JSON())
with gr.Column():
gr.Button("Upscale").click(upscale, inputs=[gr.Image(type="pil"),gr.Textbox(label="prompt",value="empty room")], outputs=gr.Image())
with gr.Column():
gr.Button("Upscale2").click(upscale2, inputs=[gr.Image(type="pil"),gr.Textbox(label="prompt",value="empty room")], outputs=gr.Image())
with gr.Column():
gr.Button("Upscale3").click(upscale3, inputs=[gr.Image(type="pil")], outputs=gr.Image())
app.launch(debug=True,share=True)
# UP 1