text2image / app.py
My-AI-Projects's picture
Update app.py
d13afd7 verified
raw
history blame
1.36 kB
import gradio as gr
import torch
from transformers import DalleMini, DalleMiniProcessor
from PIL import Image
# Load model and processor
model_id = "dalle-mini/dalle-mega"
model = DalleMini.from_pretrained(model_id)
processor = DalleMiniProcessor.from_pretrained(model_id)
# Function to generate image
def generate_image(prompt, num_inference_steps=50):
inputs = processor(prompt, return_tensors="pt")
# Generate images
with torch.no_grad():
outputs = model.generate(**inputs, num_inference_steps=num_inference_steps)
# Convert to PIL image
image = processor.decode(outputs[0], skip_special_tokens=True)
image = Image.open(io.BytesIO(image))
return image
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Text to Image Generation")
with gr.Row():
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here...")
num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, value=28, label="Number of Inference Steps")
with gr.Row():
generate_button = gr.Button("Generate Image")
result = gr.Image(label="Generated Image")
# Connect the function to the button
generate_button.click(
fn=generate_image,
inputs=[prompt, num_inference_steps],
outputs=result
)
# Launch the Gradio app
demo.launch()