mikegarts commited on
Commit
bb6c48a
·
1 Parent(s): c508fe5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import gradio as gr
4
+
5
+ import torch
6
+ import transformers
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from transformers import pipeline
9
+ from diffusers import StableDiffusionPipeline
10
+
11
+ READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
12
+
13
+ model_id = "runwayml/stable-diffusion-v1-5"
14
+ # model_id = "CompVis/stable-diffusion-v1-4"
15
+
16
+ has_cuda = torch.cuda.is_available()
17
+
18
+ if has_cuda:
19
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
20
+ device = "cuda"
21
+ else:
22
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN)
23
+ device = "cpu"
24
+
25
+ pipe.to(device)
26
+ def safety_checker(images, clip_input):
27
+ return images, False
28
+ pipe.safety_checker = safety_checker
29
+
30
+ SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
31
+ model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
32
+ tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
33
+
34
+ summarizer = pipeline("summarization")
35
+
36
+ def break_until_dot(txt):
37
+ return txt.rsplit('.', 1)[0] + '.'
38
+
39
+ def generate(prompt):
40
+ input_context = prompt
41
+ input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
42
+
43
+ outputs = model.generate(
44
+ input_ids=input_ids,
45
+ max_length=180,
46
+ temperature=0.7,
47
+ num_return_sequences=3,
48
+ do_sample=True
49
+ )
50
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ return break_until_dot(decoded)
52
+
53
+
54
+ def generate_story(prompt):
55
+ story = generate(prompt=prompt)
56
+ summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
57
+ summary = break_until_dot(summary)
58
+ return story, summary, gr.update(visible=True)
59
+
60
+ def on_change_event(app_state=None):
61
+ print(f'In change event!')
62
+ if app_state and app_state['running']:
63
+ img = app_state['img']
64
+ step = app_state['step']
65
+ label = f'Reconstructed image from the latent state at step {step}'
66
+ return gr.update(value=img, label=label)
67
+ else:
68
+ return None
69
+
70
+ with gr.Blocks() as demo:
71
+
72
+ def generate_image(prompt, inference_steps, app_state):
73
+ app_state['running'] = True
74
+ def callback(step, ts, latents):
75
+ print (f'In Callback on {step}!')
76
+ latents = 1 / 0.18215 * latents
77
+ res = pipe.vae.decode(latents).sample
78
+ res = (res / 2 + 0.5).clamp(0, 1)
79
+ res = res.cpu().permute(0, 2, 3, 1).detach().numpy()
80
+ res = pipe.numpy_to_pil(res)[0]
81
+ app_state['img'] = res
82
+ app_state['step'] = step
83
+
84
+ prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
85
+ img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=5)
86
+ app_state['running'] = False
87
+ return gr.update(value=img.images[0], label='Generated image')
88
+
89
+ app_state = gr.State({'img': None,
90
+ 'step':0,
91
+ 'running':False})
92
+ title = gr.Markdown('## Lord of the rings app')
93
+ description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
94
+ f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
95
+ f' The text2img model is {model_id}. The summarization is done using distilbart.')
96
+ prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
97
+ story = gr.Textbox(label="Your story")
98
+ summary = gr.Textbox(label="Summary")
99
+
100
+ bt_make_text = gr.Button("Generate text")
101
+ bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
102
+
103
+ img_description = gr.Markdown('Image generation take some time'
104
+ ' but here you can see the what is generated from the latent state of the diffuser every few steps.'
105
+ ' Usually there is a significant improvement around step 15, that yields much better result')
106
+ image = gr.Image(label='Illustration for your story', shape=(512, 512), show_label=True)
107
+
108
+ inference_steps = gr.Slider(5, 30,
109
+ value=15,
110
+ step=1,
111
+ visible=True,
112
+ label=f"Num inference steps (more steps makes a better image but takes more time)")
113
+
114
+
115
+ bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
116
+ bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image)
117
+
118
+ # bt_boo = gr.Button("Click me")
119
+ # bt_boo.click(fn=on_change_event, inputs=app_state, outputs=image, every=1)
120
+ # eventslider = gr.Slider(label='Boooo!')
121
+ # dep = demo.load(on_change_event, None, None, every=1)
122
+ # eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[image], every=1, cancels=[dep])
123
+ inference_steps.change(fn=on_change_event, inputs=[app_state], outputs=[image], every=1)
124
+
125
+
126
+ if READ_TOKEN:
127
+ demo.queue().launch()
128
+ else:
129
+ demo.queue().launch(share=True, debug=True)