Daemontatox commited on
Commit
8e36800
·
verified ·
1 Parent(s): ab9b588

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -96
app.py CHANGED
@@ -4,15 +4,15 @@ from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
7
-
8
  import numpy as np
9
  import os
10
  import time
11
  from Upsample import RealESRGAN
12
  import spaces # Import spaces for ZeroGPU compatibility
13
 
14
-
15
  # Load model and processor
 
16
  model_path = "deepseek-ai/Janus-Pro-7B"
17
  config = AutoConfig.from_pretrained(model_path)
18
  language_config = config.language_config
@@ -29,22 +29,25 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
29
  tokenizer = vl_chat_processor.tokenizer
30
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
 
32
- # SR model
33
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
34
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
35
 
 
 
 
36
  @torch.inference_mode()
37
- @spaces.GPU(duration=120)
38
- # Multimodal Understanding function
39
  def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
40
  # Clear CUDA cache before generating
41
  torch.cuda.empty_cache()
42
-
43
- # set seed
44
  torch.manual_seed(seed)
45
  np.random.seed(seed)
46
  torch.cuda.manual_seed(seed)
47
 
 
48
  conversation = [
49
  {
50
  "role": "<|User|>",
@@ -54,12 +57,12 @@ def multimodal_understanding(image, question, seed, top_p, temperature, progress
54
  {"role": "<|Assistant|>", "content": ""},
55
  ]
56
 
57
- pil_images = [Image.fromarray(image)]
 
58
  prepare_inputs = vl_chat_processor(
59
  conversations=conversation, images=pil_images, force_batchify=True
60
  ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
61
 
62
-
63
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
64
 
65
  outputs = vl_gpt.language_model.generate(
@@ -78,7 +81,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature, progress
78
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
79
  return answer
80
 
81
-
 
 
82
  def generate(input_ids,
83
  width,
84
  height,
@@ -88,7 +93,6 @@ def generate(input_ids,
88
  image_token_num_per_image: int = 576,
89
  patch_size: int = 16,
90
  progress=gr.Progress(track_tqdm=True)):
91
- # Clear CUDA cache before generating
92
  torch.cuda.empty_cache()
93
 
94
  tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
@@ -103,8 +107,8 @@ def generate(input_ids,
103
  for i in range(image_token_num_per_image):
104
  with torch.no_grad():
105
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
106
- use_cache=True,
107
- past_key_values=pkv)
108
  pkv = outputs.past_key_values
109
  hidden_states = outputs.last_hidden_state
110
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
@@ -118,35 +122,26 @@ def generate(input_ids,
118
 
119
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
120
  inputs_embeds = img_embeds.unsqueeze(dim=1)
121
-
122
 
123
-
124
  patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
125
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
126
-
127
  return generated_tokens.to(dtype=torch.int), patches
128
 
129
  def unpack(dec, width, height, parallel_size=5):
130
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
131
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
132
-
133
  visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
134
  visual_img[:, :, :] = dec
135
-
136
  return visual_img
137
 
138
-
139
-
140
  @torch.inference_mode()
141
- @spaces.GPU(duration=120) # Specify a duration to avoid timeout
142
  def generate_image(prompt,
143
  seed=None,
144
  guidance=5,
145
  t2i_temperature=1.0,
146
  progress=gr.Progress(track_tqdm=True)):
147
- # Clear CUDA cache and avoid tracking gradients
148
  torch.cuda.empty_cache()
149
- # Set the seed for reproducible results
150
  if seed is not None:
151
  torch.manual_seed(seed)
152
  torch.cuda.manual_seed(seed)
@@ -154,13 +149,13 @@ def generate_image(prompt,
154
  width = 384
155
  height = 384
156
  parallel_size = 4
157
-
158
  with torch.no_grad():
159
  messages = [{'role': '<|User|>', 'content': prompt},
160
  {'role': '<|Assistant|>', 'content': ''}]
161
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
162
- sft_format=vl_chat_processor.sft_format,
163
- system_prompt='')
164
  text = text + vl_chat_processor.image_start_tag
165
 
166
  input_ids = torch.LongTensor(tokenizer.encode(text))
@@ -174,13 +169,11 @@ def generate_image(prompt,
174
  width // 16 * 16,
175
  height // 16 * 16,
176
  parallel_size=parallel_size)
177
-
178
- # return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
179
  stime = time.time()
180
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
181
  print(f'upsample time: {time.time() - stime}')
182
- return ret_images
183
-
184
 
185
  @spaces.GPU(duration=60)
186
  def image_upsample(img: Image.Image) -> Image.Image:
@@ -188,87 +181,82 @@ def image_upsample(img: Image.Image) -> Image.Image:
188
  raise Exception("Image not uploaded")
189
 
190
  width, height = img.size
191
-
192
  if width >= 5000 or height >= 5000:
193
  raise Exception("The image is too large.")
194
 
195
  global sr_model
196
  result = sr_model.predict(img.convert('RGB'))
197
  return result
198
-
199
 
200
- # Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  css = '''
202
  .gradio-container {max-width: 960px !important}
203
  '''
204
- with gr.Blocks(css=css) as demo:
205
- gr.Markdown("# Janus Pro 7B")
206
- with gr.Tab("Multimodal Understanding"):
207
- gr.Markdown(value="## Multimodal Understanding")
208
- image_input = gr.Image()
209
- with gr.Column():
210
- question_input = gr.Textbox(label="Question")
211
-
212
- understanding_button = gr.Button("Chat")
213
- understanding_output = gr.Textbox(label="Response")
214
-
215
- with gr.Accordion("Advanced options", open=False):
216
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
217
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
218
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
219
-
220
- examples_inpainting = gr.Examples(
221
- label="Multimodal Understanding examples",
222
- examples=[
223
- [
224
- "explain this meme",
225
- "doge.png",
226
- ],
227
- [
228
- "Convert the formula into latex code.",
229
- "equation.png",
230
- ],
231
- ],
232
- inputs=[question_input, image_input],
233
- )
234
-
235
- with gr.Tab("Text-to-Image Generation"):
236
- gr.Markdown(value="## Text-to-Image Generation")
237
 
238
- prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!")
239
-
240
- generation_button = gr.Button("Generate Images")
241
 
242
- image_output = gr.Gallery(label="Generated Images", columns=4, rows=1)
 
 
243
 
244
- with gr.Accordion("Advanced options", open=False):
245
- with gr.Row():
246
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
247
- t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
248
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
249
-
250
- examples_t2i = gr.Examples(
251
- label="Text to image generation examples.",
252
- examples=[
253
- "Master shifu racoon wearing drip attire as a street gangster.",
254
- "The face of a beautiful girl",
255
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
256
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
257
- "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
258
- ],
259
- inputs=prompt_input,
260
- )
 
 
 
261
 
262
- understanding_button.click(
263
- multimodal_understanding,
264
- inputs=[image_input, question_input, und_seed_input, top_p, temperature],
265
- outputs=understanding_output
 
 
 
 
 
266
  )
267
 
268
- generation_button.click(
269
- fn=generate_image,
270
- inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
271
- outputs=image_output
 
272
  )
273
 
274
  demo.launch(share=True)
 
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
 
7
  import numpy as np
8
  import os
9
  import time
10
  from Upsample import RealESRGAN
11
  import spaces # Import spaces for ZeroGPU compatibility
12
 
13
+ # ---------------------------
14
  # Load model and processor
15
+ # ---------------------------
16
  model_path = "deepseek-ai/Janus-Pro-7B"
17
  config = AutoConfig.from_pretrained(model_path)
18
  language_config = config.language_config
 
29
  tokenizer = vl_chat_processor.tokenizer
30
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
 
32
+ # SR (Super Resolution) model
33
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
34
  sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
35
 
36
+ # ---------------------------
37
+ # Multimodal Understanding Function
38
+ # ---------------------------
39
  @torch.inference_mode()
40
+ @spaces.GPU(duration=120)
 
41
  def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
42
  # Clear CUDA cache before generating
43
  torch.cuda.empty_cache()
44
+
45
+ # Set seed for reproducibility
46
  torch.manual_seed(seed)
47
  np.random.seed(seed)
48
  torch.cuda.manual_seed(seed)
49
 
50
+ # Prepare conversation – note the use of a placeholder for the image.
51
  conversation = [
52
  {
53
  "role": "<|User|>",
 
57
  {"role": "<|Assistant|>", "content": ""},
58
  ]
59
 
60
+ # The chat processor expects PIL images.
61
+ pil_images = [Image.fromarray(np.array(image))] if not isinstance(image, Image.Image) else [image]
62
  prepare_inputs = vl_chat_processor(
63
  conversations=conversation, images=pil_images, force_batchify=True
64
  ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
65
 
 
66
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
67
 
68
  outputs = vl_gpt.language_model.generate(
 
81
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
82
  return answer
83
 
84
+ # ---------------------------
85
+ # Image Generation Functions
86
+ # ---------------------------
87
  def generate(input_ids,
88
  width,
89
  height,
 
93
  image_token_num_per_image: int = 576,
94
  patch_size: int = 16,
95
  progress=gr.Progress(track_tqdm=True)):
 
96
  torch.cuda.empty_cache()
97
 
98
  tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
 
107
  for i in range(image_token_num_per_image):
108
  with torch.no_grad():
109
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
110
+ use_cache=True,
111
+ past_key_values=pkv)
112
  pkv = outputs.past_key_values
113
  hidden_states = outputs.last_hidden_state
114
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
122
 
123
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
124
  inputs_embeds = img_embeds.unsqueeze(dim=1)
 
125
 
 
126
  patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
127
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
 
128
  return generated_tokens.to(dtype=torch.int), patches
129
 
130
  def unpack(dec, width, height, parallel_size=5):
131
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
132
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
 
133
  visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
134
  visual_img[:, :, :] = dec
 
135
  return visual_img
136
 
 
 
137
  @torch.inference_mode()
138
+ @spaces.GPU(duration=120)
139
  def generate_image(prompt,
140
  seed=None,
141
  guidance=5,
142
  t2i_temperature=1.0,
143
  progress=gr.Progress(track_tqdm=True)):
 
144
  torch.cuda.empty_cache()
 
145
  if seed is not None:
146
  torch.manual_seed(seed)
147
  torch.cuda.manual_seed(seed)
 
149
  width = 384
150
  height = 384
151
  parallel_size = 4
152
+
153
  with torch.no_grad():
154
  messages = [{'role': '<|User|>', 'content': prompt},
155
  {'role': '<|Assistant|>', 'content': ''}]
156
  text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
157
+ sft_format=vl_chat_processor.sft_format,
158
+ system_prompt='')
159
  text = text + vl_chat_processor.image_start_tag
160
 
161
  input_ids = torch.LongTensor(tokenizer.encode(text))
 
169
  width // 16 * 16,
170
  height // 16 * 16,
171
  parallel_size=parallel_size)
172
+ # Upsample the generated images
 
173
  stime = time.time()
174
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
175
  print(f'upsample time: {time.time() - stime}')
176
+ return ret_images # returns a list
 
177
 
178
  @spaces.GPU(duration=60)
179
  def image_upsample(img: Image.Image) -> Image.Image:
 
181
  raise Exception("Image not uploaded")
182
 
183
  width, height = img.size
 
184
  if width >= 5000 or height >= 5000:
185
  raise Exception("The image is too large.")
186
 
187
  global sr_model
188
  result = sr_model.predict(img.convert('RGB'))
189
  return result
 
190
 
191
+ # A helper function to generate a single image (the first result) from a description.
192
+ def generate_single_image(prompt, seed, guidance, t2i_temperature):
193
+ images = generate_image(prompt, seed, guidance, t2i_temperature)
194
+ # Return the first image (if available)
195
+ return images[0] if images else None
196
+
197
+ # ---------------------------
198
+ # Chat About Generated Image
199
+ # ---------------------------
200
+ # This function uses the generated image and a chat question.
201
+ def chat_about_image(generated_image, chat_text, seed, top_p, temperature, chat_history):
202
+ if generated_image is None:
203
+ return chat_history, "Please generate an image first by entering a description above."
204
+ response = multimodal_understanding(generated_image, chat_text, seed, top_p, temperature)
205
+ chat_history.append((chat_text, response))
206
+ return chat_history, ""
207
+
208
+ # ---------------------------
209
+ # Gradio Interface
210
+ # ---------------------------
211
  css = '''
212
  .gradio-container {max-width: 960px !important}
213
  '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ with gr.Blocks(css=css, title="Janus Pro 7B Image Generation and Chat") as demo:
216
+ gr.Markdown("# Janus Pro 7B: Image Generation and Conversation")
217
+ gr.Markdown("Enter an image description below to have the model generate an image. Once generated, you can chat about the image and ask questions.")
218
 
219
+ # States to store the generated image and the chat history.
220
+ state_image = gr.State(None)
221
+ state_history = gr.State([])
222
 
223
+ with gr.Row():
224
+ with gr.Column():
225
+ gr.Markdown("### Step 1. Generate an Image from Description")
226
+ description_input = gr.Textbox(label="Image Description", placeholder="Describe the image you want...")
227
+ with gr.Accordion("Advanced Generation Options", open=False):
228
+ gen_seed_input = gr.Number(label="Seed", precision=0, value=42)
229
+ guidance_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
230
+ t2i_temperature_input = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
231
+ generate_button = gr.Button("Generate Image")
232
+ image_output = gr.Image(label="Generated Image", interactive=False)
233
+ with gr.Column():
234
+ gr.Markdown("### Step 2. Chat about the Image")
235
+ gr.Markdown("Ask questions or discuss the generated image below. (If no image has been generated yet, please do so in Step 1.)")
236
+ with gr.Accordion("Advanced Chat Options", open=False):
237
+ chat_seed_input = gr.Number(label="Seed", precision=0, value=42)
238
+ top_p_input = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
239
+ chat_temperature_input = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
240
+ chatbox = gr.Chatbot(label="Conversation")
241
+ chat_input = gr.Textbox(label="Your Message", placeholder="Enter your question or comment here...")
242
+ send_button = gr.Button("Send")
243
 
244
+ # When the user clicks the "Generate Image" button:
245
+ generate_button.click(
246
+ fn=generate_single_image,
247
+ inputs=[description_input, gen_seed_input, guidance_input, t2i_temperature_input],
248
+ outputs=image_output
249
+ ).then(
250
+ fn=lambda img: img, # pass through the generated image
251
+ inputs=image_output,
252
+ outputs=state_image
253
  )
254
 
255
+ # When the user sends a chat message, update the conversation.
256
+ send_button.click(
257
+ fn=chat_about_image,
258
+ inputs=[state_image, chat_input, chat_seed_input, top_p_input, chat_temperature_input, state_history],
259
+ outputs=[chatbox, chat_input],
260
  )
261
 
262
  demo.launch(share=True)