dami1996 commited on
Commit
f19162e
·
1 Parent(s): d68ded6

cover generation and ui changes

Browse files
Files changed (1) hide show
  1. app.py +84 -27
app.py CHANGED
@@ -1,67 +1,124 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline
 
4
 
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
6
 
7
  text_generator = pipeline(
8
- "text-generation", model="openchat/openchat-3.5-0106", device=device
9
- )
10
- summarizer = pipeline(
11
- "summarization", model="sshleifer/distilbart-cnn-12-6", device=device
12
  )
 
13
  title_generator = pipeline(
14
  "text2text-generation",
15
- model="fabiochiu/t5-small-medium-title-generation",
16
- device=device,
 
 
 
 
17
  )
 
18
 
19
 
20
- def generate_blog_post(query):
21
  print("Generating article.")
22
- article = text_generator(query, max_length=500, num_return_sequences=1)[0][
23
- "generated_text"
24
- ]
25
  print(f"{article = }")
26
 
27
  print("Generating the title.")
28
- title = title_generator(article, max_length=30, num_return_sequences=1)[0][
29
- "generated_text"
30
- ]
31
  print(f"{title = }")
32
 
33
  print("Generating the summary.")
34
- summary = summarizer(article, max_length=100, min_length=30, do_sample=False)[0][
35
- "summary_text"
36
- ]
 
 
 
37
  print(f"{summary = }")
38
 
39
- return title, summary, article
 
 
 
 
 
40
 
41
 
42
  with gr.Blocks() as iface:
43
  gr.Markdown("# Blog Post Generator")
44
  gr.Markdown(
45
- "Enter a topic, and I'll generate a blog post with a title, cover image, and summary!"
46
  )
47
 
48
  with gr.Row():
49
- topic_input = gr.Textbox(lines=2, placeholder="Enter your blog post topic...")
50
 
51
- generate_button = gr.Button("Generate Blog Post", size="sm")
 
52
 
53
  with gr.Row():
54
  with gr.Column(scale=2):
55
- title_output = gr.Textbox(label="Title")
56
- article_output = gr.Textbox(label="Article", lines=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  with gr.Column(scale=1):
59
- summary_output = gr.Textbox(label="Summary", lines=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- generate_button.click(
62
  generate_blog_post,
63
- inputs=topic_input,
64
- outputs=[title_output, summary_output, article_output],
 
 
 
 
 
65
  )
66
 
67
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline
4
+ from diffusers import StableDiffusionPipeline
5
 
6
+ ARTICLE_GENERATOR_MODEL = "gpt2"
7
+ SUMMARIZER_MODEL = "Falconsai/text_summarization"
8
+ TITLE_GENERATOR_MODEL = "czearing/article-title-generator"
9
+ IMAGE_GENERATOR_MODEL = "prompthero/openjourney-v4"
10
+
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+ print(f"{DEVICE = }")
13
 
14
  text_generator = pipeline(
15
+ "text-generation", model=ARTICLE_GENERATOR_MODEL, device=DEVICE
 
 
 
16
  )
17
+ summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=DEVICE)
18
  title_generator = pipeline(
19
  "text2text-generation",
20
+ model=TITLE_GENERATOR_MODEL,
21
+ device=DEVICE,
22
+ )
23
+ image_generator = StableDiffusionPipeline.from_pretrained(
24
+ IMAGE_GENERATOR_MODEL,
25
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
26
  )
27
+ image_generator = image_generator.to(DEVICE)
28
 
29
 
30
+ def generate_blog_post(query, article_length, title_length, summary_length):
31
  print("Generating article.")
32
+ article = text_generator(query, max_length=article_length, num_return_sequences=1)[
33
+ 0
34
+ ]["generated_text"]
35
  print(f"{article = }")
36
 
37
  print("Generating the title.")
38
+ title = title_generator(article, max_length=title_length, num_return_sequences=1)[
39
+ 0
40
+ ]["generated_text"]
41
  print(f"{title = }")
42
 
43
  print("Generating the summary.")
44
+ summary = summarizer(
45
+ article,
46
+ max_length=summary_length,
47
+ min_length=min(30, summary_length),
48
+ do_sample=False,
49
+ )[0]["summary_text"]
50
  print(f"{summary = }")
51
 
52
+ print("Generating the cover image.")
53
+ image = image_generator(
54
+ summary, num_inference_steps=40, guidance_scale=7.5, width=512, height=512
55
+ ).images[0]
56
+
57
+ return title, summary, article, image
58
 
59
 
60
  with gr.Blocks() as iface:
61
  gr.Markdown("# Blog Post Generator")
62
  gr.Markdown(
63
+ "Enter a topic, and I'll generate a blog post with a title, cover image, and optional summary!"
64
  )
65
 
66
  with gr.Row():
67
+ input_prompt = gr.Textbox(lines=2, placeholder="Enter your blog post topic...")
68
 
69
+ with gr.Row():
70
+ generate_button = gr.Button("Generate Blog Post", size="sm")
71
 
72
  with gr.Row():
73
  with gr.Column(scale=2):
74
+ with gr.Blocks() as title_block:
75
+ gr.Markdown("## Title")
76
+
77
+ with gr.Accordion("Options", open=False):
78
+ title_length = gr.Slider(
79
+ minimum=10, maximum=50, value=30, step=5, label="Title Length"
80
+ )
81
+ title_output = gr.Textbox(label="Title")
82
+
83
+ with gr.Blocks() as body_block:
84
+ gr.Markdown("## Body")
85
+
86
+ with gr.Accordion("Options", open=False):
87
+ article_length = gr.Slider(
88
+ minimum=100,
89
+ maximum=1000,
90
+ value=500,
91
+ step=50,
92
+ label="Article Length",
93
+ )
94
+ article_output = gr.Textbox(label="Article", lines=10)
95
 
96
  with gr.Column(scale=1):
97
+ with gr.Blocks() as image_block:
98
+ gr.Markdown("## Cover Image")
99
+ image_output = gr.Image(label="Cover Image")
100
+
101
+ with gr.Blocks() as summary_block:
102
+ gr.Markdown("## Summary")
103
+ with gr.Accordion("Options", open=False):
104
+ summary_length = gr.Slider(
105
+ minimum=30,
106
+ maximum=200,
107
+ value=100,
108
+ step=10,
109
+ label="Summary Length",
110
+ )
111
+ summary_output = gr.Textbox(label="Summary", lines=5)
112
 
113
+ job = generate_button.click(
114
  generate_blog_post,
115
+ inputs=[
116
+ input_prompt,
117
+ article_length,
118
+ title_length,
119
+ summary_length,
120
+ ],
121
+ outputs=[title_output, summary_output, article_output, image_output],
122
  )
123
 
124
  iface.launch()