import gradio as gr import torch from transformers import pipeline from diffusers import StableDiffusionPipeline ARTICLE_GENERATOR_MODEL = "gpt2" SUMMARIZER_MODEL = "Falconsai/text_summarization" TITLE_GENERATOR_MODEL = "czearing/article-title-generator" IMAGE_GENERATOR_MODEL = "prompthero/openjourney-v4" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"{DEVICE = }") text_generator = pipeline( "text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE ) summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE) title_generator = pipeline( "text2text-generation", model=TITLE_GENERATOR_MODEL, device=DEVICE, ) image_generator = StableDiffusionPipeline.from_pretrained( IMAGE_GENERATOR_MODEL, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, ) image_generator = image_generator.to(DEVICE) def generate_blog_post(query, article_length, title_length, summary_length): print("Generating article.") article = text_generator(query, max_length=article_length, num_return_sequences=1)[ 0 ]["generated_text"] print(f"{article = }") print("Generating the title.") title = title_generator(article, max_length=title_length, num_return_sequences=1)[ 0 ]["generated_text"] print(f"{title = }") print("Generating the summary.") summary = summarizer( article, max_length=summary_length, min_length=min(30, summary_length), do_sample=False, )[0]["summary_text"] print(f"{summary = }") print("Generating the cover image.") image = image_generator( summary, num_inference_steps=40, guidance_scale=7.5, width=512, height=512 ).images[0] return title, summary, article, image with gr.Blocks() as iface: gr.Markdown("# Blog Post Generator") gr.Markdown( "Enter a topic, and I'll generate a blog post with a title, cover image, and optional summary!" ) with gr.Row(): input_prompt = gr.Textbox(lines=2, placeholder="Enter your blog post topic...") with gr.Row(): generate_button = gr.Button("Generate Blog Post", size="sm") with gr.Row(): with gr.Column(scale=2): with gr.Blocks() as title_block: gr.Markdown("## Title") with gr.Accordion("Options", open=False): title_length = gr.Slider( minimum=10, maximum=50, value=30, step=5, label="Title Length" ) title_output = gr.Textbox(label="Title") with gr.Blocks() as body_block: gr.Markdown("## Body") with gr.Accordion("Options", open=False): article_length = gr.Slider( minimum=100, maximum=1000, value=500, step=50, label="Article Length", ) article_output = gr.Textbox(label="Article", lines=10) with gr.Column(scale=1): with gr.Blocks() as image_block: gr.Markdown("## Cover Image") image_output = gr.Image(label="Cover Image") with gr.Blocks() as summary_block: gr.Markdown("## Summary") with gr.Accordion("Options", open=False): summary_length = gr.Slider( minimum=30, maximum=200, value=100, step=10, label="Summary Length", ) summary_output = gr.Textbox(label="Summary", lines=5) job = generate_button.click( generate_blog_post, inputs=[ input_prompt, article_length, title_length, summary_length, ], outputs=[title_output, summary_output, article_output, image_output], ) iface.launch()