lotrtest / app.py
mikegarts's picture
Update app.py
3d41673
raw
history blame
5.34 kB
import time
import os
import PIL
import gradio as gr
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from diffusers import StableDiffusionPipeline
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "CompVis/stable-diffusion-v1-4"
has_cuda = torch.cuda.is_available()
if has_cuda:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
device = "cuda"
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN)
device = "cpu"
pipe.to(device)
def safety_checker(images, clip_input):
return images, False
pipe.safety_checker = safety_checker
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
summarizer = pipeline("summarization")
#######################################################
#######################################################
def break_until_dot(txt):
return txt.rsplit('.', 1)[0] + '.'
def generate(prompt):
input_context = prompt
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
outputs = model.generate(
input_ids=input_ids,
max_length=120,
min_length=50,
temperature=0.7,
num_return_sequences=3,
do_sample=True
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return break_until_dot(decoded)
def generate_story(prompt):
story = generate(prompt=prompt)
summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
summary = break_until_dot(summary)
return story, summary, gr.update(visible=True)
def on_change_event(app_state):
if app_state and app_state['running'] and app_state['img']:
img = app_state['img']
step = app_state['step']
label = f'Reconstructed image from the latent state at step {step}. It will get better :)'
print(f'Updating the image:! {app_state}')
return gr.update(value=img, label=label)
else:
return gr.update(label='Illustration will appear here soon')
with gr.Blocks() as demo:
def generate_image(prompt, inference_steps, app_state):
app_state['running'] = True
def callback(step, ts, latents):
print (f'In Callback on {step} {ts} !')
latents = 1 / 0.18215 * latents
res = pipe.vae.decode(latents).sample
res = (res / 2 + 0.5).clamp(0, 1)
res = res.cpu().permute(0, 2, 3, 1).detach().numpy()
res = pipe.numpy_to_pil(res)[0]
app_state['img'] = res
app_state['step'] = step
print (f'In Callback on {app_state} Done!')
prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=3)
app_state['running'] = False
app_state['img'] = None
return gr.update(value=img.images[0], label='Generated image')
app_state = gr.State({'img': None,
'step':0,
'running':False})
title = gr.Markdown('## Lord of the rings app')
description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
f' The text2img model is {model_id}. The summarization is done using distilbart.')
prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
story = gr.Textbox(label="Your story")
summary = gr.Textbox(label="Summary")
bt_make_text = gr.Button("Generate text")
bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
img_description = gr.Markdown('Image generation takes some time'
' but here you can see here what is generated from the latent state of the diffuser every few steps.'
' Usually there is a significant improvement around step 12 that yields much better result')
image = gr.Image(label='Illustration for your story', show_label=True)
inference_steps = gr.Slider(5, 30,
value=20,
step=1,
visible=True,
label=f"Num inference steps (more steps makes a better image but takes more time)")
bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image)
eventslider = gr.Slider(visible=False)
dep = demo.load(on_change_event, app_state, image, every=10)
eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[image], every=10, cancels=[dep])
if READ_TOKEN:
demo.queue().launch()
else:
demo.queue().launch(share=True, debug=True)