Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
from transformers import pipeline | |
import torch | |
import os | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int | |
config = { | |
"model_name": "smangrul/Multimodal-Challenge", | |
"base_model_name": "distilbert-base-uncased", | |
"image_gen_model": "biggan-deep-512", | |
"max_length": 20, | |
"freeze_text_model": True, | |
"freeze_image_gen_model": True, | |
"text_embedding_dim": 768, | |
"class_embedding_dim": 128 | |
} | |
truncation=0.4 | |
is_gpu = False | |
device = torch.device('cuda') if is_gpu else torch.device('cpu') | |
print(device) | |
model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get( | |
'huggingface-api-token')) | |
tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"]) | |
model.to(device) | |
model.eval() | |
gan_model = BigGAN.from_pretrained(config["image_gen_model"]) | |
gan_model.to(device) | |
gan_model.eval() | |
print("Models were loaded") | |
def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4): | |
seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None | |
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed) | |
noise_vector = torch.from_numpy(noise_vector) | |
if int_index is not None: | |
class_vector = one_hot_from_int([int_index], batch_size=1) | |
class_vector = torch.from_numpy(class_vector) | |
dense_class_vector = gan_model.embeddings(class_vector) | |
else: | |
if isinstance(dense_class_vector, np.ndarray): | |
dense_class_vector = torch.tensor(dense_class_vector) | |
dense_class_vector = dense_class_vector.view(1, 128) | |
input_vector = torch.cat([noise_vector, dense_class_vector], dim=1) | |
# Generate an image | |
with torch.no_grad(): | |
output = gan_model.generator(input_vector, truncation) | |
output = output.cpu().numpy() | |
output = output.transpose((0, 2, 3, 1)) | |
output = ((output + 1.0) / 2.0) * 256 | |
output.clip(0, 255, out=output) | |
output = np.asarray(np.uint8(output[0]), dtype=np.uint8) | |
return output | |
def print_image(numpy_array): | |
""" Utility function to print a numpy uint8 array as an image | |
""" | |
img = Image.fromarray(numpy_array) | |
plt.imshow(img) | |
plt.show() | |
def text_to_image(text): | |
tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device) | |
with torch.no_grad(): | |
lm_output = model(tokens, return_dict=True) | |
pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist() | |
print(pred_int_index) | |
# Now generate an image (a numpy array) | |
numpy_image = generate_image(int_index=pred_int_index, | |
truncation=truncation, | |
noise_seed_vector=tokens) | |
img = Image.fromarray(numpy_image) | |
#print_image(numpy_image) | |
return img | |
examples = ["a high resoltuion photo of a pizza from famous food magzine.", | |
"this is a photo of my pet golden retriever.", | |
"this is a photo of a trouble some street cat.", | |
"a blur image of coral reef.", | |
"a yellow taxi cab commonly found in USA.", | |
"Once upon a time, there was a black ship full of pirates.", | |
"a photo of a large castle.", | |
"a sketch of an old Church"] | |
if __name__ == '__main__': | |
interFace = gr.Interface(fn=text_to_image, | |
inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text " | |
"query", | |
lines=1), | |
outputs=gr.outputs.Image(type="auto", label="Generated Image"), | |
verbose=True, | |
examples=examples, | |
title="Generate Image from Text", | |
description="", | |
theme="huggingface") | |
interFace.launch() | |