import streamlit as st from diffusers import DiffusionPipeline import torch from PIL import Image import io import gc # Import garbage collection # Заголовок приложения st.title("Генератор изображений с LCM Dreamshaper") st.write("Используйте эту модель для быстрой генерации изображений на CPU") # Создаем область для настроек with st.sidebar: st.header("Настройки") prompt = st.text_area("Введите ваш запрос:", "hoaxx kitty", height=100) num_inference_steps = st.slider( "Количество шагов инференса:", min_value=1, max_value=50, value=5, help="Больше шагов = выше качество, но медленнее" ) guidance_scale = st.slider( "Guidance Scale:", min_value=1.0, max_value=15.0, value=8.0, step=0.5, help="Насколько строго модель следует промпту" ) lcm_origin_steps = st.slider( "LCM Origin Steps:", min_value=1, max_value=50, value=35 ) generate_button = st.button("Сгенерировать изображение") # Загружаем модель при первом запуске @st.cache_resource def load_model(): pipe = DiffusionPipeline.from_pretrained( "SimianLuo/LCM_Dreamshaper_v7", torch_dtype=torch.float32 ) pipe.to("cpu") pipe.enable_attention_slicing() pipe.safety_checker = None return pipe # Функция для генерации изображения def generate_image(pipe, prompt, steps, guidance, lcm_steps): try: with torch.inference_mode(): images = pipe( prompt=prompt, num_inference_steps=steps, guidance_scale=guidance, lcm_origin_steps=lcm_steps, output_type="pil" ).images return images[0] except Exception as e: st.error(f"Error generating image: {e}") return None # Загружаем модель pipe = load_model() # Отображаем прогресс и результат if generate_button: with st.spinner("Генерация изображения..."): # Создаем место для вывода изображения result_container = st.empty() # Генерируем изображение image = generate_image( pipe, prompt, num_inference_steps, guidance_scale, lcm_origin_steps ) # Показываем результат if image: # Only display if image generation was successful result_container.image(image, caption=f"Результат для: {prompt}", use_container_width=True) # Use use_container_width # Предлагаем скачать buf = io.BytesIO() image.save(buf, format="PNG") byte_im = buf.getvalue() st.download_button( label="Скачать изображение", data=byte_im, file_name="generated_image.png", mime="image/png" ) gc.collect() # Garbage collection after image generation # Инструкции по использованию if not generate_button: st.write("👈 Настройте параметры в боковой панели и нажмите 'Сгенерировать изображение'")