Spaces:
Sleeping
Sleeping
File size: 5,392 Bytes
40f772a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
import gradio as gr
from pipeline_controlnet_sd_xl_raw import StableDiffusionXLControlNetRAWPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler
from torchvision import transforms
from PIL import Image
import traceback
# ========== 1. Load Models ==========
# base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
# controlnet_path = "/mnt/wencheng/RAWPami/diffusers/examples/controlnet/controlnet-model"
# controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
# pipe = StableDiffusionXLControlNetRAWPipeline.from_pretrained(
# base_model_path,
# controlnet=controlnet,
# torch_dtype=torch.float16
# )
pipe = StableDiffusionXLControlNetRAWPipeline.from_pretrained(
"wencheng256/DiffusionRAW",
torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
# ========== 2. Utility function: tensor -> PIL ==========
def tensor_to_pil(img_tensor: torch.Tensor) -> Image.Image:
if img_tensor.is_cuda:
img_tensor = img_tensor.cpu()
if img_tensor.dtype != torch.float32:
img_tensor = img_tensor.float()
img_tensor = img_tensor.clamp(0, 1)
return transforms.ToPILImage()(img_tensor)
# ========== 3. Load a .pth file ==========
def load_pth_data(pth_path):
data = torch.load(pth_path)
rgb_tensor = data["rgb"]
raw_tensor = data["raw"]
mask_tensor = data["mask"]
cond_tensor = data["condition"]
# Assuming each key can contain multiple images; using the first index only
raw_image_pil = tensor_to_pil(raw_tensor[0][:, :448])
rgb_tensor = tensor_to_pil(torch.flip(rgb_tensor[0], dims=[0])[:, :448])
mask_image_pil = tensor_to_pil(1 - mask_tensor[0])
return rgb_tensor, raw_image_pil, mask_image_pil, raw_tensor, mask_tensor, cond_tensor
# ========== 4. Inference function ==========
def infer_fn(prompt, mask_edited, raw_tensor_state, mask_tensor_state, cond_tensor_state):
"""
mask_edited: using tool='sketch' returns a dict containing {'image': PIL, 'mask': PIL}.
"""
try:
if isinstance(mask_edited, dict):
# Usually we only need the drawn mask
mask_edited = mask_edited["mask"]
mask_edited_tensor = transforms.ToTensor()(mask_edited)
# Keep only one channel as grayscale mask
mask_edited_tensor = mask_edited_tensor[:1]
mask_edited_tensor = mask_edited_tensor.unsqueeze(0).half()
raw_t = raw_tensor_state.half()
cond_t = cond_tensor_state.half()
generator = torch.manual_seed(0)
print("Mask shape:", mask_edited_tensor.shape)
print("Raw shape:", raw_t.shape)
print("Cond shape:", cond_t.shape)
result = pipe(
prompt=prompt,
num_inference_steps=20,
generator=generator,
image=raw_t,
mask_image=mask_edited_tensor,
control_image=cond_t
).images[0]
return tensor_to_pil(result)
except Exception as e:
traceback.print_exc()
return "Error occurred during inference. Please check the terminal logs!"
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# DiffusionRAW ")
# Provide a dropdown to select pth file
pth_options = ["./data1.pth", "./data2.pth", "./data3.pth"]
with gr.Row():
pth_selector = gr.Dropdown(
pth_options,
value=pth_options[0],
label="Select a PTH file"
)
load_button = gr.Button("Load")
with gr.Row():
# Display the raw image
raw_display = gr.Image(
label="Raw Image (Display Only)",
interactive=False,
)
rgb_display = gr.Image(
label="sRGB Image (Display Only)",
interactive=False,
)
# Mask editor with sketch tool
mask_editor = gr.Image(
label="Mask (Sketch)",
tool="sketch",
type="pil",
brush_color="#FFFFFF",
interactive=True,
width=512,
height=512
)
# States to store tensors
raw_tensor_state = gr.State()
mask_tensor_state = gr.State()
cond_tensor_state = gr.State()
load_button.click(
fn=load_pth_data,
inputs=[pth_selector],
outputs=[
rgb_display,
raw_display,
mask_editor,
raw_tensor_state,
mask_tensor_state,
cond_tensor_state
]
)
prompt_input = gr.Textbox(label="Prompt", value="An RAW Image.", lines=1)
generate_button = gr.Button("Generate")
output_image = gr.Image(label="Output", show_download_button=False)
generate_button.click(
fn=infer_fn,
inputs=[
prompt_input,
mask_editor,
raw_tensor_state,
mask_tensor_state,
cond_tensor_state
],
outputs=[output_image]
)
return demo
if __name__ == "__main__":
demo = build_demo()
demo.launch(server_name="0.0.0.0", server_port=9112, debug=True)
|