File size: 4,008 Bytes
76e2719 11a93c7 f19162e 76e2719 f19162e 76e2719 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 76e2719 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 f19162e 11a93c7 76e2719 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import torch
from transformers import pipeline
from diffusers import StableDiffusionPipeline
ARTICLE_GENERATOR_MODEL = "gpt2"
SUMMARIZER_MODEL = "Falconsai/text_summarization"
TITLE_GENERATOR_MODEL = "czearing/article-title-generator"
IMAGE_GENERATOR_MODEL = "prompthero/openjourney-v4"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{DEVICE = }")
text_generator = pipeline(
"text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE
)
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE)
title_generator = pipeline(
"text2text-generation",
model=TITLE_GENERATOR_MODEL,
device=DEVICE,
)
image_generator = StableDiffusionPipeline.from_pretrained(
IMAGE_GENERATOR_MODEL,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
image_generator = image_generator.to(DEVICE)
def generate_blog_post(query, article_length, title_length, summary_length):
print("Generating article.")
article = text_generator(query, max_length=article_length, num_return_sequences=1)[
0
]["generated_text"]
print(f"{article = }")
print("Generating the title.")
title = title_generator(article, max_length=title_length, num_return_sequences=1)[
0
]["generated_text"]
print(f"{title = }")
print("Generating the summary.")
summary = summarizer(
article,
max_length=summary_length,
min_length=min(30, summary_length),
do_sample=False,
)[0]["summary_text"]
print(f"{summary = }")
print("Generating the cover image.")
image = image_generator(
summary, num_inference_steps=40, guidance_scale=7.5, width=512, height=512
).images[0]
return title, summary, article, image
with gr.Blocks() as iface:
gr.Markdown("# Blog Post Generator")
gr.Markdown(
"Enter a topic, and I'll generate a blog post with a title, cover image, and optional summary!"
)
with gr.Row():
input_prompt = gr.Textbox(lines=2, placeholder="Enter your blog post topic...")
with gr.Row():
generate_button = gr.Button("Generate Blog Post", size="sm")
with gr.Row():
with gr.Column(scale=2):
with gr.Blocks() as title_block:
gr.Markdown("## Title")
with gr.Accordion("Options", open=False):
title_length = gr.Slider(
minimum=10, maximum=50, value=30, step=5, label="Title Length"
)
title_output = gr.Textbox(label="Title")
with gr.Blocks() as body_block:
gr.Markdown("## Body")
with gr.Accordion("Options", open=False):
article_length = gr.Slider(
minimum=100,
maximum=1000,
value=500,
step=50,
label="Article Length",
)
article_output = gr.Textbox(label="Article", lines=10)
with gr.Column(scale=1):
with gr.Blocks() as image_block:
gr.Markdown("## Cover Image")
image_output = gr.Image(label="Cover Image")
with gr.Blocks() as summary_block:
gr.Markdown("## Summary")
with gr.Accordion("Options", open=False):
summary_length = gr.Slider(
minimum=30,
maximum=200,
value=100,
step=10,
label="Summary Length",
)
summary_output = gr.Textbox(label="Summary", lines=5)
job = generate_button.click(
generate_blog_post,
inputs=[
input_prompt,
article_length,
title_length,
summary_length,
],
outputs=[title_output, summary_output, article_output, image_output],
)
iface.launch()
|