File size: 4,008 Bytes
76e2719
 
11a93c7
f19162e
76e2719
f19162e
 
 
 
 
 
 
76e2719
11a93c7
f19162e
11a93c7
f19162e
11a93c7
 
f19162e
 
 
 
 
 
11a93c7
f19162e
76e2719
11a93c7
f19162e
11a93c7
f19162e
 
 
11a93c7
 
 
f19162e
 
 
11a93c7
 
 
f19162e
 
 
 
 
 
11a93c7
 
f19162e
 
 
 
 
 
11a93c7
 
 
 
 
f19162e
11a93c7
 
 
f19162e
11a93c7
f19162e
 
11a93c7
 
 
f19162e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a93c7
 
f19162e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11a93c7
f19162e
11a93c7
f19162e
 
 
 
 
 
 
11a93c7
76e2719
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
from transformers import pipeline
from diffusers import StableDiffusionPipeline

ARTICLE_GENERATOR_MODEL = "gpt2"
SUMMARIZER_MODEL = "Falconsai/text_summarization"
TITLE_GENERATOR_MODEL = "czearing/article-title-generator"
IMAGE_GENERATOR_MODEL = "prompthero/openjourney-v4"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{DEVICE = }")

text_generator = pipeline(
    "text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE
)
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE)
title_generator = pipeline(
    "text2text-generation",
    model=TITLE_GENERATOR_MODEL,
    device=DEVICE,
)
image_generator = StableDiffusionPipeline.from_pretrained(
    IMAGE_GENERATOR_MODEL,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
)
image_generator = image_generator.to(DEVICE)


def generate_blog_post(query, article_length, title_length, summary_length):
    print("Generating article.")
    article = text_generator(query, max_length=article_length, num_return_sequences=1)[
        0
    ]["generated_text"]
    print(f"{article = }")

    print("Generating the title.")
    title = title_generator(article, max_length=title_length, num_return_sequences=1)[
        0
    ]["generated_text"]
    print(f"{title = }")

    print("Generating the summary.")
    summary = summarizer(
        article,
        max_length=summary_length,
        min_length=min(30, summary_length),
        do_sample=False,
    )[0]["summary_text"]
    print(f"{summary = }")

    print("Generating the cover image.")
    image = image_generator(
        summary, num_inference_steps=40, guidance_scale=7.5, width=512, height=512
    ).images[0]

    return title, summary, article, image


with gr.Blocks() as iface:
    gr.Markdown("# Blog Post Generator")
    gr.Markdown(
        "Enter a topic, and I'll generate a blog post with a title, cover image, and optional summary!"
    )

    with gr.Row():
        input_prompt = gr.Textbox(lines=2, placeholder="Enter your blog post topic...")

    with gr.Row():
        generate_button = gr.Button("Generate Blog Post", size="sm")

    with gr.Row():
        with gr.Column(scale=2):
            with gr.Blocks() as title_block:
                gr.Markdown("## Title")

                with gr.Accordion("Options", open=False):
                    title_length = gr.Slider(
                        minimum=10, maximum=50, value=30, step=5, label="Title Length"
                    )
                title_output = gr.Textbox(label="Title")

            with gr.Blocks() as body_block:
                gr.Markdown("## Body")

                with gr.Accordion("Options", open=False):
                    article_length = gr.Slider(
                        minimum=100,
                        maximum=1000,
                        value=500,
                        step=50,
                        label="Article Length",
                    )
                article_output = gr.Textbox(label="Article", lines=10)

        with gr.Column(scale=1):
            with gr.Blocks() as image_block:
                gr.Markdown("## Cover Image")
                image_output = gr.Image(label="Cover Image")

            with gr.Blocks() as summary_block:
                gr.Markdown("## Summary")
                with gr.Accordion("Options", open=False):
                    summary_length = gr.Slider(
                        minimum=30,
                        maximum=200,
                        value=100,
                        step=10,
                        label="Summary Length",
                    )
                summary_output = gr.Textbox(label="Summary", lines=5)

    job = generate_button.click(
        generate_blog_post,
        inputs=[
            input_prompt,
            article_length,
            title_length,
            summary_length,
        ],
        outputs=[title_output, summary_output, article_output, image_output],
    )

iface.launch()