Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import numpy as np | |
#import tensorrt as trt | |
import random | |
import torch | |
from diffusers import StableDiffusion3Pipeline, AutoencoderKL, StableDiffusionXLImg2ImgPipeline, EulerAncestralDiscreteScheduler | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
#from threading import Thread | |
#from transformers import pipeline | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import re | |
import paramiko | |
import urllib | |
import time | |
import os | |
from image_gen_aux import UpscaleWithModel | |
from huggingface_hub import hf_hub_download | |
from models.transformer_sd3 import SD3Transformer2DModel | |
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline | |
from PIL import Image | |
FTP_HOST = "1ink.us" | |
FTP_USER = "ford442" | |
FTP_PASS = "GoogleBez12!" | |
FTP_DIR = "1ink.us/stable_diff/" # Remote directory on FTP server | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
#torch.backends.cudnn.benchmark = False | |
torch.backends.cuda.preferred_blas_library="cublas" | |
#torch.backends.cuda.preferred_linalg_library="cusolver" | |
torch.set_float32_matmul_precision("highest") | |
hftoken = os.getenv("HF_AUTH_TOKEN") | |
image_encoder_path = "google/siglip-so400m-patch14-384" | |
ipadapter_path = hf_hub_download(repo_id="InstantX/SD3.5-Large-IP-Adapter", filename="ip-adapter.bin") | |
def upload_to_ftp(filename): | |
try: | |
transport = paramiko.Transport((FTP_HOST, 22)) | |
destination_path=FTP_DIR+filename | |
transport.connect(username = FTP_USER, password = FTP_PASS) | |
sftp = paramiko.SFTPClient.from_transport(transport) | |
sftp.put(filename, destination_path) | |
sftp.close() | |
transport.close() | |
print(f"Uploaded {filename} to FTP server") | |
except Exception as e: | |
print(f"FTP upload error: {e}") | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
torch_dtype = torch.bfloat16 | |
checkpoint = "microsoft/Phi-3.5-mini-instruct" | |
#vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) | |
vae = AutoencoderKL.from_pretrained("ford442/sdxl-vae-bf16") | |
#vae = AutoencoderKL.from_pretrained("ford442/sdxl-vae-bf16") | |
transformer = SD3Transformer2DModel.from_pretrained( | |
model_path, | |
subfolder="transformer", | |
torch_dtype=torch.bfloat16 | |
) | |
pipe = StableDiffusion3Pipeline.from_pretrained("ford442/stable-diffusion-3.5-medium-bf16", transformer=transformer).to(device=torch.device("cuda:0"), dtype=torch.bfloat16) | |
#pipe = StableDiffusion3Pipeline.from_pretrained("ford442/stable-diffusion-3.5-medium-bf16").to(torch.device("cuda:0")) | |
#pipe = StableDiffusion3Pipeline.from_pretrained("ford442/RealVis_Medium_1.0b_bf16", torch_dtype=torch.bfloat16) | |
#pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-medium", token=hftoken, torch_dtype=torch.float32, device_map='balanced') | |
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++") | |
#pipe.scheduler.config.requires_aesthetics_score = False | |
#pipe.enable_model_cpu_offload() | |
#pipe.to(device) | |
#pipe.to(device=device, dtype=torch.bfloat16) | |
#pipe = torch.compile(pipe) | |
# pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config, beta_schedule="scaled_linear") | |
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained("ford442/stable-diffusion-xl-refiner-1.0-bf16", vae=AutoencoderKL.from_pretrained("ford442/sdxl-vae-bf16"), use_safetensors=True, requires_aesthetics_score=True).to(device=torch.device("cuda:0"), dtype=torch.bfloat16) | |
#refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", vae=vae, torch_dtype=torch.float32, requires_aesthetics_score=True, device_map='balanced') | |
refiner.scheduler=EulerAncestralDiscreteScheduler.from_config(refiner.scheduler.config, beta_schedule="scaled_linear") | |
#refiner.enable_model_cpu_offload() | |
#refiner.scheduler.config.requires_aesthetics_score=False | |
#refiner.to(device) | |
#refiner = torch.compile(refiner) | |
#refiner.scheduler = EulerAncestralDiscreteScheduler.from_config(refiner.scheduler.config, beta_schedule="scaled_linear") | |
#refiner.scheduler = EulerAncestralDiscreteScheduler.from_config(refiner.scheduler.config) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint, add_prefix_space=False, device_map='balanced') | |
tokenizer.tokenizer_legacy=False | |
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='balanced') | |
#model = torch.compile(model) | |
pipe.init_ipadapter( | |
ip_adapter_path=ipadapter_path, | |
image_encoder_path=image_encoder_path, | |
nb_token=64, | |
) | |
upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(torch.device("cuda:0")) | |
def filter_text(text,phraseC): | |
"""Filters out the text up to and including 'Rewritten Prompt:'.""" | |
phrase = "Rewritten Prompt:" | |
phraseB = "rewritten text:" | |
pattern = f"(.*?){re.escape(phrase)}(.*)" | |
patternB = f"(.*?){re.escape(phraseB)}(.*)" | |
# matchB = re.search(patternB, text) | |
matchB = re.search(patternB, text, flags=re.DOTALL) | |
if matchB: | |
filtered_text = matchB.group(2) | |
match = re.search(pattern, filtered_text, flags=re.DOTALL) | |
if match: | |
filtered_text = match.group(2) | |
filtered_text = re.sub(phraseC, "", filtered_text, flags=re.DOTALL) # Replaces the matched pattern with an empty string | |
return filtered_text | |
else: | |
return filtered_text | |
else: | |
# Handle the case where no match is found | |
return text | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 4096 | |
def infer( | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
expanded, | |
latent_file, # Add latents file input | |
progress=gr.Progress(track_tqdm=True), | |
): | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device='cuda').manual_seed(seed) | |
if expanded: | |
system_prompt_rewrite = ( | |
"You are an AI assistant that rewrites image prompts to be more descriptive and detailed." | |
) | |
user_prompt_rewrite = ( | |
"Rewrite this prompt to be more descriptive and detailed and only return the rewritten text: " | |
) | |
user_prompt_rewrite_2 = ( | |
"Rephrase this scene to have more elaborate details: " | |
) | |
input_text = f"{system_prompt_rewrite} {user_prompt_rewrite} {prompt}" | |
input_text_2 = f"{system_prompt_rewrite} {user_prompt_rewrite_2} {prompt}" | |
print("-- got prompt --") | |
# Encode the input text and include the attention mask | |
encoded_inputs = tokenizer(input_text, return_tensors="pt", return_attention_mask=True) | |
encoded_inputs_2 = tokenizer(input_text_2, return_tensors="pt", return_attention_mask=True) | |
# Ensure all values are on the correct device | |
input_ids = encoded_inputs["input_ids"].to(device) | |
input_ids_2 = encoded_inputs_2["input_ids"].to(device) | |
attention_mask = encoded_inputs["attention_mask"].to(device) | |
attention_mask_2 = encoded_inputs_2["attention_mask"].to(device) | |
print("-- tokenize prompt --") | |
# Google T5 | |
#input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda") | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=512, | |
temperature=0.2, | |
top_p=0.9, | |
do_sample=True, | |
) | |
outputs_2 = model.generate( | |
input_ids=input_ids_2, | |
attention_mask=attention_mask_2, | |
max_new_tokens=65, | |
temperature=0.2, | |
top_p=0.9, | |
do_sample=True, | |
) | |
# Use the encoded tensor 'text_inputs' here | |
enhanced_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
enhanced_prompt_2 = tokenizer.decode(outputs_2[0], skip_special_tokens=True) | |
print('-- generated prompt --') | |
enhanced_prompt = filter_text(enhanced_prompt,prompt) | |
enhanced_prompt_2 = filter_text(enhanced_prompt_2,prompt) | |
print('-- filtered prompt --') | |
print(enhanced_prompt) | |
print('-- filtered prompt 2 --') | |
print(enhanced_prompt_2) | |
else: | |
enhanced_prompt = prompt | |
enhanced_prompt_2 = prompt | |
if latent_file: # Check if a latent file is provided | |
# initial_latents = pipe.prepare_latents( | |
# batch_size=1, | |
# num_channels_latents=pipe.transformer.in_channels, | |
# height=pipe.transformer.config.sample_size[0], | |
# width=pipe.transformer.config.sample_size[1], | |
# dtype=pipe.transformer.dtype, | |
# device=pipe.device, | |
# generator=generator, | |
# ) | |
sd_image_a = Image.open(latent_file.name) | |
print("-- using image file --") | |
print('-- generating image --') | |
#with torch.no_grad(): | |
result = pipe( | |
clip_image=image, | |
prompt=prompt, | |
ipadapter_scale=scale, | |
width=width, | |
height=height, | |
generator=torch.Generator().manual_seed(seed) | |
).images[0] | |
rv_path = f"sd35_{seed}.png" | |
sd_image[0].save(rv_path,optimize=False,compress_level=0) | |
upload_to_ftp(rv_path) | |
else: | |
print('-- generating image --') | |
#with torch.no_grad(): | |
sd_image = pipe( | |
prompt=prompt, # This conversion is fine | |
prompt_2=enhanced_prompt_2, | |
prompt_3=enhanced_prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
# latents=None, | |
# output='latent', | |
generator=generator, | |
max_sequence_length=512 | |
).images[0] | |
print('-- got image --') | |
sd35_image_image = pipe.vae.decode(sd_image / 0.18215).sample | |
sd35_image = sd35_image.cpu().permute(0, 2, 3, 1).float().detach().numpy() | |
sd35_image = (sd35_image * 255).round().astype("uint8") | |
image_pil = Image.fromarray(sd35_image[0]) | |
sd35_path = f"sd35_{seed}.png" | |
image_pil.save(sd35_path,optimize=False,compress_level=0) | |
upload_to_ftp(sd35_path) | |
#sd35_path = f"sd35_{seed}.png" | |
#sd_image.save(sd35_path,optimize=False,compress_level=0) | |
#upload_to_ftp(sd35_path) | |
# Convert the generated image to a tensor | |
#generated_image_tensor = torch.tensor([np.array(sd_image).transpose(2, 0, 1)]).to('cuda') / 255.0 | |
# Encode the generated image into latents | |
#with torch.no_grad(): | |
# generated_latents = pipe.vae.encode(generated_image_tensor.to(torch.bfloat16)).latent_dist.sample().mul_(0.18215) | |
#latent_path = f"sd35m_{seed}.pt" | |
# Save the latents to a .pt file | |
#torch.save(generated_latents, latent_path) | |
#upload_to_ftp(latent_path) | |
#refiner.scheduler.set_timesteps(num_inference_steps,device) | |
refine = refiner( | |
prompt=f"{enhanced_prompt_2}, high quality masterpiece, complex details", | |
negative_prompt = negative_prompt, | |
guidance_scale=7.5, | |
num_inference_steps=num_inference_steps, | |
image=sd_image, | |
generator=generator, | |
).images[0] | |
refine_path = f"sd35_refine_{seed}.png" | |
refine.save(refine_path,optimize=False,compress_level=0) | |
upload_to_ftp(refine_path) | |
return refine, seed, enhanced_prompt | |
examples = [ | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
"A delicious ceviche cheesecake slice", | |
] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
body{ | |
background-color: blue; | |
} | |
""" | |
def repeat_infer( | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
num_iterations, # New input for number of iterations | |
): | |
i = 0 | |
while i < num_iterations: | |
time.sleep(700) # Wait for 10 minutes (600 seconds) | |
result, seed, image_path, enhanced_prompt = infer( | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
) | |
# Optionally, you can add logic here to process the results of each iteration | |
# For example, you could display the image, save it with a different name, etc. | |
i += 1 | |
return result, seed, image_path, enhanced_prompt | |
with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(" # Text-to-Text-to-Image StableDiffusion 3.5 Medium (with refine)") | |
expanded_prompt_output = gr.Textbox(label="Expanded Prompt", lines=5) # Add this line | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
value="A captivating Christmas scene.", | |
container=False, | |
) | |
options = [True, False] | |
expanded = gr.Radio( | |
show_label=True, | |
container=True, | |
interactive=True, | |
choices=options, | |
value=True, | |
label="Use expanded prompt: ", | |
) | |
run_button = gr.Button("Run", scale=0, variant="primary") | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
latent_file = gr.File(label="Image File (optional)") # Add latents file input | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
visible=False, | |
) | |
num_iterations = gr.Number( | |
value=1000, | |
label="Number of Iterations") | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=768, # Replace with defaults that work for your model | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=768, # Replace with defaults that work for your model | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=30.0, | |
step=0.1, | |
value=4.2, # Replace with defaults that work for your model | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=500, | |
step=1, | |
value=150, # Replace with defaults that work for your model | |
) | |
gr.Examples(examples=examples, inputs=[prompt]) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=infer, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
expanded, | |
latent_file, # Add latent_file to the inputs | |
], | |
outputs=[result, seed, expanded_prompt_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() |