text2image / app.py
My-AI-Projects's picture
Update app.py
afb2266 verified
raw
history blame
1.27 kB
import gradio as gr
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
# Load the model and processor
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
# Function to generate an image from a text prompt
def generate_image(prompt):
inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128)
# Generate images
images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1)
# Post-process image for display (convert to PIL image format)
image = images[0] # assuming single image output
return image # return the generated image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image, # Function to generate image
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input
outputs="image", # Output as an image
title="DALL-E Mini Image Generator",
description="Generate images from text prompts using DALL-E Mini model."
)
# Launch the app
iface.launch()