Spaces:
Runtime error
Runtime error
File size: 3,080 Bytes
eed789a da92d39 eed789a 7283bfa eed789a 1d9e1c4 fe561f0 00b818c 1d9e1c4 eed789a 6358443 eed789a 422ecfe ff59753 eed789a 40b1ad3 eed789a 40b1ad3 eed789a 40b1ad3 1d9e1c4 40b1ad3 eed789a 40b1ad3 eed789a 1d9e1c4 eed789a 40b1ad3 eed789a 5a44bd5 eed789a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
os.system('pip install git+https://github.com/huggingface/transformers --upgrade')
import gradio as gr
from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalLM
import torch
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt
feature_extractor = ImageGPTFeatureExtractor.from_pretrained("openai/imagegpt-medium")
model = ImageGPTForCausalLM.from_pretrained("openai/imagegpt-medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# load image examples
urls = ['https://avatars.githubusercontent.com/u/326577?v=4',
'https://upload.wikimedia.org/wikipedia/commons/thumb/6/6e/Football_%28soccer_ball%29.svg/1200px-Football_%28soccer_ball%29.svg.png',
'https://i.imgflip.com/4/4t0m5.jpg',
]
for idx, url in enumerate(urls):
image = Image.open(requests.get(url, stream=True).raw)
image.save(f"image_{idx}.png")
def process_image(image):
# prepare 8 images, shape (8, 1024)
batch_size = 8
encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")
# create primers
samples = encoding.pixel_values.numpy()
n_px = feature_extractor.size
clusters = feature_extractor.clusters
n_px_crop = 16
primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
# generate (no beam search)
context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
context = torch.tensor(context).to(device)
output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)
# decode back to images (convert color cluster tokens back to pixels)
samples = output[:,1:].cpu().detach().numpy()
samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples]
# stack images horizontally
row1 = np.hstack(samples_img[:4])
row2 = np.hstack(samples_img[4:])
result = np.vstack([row1, row2])
# return as PIL Image
completion = Image.fromarray(result)
return completion
title = "Interactive demo: ImageGPT"
description = "Demo for OpenAI's ImageGPT: Generative Pretraining from Pixels. To use it, simply upload an image or use the example image below and click 'submit'. Results will show up in a few seconds."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>ImageGPT: Generative Pretraining from Pixels</a> | <a href='https://openai.com/blog/image-gpt/'>Official blog</a></p>"
examples =[f"image_{idx}.png" for idx in range(len(urls))]
iface = gr.Interface(fn=process_image,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Image(type="pil"),
title=title,
description=description,
article=article,
examples=examples,
enable_queue=True)
iface.launch(debug=True) |