Spaces:
Runtime error
Runtime error
File size: 1,274 Bytes
5e69d3c afb2266 d13afd7 afb2266 d13afd7 afb2266 ddc7e76 afb2266 5e69d3c ddc7e76 afb2266 5e69d3c ddc7e76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import gradio as gr
import jax
import jax.numpy as jnp
from flax.jax_utils import replicate
from transformers import DalleBartProcessor, FlaxDalleBartForConditionalGeneration
# Load the model and processor
processor = DalleBartProcessor.from_pretrained("dalle-mini/dalle-mega")
model = FlaxDalleBartForConditionalGeneration.from_pretrained("dalle-mini/dalle-mega")
# Function to generate an image from a text prompt
def generate_image(prompt):
inputs = processor([prompt], return_tensors="jax", padding="max_length", truncation=True, max_length=128)
# Generate images
images = model.generate(**replicate(inputs.data), do_sample=True, num_beams=1, num_return_sequences=1)
# Post-process image for display (convert to PIL image format)
image = images[0] # assuming single image output
return image # return the generated image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image, # Function to generate image
inputs=gr.inputs.Textbox(lines=2, placeholder="Enter your prompt"), # Textbox input
outputs="image", # Output as an image
title="DALL-E Mini Image Generator",
description="Generate images from text prompts using DALL-E Mini model."
)
# Launch the app
iface.launch()
|