|
import streamlit as st |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
|
|
@st.cache_resource |
|
def load_pipeline(): |
|
return DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
pipeline = load_pipeline() |
|
|
|
def generate_image(prompt): |
|
|
|
with torch.no_grad(): |
|
result = pipeline(prompt).images[0] |
|
return result |
|
|
|
def main(): |
|
st.title("Stable Diffusion Image Generator") |
|
|
|
|
|
prompt = st.text_input("Enter a prompt for image generation:") |
|
|
|
if st.button("Generate Image"): |
|
if prompt: |
|
|
|
image = generate_image(prompt) |
|
st.image(image, caption="Generated Image") |
|
else: |
|
st.warning("Please enter a prompt.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|