Spaces:
Runtime error
Runtime error
Commit
·
2d0cabb
1
Parent(s):
2573f42
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForSequenceClassification
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from transformers import pipeline
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
from matplotlib import pyplot as plt
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from pytorch_pretrained_biggan import BigGAN, truncated_noise_sample, one_hot_from_names, one_hot_from_int
|
15 |
+
|
16 |
+
config = {
|
17 |
+
"model_name": "smangrul/Multimodal-Challenge",
|
18 |
+
"base_model_name": "distilbert-base-uncased",
|
19 |
+
"image_gen_model": "biggan-deep-128",
|
20 |
+
"max_length": 20,
|
21 |
+
"freeze_text_model": True,
|
22 |
+
"freeze_image_gen_model": True,
|
23 |
+
"text_embedding_dim": 768,
|
24 |
+
"class_embedding_dim": 128
|
25 |
+
}
|
26 |
+
truncation=0.4
|
27 |
+
|
28 |
+
is_gpu = False
|
29 |
+
device = torch.device('cuda') if is_gpu else torch.device('cpu')
|
30 |
+
print(device)
|
31 |
+
|
32 |
+
model = AutoModelForSequenceClassification.from_pretrained(config["model_name"], use_auth_token=os.environ.get(
|
33 |
+
'huggingface-api-token'))
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(config["base_model_name"])
|
35 |
+
model.to(device)
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
gan_model = BigGAN.from_pretrained(config["image_gen_model"])
|
39 |
+
gan_model.to(device)
|
40 |
+
gan_model.eval()
|
41 |
+
print("Models were loaded")
|
42 |
+
|
43 |
+
|
44 |
+
def generate_image(dense_class_vector=None, int_index=None, noise_seed_vector=None, truncation=0.4):
|
45 |
+
seed = int(noise_seed_vector.sum().item()) if noise_seed_vector is not None else None
|
46 |
+
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1, seed=seed)
|
47 |
+
noise_vector = torch.from_numpy(noise_vector)
|
48 |
+
if int_index is not None:
|
49 |
+
class_vector = one_hot_from_int([int_index], batch_size=1)
|
50 |
+
class_vector = torch.from_numpy(class_vector)
|
51 |
+
dense_class_vector = gan_model.embeddings(class_vector)
|
52 |
+
else:
|
53 |
+
if isinstance(dense_class_vector, np.ndarray):
|
54 |
+
dense_class_vector = torch.tensor(dense_class_vector)
|
55 |
+
dense_class_vector = dense_class_vector.view(1, 128)
|
56 |
+
|
57 |
+
input_vector = torch.cat([noise_vector, dense_class_vector], dim=1)
|
58 |
+
|
59 |
+
# Generate an image
|
60 |
+
with torch.no_grad():
|
61 |
+
output = gan_model.generator(input_vector, truncation)
|
62 |
+
output = output.cpu().numpy()
|
63 |
+
output = output.transpose((0, 2, 3, 1))
|
64 |
+
output = ((output + 1.0) / 2.0) * 256
|
65 |
+
output.clip(0, 255, out=output)
|
66 |
+
output = np.asarray(np.uint8(output[0]), dtype=np.uint8)
|
67 |
+
return output
|
68 |
+
|
69 |
+
|
70 |
+
def print_image(numpy_array):
|
71 |
+
""" Utility function to print a numpy uint8 array as an image
|
72 |
+
"""
|
73 |
+
img = Image.fromarray(numpy_array)
|
74 |
+
plt.imshow(img)
|
75 |
+
plt.show()
|
76 |
+
|
77 |
+
|
78 |
+
def text_to_image(text):
|
79 |
+
tokens = tokenizer.encode(text, add_special_tokens=True, return_tensors='pt').to(device)
|
80 |
+
with torch.no_grad():
|
81 |
+
lm_output = model(tokens, return_dict=True)
|
82 |
+
pred_int_index = torch.argmax(lm_output.logits[0], dim=-1).cpu().detach().numpy().tolist()
|
83 |
+
print(pred_int_index)
|
84 |
+
|
85 |
+
# Now generate an image (a numpy array)
|
86 |
+
numpy_image = generate_image(int_index=pred_int_index,
|
87 |
+
truncation=truncation,
|
88 |
+
noise_seed_vector=tokens)
|
89 |
+
|
90 |
+
img = Image.fromarray(numpy_image)
|
91 |
+
#print_image(numpy_image)
|
92 |
+
return img
|
93 |
+
|
94 |
+
examples = ["a high resoltuion photo of a pizza from famous food magzine.",
|
95 |
+
"this is a photo of my pet golden retriever.",
|
96 |
+
"this is a photo of a trouble some street cat.",
|
97 |
+
"a blur image of coral reef.",
|
98 |
+
"a yellow taxi cab commonly found in USA.",
|
99 |
+
"Once upon a time, there was a black ship full of pirates.",
|
100 |
+
"a photo of a large castle.",
|
101 |
+
"a sketch of an old Church"]
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
interFace = gr.Interface(fn=text_to_image,
|
105 |
+
inputs=gr.inputs.Textbox(placeholder="Enter the text to generate an image", label="Text "
|
106 |
+
"query",
|
107 |
+
lines=1),
|
108 |
+
outputs=gr.outputs.Image(type="auto", label="Generated Image"),
|
109 |
+
verbose=True,
|
110 |
+
examples=examples,
|
111 |
+
title="Generate Image from Text",
|
112 |
+
description="",
|
113 |
+
theme="huggingface")
|
114 |
+
interFace.launch()
|