text2image / app.py
My-AI-Projects's picture
Update app.py
60c6128 verified
raw
history blame
1.44 kB
import gradio as gr
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
from PIL import Image
import numpy as np
import jax
import jax.numpy as jnp
# Load the DALL-E Mega model and processor
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
# Function to generate an image from a text prompt
def generate_image(prompt):
# Process the prompt
inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=64)
# Generate the images
outputs = model.generate(**inputs, do_sample=True, num_beams=4, num_return_sequences=1)
# Decode the images and convert them to displayable format
images = model.decode(outputs.sequences)
images = jax.device_get(images)
# Convert to a PIL image
pil_img = Image.fromarray(np.asarray(images[0]).astype(np.uint8))
return pil_img
# Create Gradio interface
iface = gr.Interface(
fn=generate_image, # Function to generate the image
inputs=gr.Textbox(lines=2, placeholder="Enter your text prompt"), # Input textbox for the prompt
outputs="image", # Output as an image
title="DALL-E Mini/Mega Image Generator",
description="Generate images from text prompts using the DALL-E Mega model."
)
# Launch the app
iface.launch()