|
from huggingface_hub import from_pretrained_keras |
|
from keras_cv import models |
|
import gradio as gr |
|
|
|
from tensorflow import keras |
|
|
|
keras.mixed_precision.set_global_policy("mixed_float16") |
|
|
|
|
|
resolution = 512 |
|
sd_dreambooth_model = models.StableDiffusion( |
|
img_width=resolution, img_height=resolution |
|
) |
|
db_diffusion_model = from_pretrained_keras("merve/dreambooth_kedis") |
|
sd_dreambooth_model._diffusion_model = db_diffusion_model |
|
|
|
|
|
def infer(prompt): |
|
generated_images = sd_dreambooth_model.text_to_image( |
|
prompt, batch_size=2 |
|
) |
|
return generated_images |
|
|
|
output = gr.Gallery(label="Outputs").style(grid=(1,2)) |
|
|
|
|
|
title = "Dreambooth Demo on Dog Images" |
|
description = "This is a dreambooth model fine-tuned on dog images. To try it, input the concept with {kedis cat}." |
|
examples=[["sks dog in space"]] |
|
gr.Interface(infer, inputs=["text"], outputs=[output], title=title, description=description, examples=examples).queue().launch() |
|
|