blog-post / app.py
dami1996's picture
cover generation and ui changes
f19162e
raw
history blame
4.01 kB
import gradio as gr
import torch
from transformers import pipeline
from diffusers import StableDiffusionPipeline
ARTICLE_GENERATOR_MODEL = "gpt2"
SUMMARIZER_MODEL = "Falconsai/text_summarization"
TITLE_GENERATOR_MODEL = "czearing/article-title-generator"
IMAGE_GENERATOR_MODEL = "prompthero/openjourney-v4"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{DEVICE = }")
text_generator = pipeline(
"text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE
)
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE)
title_generator = pipeline(
"text2text-generation",
model=TITLE_GENERATOR_MODEL,
device=DEVICE,
)
image_generator = StableDiffusionPipeline.from_pretrained(
IMAGE_GENERATOR_MODEL,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
image_generator = image_generator.to(DEVICE)
def generate_blog_post(query, article_length, title_length, summary_length):
print("Generating article.")
article = text_generator(query, max_length=article_length, num_return_sequences=1)[
0
]["generated_text"]
print(f"{article = }")
print("Generating the title.")
title = title_generator(article, max_length=title_length, num_return_sequences=1)[
0
]["generated_text"]
print(f"{title = }")
print("Generating the summary.")
summary = summarizer(
article,
max_length=summary_length,
min_length=min(30, summary_length),
do_sample=False,
)[0]["summary_text"]
print(f"{summary = }")
print("Generating the cover image.")
image = image_generator(
summary, num_inference_steps=40, guidance_scale=7.5, width=512, height=512
).images[0]
return title, summary, article, image
with gr.Blocks() as iface:
gr.Markdown("# Blog Post Generator")
gr.Markdown(
"Enter a topic, and I'll generate a blog post with a title, cover image, and optional summary!"
)
with gr.Row():
input_prompt = gr.Textbox(lines=2, placeholder="Enter your blog post topic...")
with gr.Row():
generate_button = gr.Button("Generate Blog Post", size="sm")
with gr.Row():
with gr.Column(scale=2):
with gr.Blocks() as title_block:
gr.Markdown("## Title")
with gr.Accordion("Options", open=False):
title_length = gr.Slider(
minimum=10, maximum=50, value=30, step=5, label="Title Length"
)
title_output = gr.Textbox(label="Title")
with gr.Blocks() as body_block:
gr.Markdown("## Body")
with gr.Accordion("Options", open=False):
article_length = gr.Slider(
minimum=100,
maximum=1000,
value=500,
step=50,
label="Article Length",
)
article_output = gr.Textbox(label="Article", lines=10)
with gr.Column(scale=1):
with gr.Blocks() as image_block:
gr.Markdown("## Cover Image")
image_output = gr.Image(label="Cover Image")
with gr.Blocks() as summary_block:
gr.Markdown("## Summary")
with gr.Accordion("Options", open=False):
summary_length = gr.Slider(
minimum=30,
maximum=200,
value=100,
step=10,
label="Summary Length",
)
summary_output = gr.Textbox(label="Summary", lines=5)
job = generate_button.click(
generate_blog_post,
inputs=[
input_prompt,
article_length,
title_length,
summary_length,
],
outputs=[title_output, summary_output, article_output, image_output],
)
iface.launch()