pigeon-avatar / app.py
ItzRoBeerT's picture
Update app.py
8e4035c verified
raw
history blame
2.34 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
import torch
from PIL import Image
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.mps.is_available():
device = "mps"
model_id_image = "CompVis/stable-diffusion-v1-4"
model_id_image_description = "vikhyatk/moondream2"
revision = "2024-08-26"
torch_dtype = torch.float32
if torch.cuda.is_available():
torch_dtype = torch.bfloat16
def generate_description(image):
model = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
tokenizer = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
image_test = Image.open(image)
enc_image = model.encode_image(image_test)
res = model.answer_question(enc_image, "Describe this image to create an avatar", tokenizer)
return res
def generate_image_by_description(description, avatar_style=None):
pipe = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
pipe = pipe.to(device)
pipe.enable_attention_slicing()
prompt = (
f"Create a pigeon profile avatar. "
f"Use the following description: {description}. "
)
if avatar_style:
prompt += f"Use {avatar_style} avatar style."
image = pipe(prompt).images[0]
return image
def process_and_generate(image, avatar_style):
description = generate_description(image)
return generate_image_by_description(description, avatar_style)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2, min_width=300):
selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon",height=300)
avatar_style = gr.Radio(
["Realistic", "Pixel Art", "Imaginative", "Cartoon"], label="(optional) Select the avatar style:")
generate_button = gr.Button("Generate Avatar", variant="primary")
with gr.Column(scale=2, min_width=300):
generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
generate_button.click(process_and_generate, inputs=[selected_image, avatar_style ], outputs=generated_image)
demo.launch()