Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,15 +4,15 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from janus.utils.io import load_pil_images
|
6 |
from PIL import Image
|
7 |
-
|
8 |
import numpy as np
|
9 |
import os
|
10 |
import time
|
11 |
from Upsample import RealESRGAN
|
12 |
import spaces # Import spaces for ZeroGPU compatibility
|
13 |
|
14 |
-
|
15 |
# Load model and processor
|
|
|
16 |
model_path = "deepseek-ai/Janus-Pro-7B"
|
17 |
config = AutoConfig.from_pretrained(model_path)
|
18 |
language_config = config.language_config
|
@@ -29,22 +29,25 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
|
29 |
tokenizer = vl_chat_processor.tokenizer
|
30 |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
31 |
|
32 |
-
# SR model
|
33 |
sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
|
34 |
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
35 |
|
|
|
|
|
|
|
36 |
@torch.inference_mode()
|
37 |
-
@spaces.GPU(duration=120)
|
38 |
-
# Multimodal Understanding function
|
39 |
def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
|
40 |
# Clear CUDA cache before generating
|
41 |
torch.cuda.empty_cache()
|
42 |
-
|
43 |
-
#
|
44 |
torch.manual_seed(seed)
|
45 |
np.random.seed(seed)
|
46 |
torch.cuda.manual_seed(seed)
|
47 |
|
|
|
48 |
conversation = [
|
49 |
{
|
50 |
"role": "<|User|>",
|
@@ -54,12 +57,12 @@ def multimodal_understanding(image, question, seed, top_p, temperature, progress
|
|
54 |
{"role": "<|Assistant|>", "content": ""},
|
55 |
]
|
56 |
|
57 |
-
|
|
|
58 |
prepare_inputs = vl_chat_processor(
|
59 |
conversations=conversation, images=pil_images, force_batchify=True
|
60 |
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
61 |
|
62 |
-
|
63 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
64 |
|
65 |
outputs = vl_gpt.language_model.generate(
|
@@ -78,7 +81,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature, progress
|
|
78 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
79 |
return answer
|
80 |
|
81 |
-
|
|
|
|
|
82 |
def generate(input_ids,
|
83 |
width,
|
84 |
height,
|
@@ -88,7 +93,6 @@ def generate(input_ids,
|
|
88 |
image_token_num_per_image: int = 576,
|
89 |
patch_size: int = 16,
|
90 |
progress=gr.Progress(track_tqdm=True)):
|
91 |
-
# Clear CUDA cache before generating
|
92 |
torch.cuda.empty_cache()
|
93 |
|
94 |
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
@@ -103,8 +107,8 @@ def generate(input_ids,
|
|
103 |
for i in range(image_token_num_per_image):
|
104 |
with torch.no_grad():
|
105 |
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
|
106 |
-
|
107 |
-
|
108 |
pkv = outputs.past_key_values
|
109 |
hidden_states = outputs.last_hidden_state
|
110 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
@@ -118,35 +122,26 @@ def generate(input_ids,
|
|
118 |
|
119 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
120 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
121 |
-
|
122 |
|
123 |
-
|
124 |
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
125 |
-
|
126 |
-
|
127 |
return generated_tokens.to(dtype=torch.int), patches
|
128 |
|
129 |
def unpack(dec, width, height, parallel_size=5):
|
130 |
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
131 |
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
132 |
-
|
133 |
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
134 |
visual_img[:, :, :] = dec
|
135 |
-
|
136 |
return visual_img
|
137 |
|
138 |
-
|
139 |
-
|
140 |
@torch.inference_mode()
|
141 |
-
@spaces.GPU(duration=120)
|
142 |
def generate_image(prompt,
|
143 |
seed=None,
|
144 |
guidance=5,
|
145 |
t2i_temperature=1.0,
|
146 |
progress=gr.Progress(track_tqdm=True)):
|
147 |
-
# Clear CUDA cache and avoid tracking gradients
|
148 |
torch.cuda.empty_cache()
|
149 |
-
# Set the seed for reproducible results
|
150 |
if seed is not None:
|
151 |
torch.manual_seed(seed)
|
152 |
torch.cuda.manual_seed(seed)
|
@@ -154,13 +149,13 @@ def generate_image(prompt,
|
|
154 |
width = 384
|
155 |
height = 384
|
156 |
parallel_size = 4
|
157 |
-
|
158 |
with torch.no_grad():
|
159 |
messages = [{'role': '<|User|>', 'content': prompt},
|
160 |
{'role': '<|Assistant|>', 'content': ''}]
|
161 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
162 |
-
|
163 |
-
|
164 |
text = text + vl_chat_processor.image_start_tag
|
165 |
|
166 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
@@ -174,13 +169,11 @@ def generate_image(prompt,
|
|
174 |
width // 16 * 16,
|
175 |
height // 16 * 16,
|
176 |
parallel_size=parallel_size)
|
177 |
-
|
178 |
-
# return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
|
179 |
stime = time.time()
|
180 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
181 |
print(f'upsample time: {time.time() - stime}')
|
182 |
-
return ret_images
|
183 |
-
|
184 |
|
185 |
@spaces.GPU(duration=60)
|
186 |
def image_upsample(img: Image.Image) -> Image.Image:
|
@@ -188,87 +181,82 @@ def image_upsample(img: Image.Image) -> Image.Image:
|
|
188 |
raise Exception("Image not uploaded")
|
189 |
|
190 |
width, height = img.size
|
191 |
-
|
192 |
if width >= 5000 or height >= 5000:
|
193 |
raise Exception("The image is too large.")
|
194 |
|
195 |
global sr_model
|
196 |
result = sr_model.predict(img.convert('RGB'))
|
197 |
return result
|
198 |
-
|
199 |
|
200 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
css = '''
|
202 |
.gradio-container {max-width: 960px !important}
|
203 |
'''
|
204 |
-
with gr.Blocks(css=css) as demo:
|
205 |
-
gr.Markdown("# Janus Pro 7B")
|
206 |
-
with gr.Tab("Multimodal Understanding"):
|
207 |
-
gr.Markdown(value="## Multimodal Understanding")
|
208 |
-
image_input = gr.Image()
|
209 |
-
with gr.Column():
|
210 |
-
question_input = gr.Textbox(label="Question")
|
211 |
-
|
212 |
-
understanding_button = gr.Button("Chat")
|
213 |
-
understanding_output = gr.Textbox(label="Response")
|
214 |
-
|
215 |
-
with gr.Accordion("Advanced options", open=False):
|
216 |
-
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
217 |
-
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
218 |
-
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
219 |
-
|
220 |
-
examples_inpainting = gr.Examples(
|
221 |
-
label="Multimodal Understanding examples",
|
222 |
-
examples=[
|
223 |
-
[
|
224 |
-
"explain this meme",
|
225 |
-
"doge.png",
|
226 |
-
],
|
227 |
-
[
|
228 |
-
"Convert the formula into latex code.",
|
229 |
-
"equation.png",
|
230 |
-
],
|
231 |
-
],
|
232 |
-
inputs=[question_input, image_input],
|
233 |
-
)
|
234 |
-
|
235 |
-
with gr.Tab("Text-to-Image Generation"):
|
236 |
-
gr.Markdown(value="## Text-to-Image Generation")
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
|
242 |
-
|
|
|
|
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
266 |
)
|
267 |
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
272 |
)
|
273 |
|
274 |
demo.launch(share=True)
|
|
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from janus.utils.io import load_pil_images
|
6 |
from PIL import Image
|
|
|
7 |
import numpy as np
|
8 |
import os
|
9 |
import time
|
10 |
from Upsample import RealESRGAN
|
11 |
import spaces # Import spaces for ZeroGPU compatibility
|
12 |
|
13 |
+
# ---------------------------
|
14 |
# Load model and processor
|
15 |
+
# ---------------------------
|
16 |
model_path = "deepseek-ai/Janus-Pro-7B"
|
17 |
config = AutoConfig.from_pretrained(model_path)
|
18 |
language_config = config.language_config
|
|
|
29 |
tokenizer = vl_chat_processor.tokenizer
|
30 |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
31 |
|
32 |
+
# SR (Super Resolution) model
|
33 |
sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
|
34 |
sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
|
35 |
|
36 |
+
# ---------------------------
|
37 |
+
# Multimodal Understanding Function
|
38 |
+
# ---------------------------
|
39 |
@torch.inference_mode()
|
40 |
+
@spaces.GPU(duration=120)
|
|
|
41 |
def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
|
42 |
# Clear CUDA cache before generating
|
43 |
torch.cuda.empty_cache()
|
44 |
+
|
45 |
+
# Set seed for reproducibility
|
46 |
torch.manual_seed(seed)
|
47 |
np.random.seed(seed)
|
48 |
torch.cuda.manual_seed(seed)
|
49 |
|
50 |
+
# Prepare conversation – note the use of a placeholder for the image.
|
51 |
conversation = [
|
52 |
{
|
53 |
"role": "<|User|>",
|
|
|
57 |
{"role": "<|Assistant|>", "content": ""},
|
58 |
]
|
59 |
|
60 |
+
# The chat processor expects PIL images.
|
61 |
+
pil_images = [Image.fromarray(np.array(image))] if not isinstance(image, Image.Image) else [image]
|
62 |
prepare_inputs = vl_chat_processor(
|
63 |
conversations=conversation, images=pil_images, force_batchify=True
|
64 |
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
65 |
|
|
|
66 |
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
67 |
|
68 |
outputs = vl_gpt.language_model.generate(
|
|
|
81 |
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
82 |
return answer
|
83 |
|
84 |
+
# ---------------------------
|
85 |
+
# Image Generation Functions
|
86 |
+
# ---------------------------
|
87 |
def generate(input_ids,
|
88 |
width,
|
89 |
height,
|
|
|
93 |
image_token_num_per_image: int = 576,
|
94 |
patch_size: int = 16,
|
95 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
96 |
torch.cuda.empty_cache()
|
97 |
|
98 |
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
|
|
|
107 |
for i in range(image_token_num_per_image):
|
108 |
with torch.no_grad():
|
109 |
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
|
110 |
+
use_cache=True,
|
111 |
+
past_key_values=pkv)
|
112 |
pkv = outputs.past_key_values
|
113 |
hidden_states = outputs.last_hidden_state
|
114 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
|
|
122 |
|
123 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
124 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
|
|
125 |
|
|
|
126 |
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
|
127 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size])
|
|
|
128 |
return generated_tokens.to(dtype=torch.int), patches
|
129 |
|
130 |
def unpack(dec, width, height, parallel_size=5):
|
131 |
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
132 |
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
|
|
133 |
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
134 |
visual_img[:, :, :] = dec
|
|
|
135 |
return visual_img
|
136 |
|
|
|
|
|
137 |
@torch.inference_mode()
|
138 |
+
@spaces.GPU(duration=120)
|
139 |
def generate_image(prompt,
|
140 |
seed=None,
|
141 |
guidance=5,
|
142 |
t2i_temperature=1.0,
|
143 |
progress=gr.Progress(track_tqdm=True)):
|
|
|
144 |
torch.cuda.empty_cache()
|
|
|
145 |
if seed is not None:
|
146 |
torch.manual_seed(seed)
|
147 |
torch.cuda.manual_seed(seed)
|
|
|
149 |
width = 384
|
150 |
height = 384
|
151 |
parallel_size = 4
|
152 |
+
|
153 |
with torch.no_grad():
|
154 |
messages = [{'role': '<|User|>', 'content': prompt},
|
155 |
{'role': '<|Assistant|>', 'content': ''}]
|
156 |
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
157 |
+
sft_format=vl_chat_processor.sft_format,
|
158 |
+
system_prompt='')
|
159 |
text = text + vl_chat_processor.image_start_tag
|
160 |
|
161 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
|
|
169 |
width // 16 * 16,
|
170 |
height // 16 * 16,
|
171 |
parallel_size=parallel_size)
|
172 |
+
# Upsample the generated images
|
|
|
173 |
stime = time.time()
|
174 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
175 |
print(f'upsample time: {time.time() - stime}')
|
176 |
+
return ret_images # returns a list
|
|
|
177 |
|
178 |
@spaces.GPU(duration=60)
|
179 |
def image_upsample(img: Image.Image) -> Image.Image:
|
|
|
181 |
raise Exception("Image not uploaded")
|
182 |
|
183 |
width, height = img.size
|
|
|
184 |
if width >= 5000 or height >= 5000:
|
185 |
raise Exception("The image is too large.")
|
186 |
|
187 |
global sr_model
|
188 |
result = sr_model.predict(img.convert('RGB'))
|
189 |
return result
|
|
|
190 |
|
191 |
+
# A helper function to generate a single image (the first result) from a description.
|
192 |
+
def generate_single_image(prompt, seed, guidance, t2i_temperature):
|
193 |
+
images = generate_image(prompt, seed, guidance, t2i_temperature)
|
194 |
+
# Return the first image (if available)
|
195 |
+
return images[0] if images else None
|
196 |
+
|
197 |
+
# ---------------------------
|
198 |
+
# Chat About Generated Image
|
199 |
+
# ---------------------------
|
200 |
+
# This function uses the generated image and a chat question.
|
201 |
+
def chat_about_image(generated_image, chat_text, seed, top_p, temperature, chat_history):
|
202 |
+
if generated_image is None:
|
203 |
+
return chat_history, "Please generate an image first by entering a description above."
|
204 |
+
response = multimodal_understanding(generated_image, chat_text, seed, top_p, temperature)
|
205 |
+
chat_history.append((chat_text, response))
|
206 |
+
return chat_history, ""
|
207 |
+
|
208 |
+
# ---------------------------
|
209 |
+
# Gradio Interface
|
210 |
+
# ---------------------------
|
211 |
css = '''
|
212 |
.gradio-container {max-width: 960px !important}
|
213 |
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
+
with gr.Blocks(css=css, title="Janus Pro 7B – Image Generation and Chat") as demo:
|
216 |
+
gr.Markdown("# Janus Pro 7B: Image Generation and Conversation")
|
217 |
+
gr.Markdown("Enter an image description below to have the model generate an image. Once generated, you can chat about the image and ask questions.")
|
218 |
|
219 |
+
# States to store the generated image and the chat history.
|
220 |
+
state_image = gr.State(None)
|
221 |
+
state_history = gr.State([])
|
222 |
|
223 |
+
with gr.Row():
|
224 |
+
with gr.Column():
|
225 |
+
gr.Markdown("### Step 1. Generate an Image from Description")
|
226 |
+
description_input = gr.Textbox(label="Image Description", placeholder="Describe the image you want...")
|
227 |
+
with gr.Accordion("Advanced Generation Options", open=False):
|
228 |
+
gen_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
229 |
+
guidance_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
|
230 |
+
t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
|
231 |
+
generate_button = gr.Button("Generate Image")
|
232 |
+
image_output = gr.Image(label="Generated Image", interactive=False)
|
233 |
+
with gr.Column():
|
234 |
+
gr.Markdown("### Step 2. Chat about the Image")
|
235 |
+
gr.Markdown("Ask questions or discuss the generated image below. (If no image has been generated yet, please do so in Step 1.)")
|
236 |
+
with gr.Accordion("Advanced Chat Options", open=False):
|
237 |
+
chat_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
238 |
+
top_p_input = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
239 |
+
chat_temperature_input = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
|
240 |
+
chatbox = gr.Chatbot(label="Conversation")
|
241 |
+
chat_input = gr.Textbox(label="Your Message", placeholder="Enter your question or comment here...")
|
242 |
+
send_button = gr.Button("Send")
|
243 |
|
244 |
+
# When the user clicks the "Generate Image" button:
|
245 |
+
generate_button.click(
|
246 |
+
fn=generate_single_image,
|
247 |
+
inputs=[description_input, gen_seed_input, guidance_input, t2i_temperature_input],
|
248 |
+
outputs=image_output
|
249 |
+
).then(
|
250 |
+
fn=lambda img: img, # pass through the generated image
|
251 |
+
inputs=image_output,
|
252 |
+
outputs=state_image
|
253 |
)
|
254 |
|
255 |
+
# When the user sends a chat message, update the conversation.
|
256 |
+
send_button.click(
|
257 |
+
fn=chat_about_image,
|
258 |
+
inputs=[state_image, chat_input, chat_seed_input, top_p_input, chat_temperature_input, state_history],
|
259 |
+
outputs=[chatbox, chat_input],
|
260 |
)
|
261 |
|
262 |
demo.launch(share=True)
|