ginipick commited on
Commit
1d8062c
·
verified ·
1 Parent(s): 5f97a8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -112
app.py CHANGED
@@ -4,14 +4,12 @@ from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
7
-
8
  import numpy as np
9
  import os
10
  import time
11
  from Upsample import RealESRGAN
12
  import spaces # Import spaces for ZeroGPU compatibility
13
 
14
-
15
  # Load model and processor
16
  model_path = "deepseek-ai/Janus-Pro-7B"
17
  config = AutoConfig.from_pretrained(model_path)
@@ -31,16 +29,15 @@ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
 
32
  # SR model
33
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
34
- sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
35
 
36
  @torch.inference_mode()
37
- @spaces.GPU(duration=120)
38
- # Multimodal Understanding function
39
  def multimodal_understanding(image, question, seed, top_p, temperature):
40
  # Clear CUDA cache before generating
41
  torch.cuda.empty_cache()
42
 
43
- # set seed
44
  torch.manual_seed(seed)
45
  np.random.seed(seed)
46
  torch.cuda.manual_seed(seed)
@@ -54,12 +51,11 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
54
  {"role": "<|Assistant|>", "content": ""},
55
  ]
56
 
57
- pil_images = [Image.fromarray(image)]
58
  prepare_inputs = vl_chat_processor(
59
  conversations=conversation, images=pil_images, force_batchify=True
60
  ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
61
 
62
-
63
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
64
 
65
  outputs = vl_gpt.language_model.generate(
@@ -78,16 +74,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
78
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
79
  return answer
80
 
81
-
82
- def generate(input_ids,
83
- width,
84
- height,
85
- temperature: float = 1,
86
- parallel_size: int = 5,
87
- cfg_weight: float = 5,
88
- image_token_num_per_image: int = 576,
89
- patch_size: int = 16):
90
- # Clear CUDA cache before generating
91
  torch.cuda.empty_cache()
92
 
93
  tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
@@ -102,8 +91,8 @@ def generate(input_ids,
102
  for i in range(image_token_num_per_image):
103
  with torch.no_grad():
104
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
105
- use_cache=True,
106
- past_key_values=pkv)
107
  pkv = outputs.past_key_values
108
  hidden_states = outputs.last_hidden_state
109
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
@@ -117,34 +106,22 @@ def generate(input_ids,
117
 
118
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
119
  inputs_embeds = img_embeds.unsqueeze(dim=1)
120
-
121
 
122
-
123
  patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
124
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
125
-
126
  return generated_tokens.to(dtype=torch.int), patches
127
 
128
  def unpack(dec, width, height, parallel_size=5):
129
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
130
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
131
-
132
  visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
133
  visual_img[:, :, :] = dec
134
-
135
  return visual_img
136
 
137
-
138
-
139
  @torch.inference_mode()
140
- @spaces.GPU(duration=120) # Specify a duration to avoid timeout
141
- def generate_image(prompt,
142
- seed=None,
143
- guidance=5,
144
- t2i_temperature=1.0):
145
- # Clear CUDA cache and avoid tracking gradients
146
  torch.cuda.empty_cache()
147
- # Set the seed for reproducible results
148
  if seed is not None:
149
  torch.manual_seed(seed)
150
  torch.cuda.manual_seed(seed)
@@ -156,11 +133,12 @@ def generate_image(prompt,
156
  with torch.no_grad():
157
  messages = [{'role': '<|User|>', 'content': prompt},
158
  {'role': '<|Assistant|>', 'content': ''}]
159
- text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
160
- sft_format=vl_chat_processor.sft_format,
161
- system_prompt='')
 
 
162
  text = text + vl_chat_processor.image_start_tag
163
-
164
  input_ids = torch.LongTensor(tokenizer.encode(text))
165
  output, patches = generate(input_ids,
166
  width // 16 * 16,
@@ -173,95 +151,125 @@ def generate_image(prompt,
173
  height // 16 * 16,
174
  parallel_size=parallel_size)
175
 
176
- # return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)]
177
  stime = time.time()
178
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
179
  print(f'upsample time: {time.time() - stime}')
180
  return ret_images
181
 
182
-
183
  @spaces.GPU(duration=60)
184
  def image_upsample(img: Image.Image) -> Image.Image:
185
  if img is None:
186
  raise Exception("Image not uploaded")
187
-
188
  width, height = img.size
189
-
190
  if width >= 5000 or height >= 5000:
191
  raise Exception("The image is too large.")
192
-
193
  global sr_model
194
  result = sr_model.predict(img.convert('RGB'))
195
  return result
196
-
197
 
198
- # Gradio interface
199
- with gr.Blocks() as demo:
200
- gr.Markdown(value="# Multimodal Understanding")
201
- with gr.Row():
202
- image_input = gr.Image()
203
- with gr.Column():
204
- question_input = gr.Textbox(label="Question")
205
- und_seed_input = gr.Number(label="Seed", precision=0, value=42)
206
- top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
207
- temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
208
-
209
- understanding_button = gr.Button("Chat")
210
- understanding_output = gr.Textbox(label="Response")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- examples_inpainting = gr.Examples(
213
- label="Multimodal Understanding examples",
214
- examples=[
215
- [
216
- "explain this meme",
217
- "doge.png",
218
- ],
219
- [
220
- "Convert the formula into latex code.",
221
- "equation.png",
222
- ],
223
- ],
224
- inputs=[question_input, image_input],
225
- )
226
-
227
 
228
- gr.Markdown(value="# Text-to-Image Generation")
229
-
230
-
231
-
232
- with gr.Row():
233
- cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
234
- t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="temperature")
235
-
236
- prompt_input = gr.Textbox(label="Prompt. (Prompt in more detail can help produce better images!")
237
- seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
238
-
239
- generation_button = gr.Button("Generate Images")
240
-
241
- image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
242
-
243
- examples_t2i = gr.Examples(
244
- label="Text to image generation examples.",
245
- examples=[
246
- "Master shifu racoon wearing drip attire as a street gangster.",
247
- "The face of a beautiful girl",
248
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
249
- "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting,immortal,fluffy, shiny mane,Petals,fairyism,unreal engine 5 and Octane Render,highly detailed, photorealistic, cinematic, natural colors.",
250
- "The image features an intricately designed eye set against a circular backdrop adorned with ornate swirl patterns that evoke both realism and surrealism. At the center of attention is a strikingly vivid blue iris surrounded by delicate veins radiating outward from the pupil to create depth and intensity. The eyelashes are long and dark, casting subtle shadows on the skin around them which appears smooth yet slightly textured as if aged or weathered over time.\n\nAbove the eye, there's a stone-like structure resembling part of classical architecture, adding layers of mystery and timeless elegance to the composition. This architectural element contrasts sharply but harmoniously with the organic curves surrounding it. Below the eye lies another decorative motif reminiscent of baroque artistry, further enhancing the overall sense of eternity encapsulated within each meticulously crafted detail. \n\nOverall, the atmosphere exudes a mysterious aura intertwined seamlessly with elements suggesting timelessness, achieved through the juxtaposition of realistic textures and surreal artistic flourishes. Each component\u2014from the intricate designs framing the eye to the ancient-looking stone piece above\u2014contributes uniquely towards creating a visually captivating tableau imbued with enigmatic allure.",
251
- ],
252
- inputs=prompt_input,
253
- )
254
-
255
- understanding_button.click(
256
- multimodal_understanding,
257
- inputs=[image_input, question_input, und_seed_input, top_p, temperature],
258
- outputs=understanding_output
259
- )
260
-
261
- generation_button.click(
262
- fn=generate_image,
263
- inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
264
- outputs=image_output
265
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- demo.launch(share=True)
 
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
 
7
  import numpy as np
8
  import os
9
  import time
10
  from Upsample import RealESRGAN
11
  import spaces # Import spaces for ZeroGPU compatibility
12
 
 
13
  # Load model and processor
14
  model_path = "deepseek-ai/Janus-Pro-7B"
15
  config = AutoConfig.from_pretrained(model_path)
 
29
 
30
  # SR model
31
  sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
32
+ sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
33
 
34
  @torch.inference_mode()
35
+ @spaces.GPU(duration=120)
 
36
  def multimodal_understanding(image, question, seed, top_p, temperature):
37
  # Clear CUDA cache before generating
38
  torch.cuda.empty_cache()
39
 
40
+ # Set seed
41
  torch.manual_seed(seed)
42
  np.random.seed(seed)
43
  torch.cuda.manual_seed(seed)
 
51
  {"role": "<|Assistant|>", "content": ""},
52
  ]
53
 
54
+ pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image]
55
  prepare_inputs = vl_chat_processor(
56
  conversations=conversation, images=pil_images, force_batchify=True
57
  ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
58
 
 
59
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
60
 
61
  outputs = vl_gpt.language_model.generate(
 
74
  answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
75
  return answer
76
 
77
+ def generate(input_ids, width, height, temperature: float = 1,
78
+ parallel_size: int = 5, cfg_weight: float = 5,
79
+ image_token_num_per_image: int = 576, patch_size: int = 16):
 
 
 
 
 
 
 
80
  torch.cuda.empty_cache()
81
 
82
  tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
 
91
  for i in range(image_token_num_per_image):
92
  with torch.no_grad():
93
  outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
94
+ use_cache=True,
95
+ past_key_values=pkv)
96
  pkv = outputs.past_key_values
97
  hidden_states = outputs.last_hidden_state
98
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
106
 
107
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
108
  inputs_embeds = img_embeds.unsqueeze(dim=1)
 
109
 
 
110
  patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
111
+ shape=[parallel_size, 8, width // patch_size, height // patch_size])
 
112
  return generated_tokens.to(dtype=torch.int), patches
113
 
114
  def unpack(dec, width, height, parallel_size=5):
115
  dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
116
  dec = np.clip((dec + 1) / 2 * 255, 0, 255)
 
117
  visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
118
  visual_img[:, :, :] = dec
 
119
  return visual_img
120
 
 
 
121
  @torch.inference_mode()
122
+ @spaces.GPU(duration=120)
123
+ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
 
 
 
 
124
  torch.cuda.empty_cache()
 
125
  if seed is not None:
126
  torch.manual_seed(seed)
127
  torch.cuda.manual_seed(seed)
 
133
  with torch.no_grad():
134
  messages = [{'role': '<|User|>', 'content': prompt},
135
  {'role': '<|Assistant|>', 'content': ''}]
136
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
137
+ conversations=messages,
138
+ sft_format=vl_chat_processor.sft_format,
139
+ system_prompt=''
140
+ )
141
  text = text + vl_chat_processor.image_start_tag
 
142
  input_ids = torch.LongTensor(tokenizer.encode(text))
143
  output, patches = generate(input_ids,
144
  width // 16 * 16,
 
151
  height // 16 * 16,
152
  parallel_size=parallel_size)
153
 
 
154
  stime = time.time()
155
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
156
  print(f'upsample time: {time.time() - stime}')
157
  return ret_images
158
 
 
159
  @spaces.GPU(duration=60)
160
  def image_upsample(img: Image.Image) -> Image.Image:
161
  if img is None:
162
  raise Exception("Image not uploaded")
 
163
  width, height = img.size
 
164
  if width >= 5000 or height >= 5000:
165
  raise Exception("The image is too large.")
 
166
  global sr_model
167
  result = sr_model.predict(img.convert('RGB'))
168
  return result
 
169
 
170
+ # Custom CSS for a sleek, modern and highly readable interface
171
+ custom_css = """
172
+ body {
173
+ background: #f0f2f5;
174
+ font-family: 'Segoe UI', sans-serif;
175
+ color: #333;
176
+ }
177
+ h1, h2, h3 {
178
+ font-weight: 600;
179
+ }
180
+ .gradio-container {
181
+ padding: 20px;
182
+ }
183
+ header {
184
+ text-align: center;
185
+ padding: 20px;
186
+ margin-bottom: 20px;
187
+ }
188
+ header h1 {
189
+ font-size: 3em;
190
+ color: #2c3e50;
191
+ }
192
+ .gr-button {
193
+ background-color: #3498db !important;
194
+ color: #fff !important;
195
+ border: none !important;
196
+ padding: 10px 20px !important;
197
+ border-radius: 5px !important;
198
+ font-size: 1em !important;
199
+ }
200
+ .gr-button:hover {
201
+ background-color: #2980b9 !important;
202
+ }
203
+ .gr-input, .gr-slider, .gr-number, .gr-textbox {
204
+ border-radius: 5px;
205
+ }
206
+ .gr-gallery-item {
207
+ border-radius: 10px;
208
+ overflow: hidden;
209
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
210
+ }
211
+ """
212
 
213
+ # Gradio Interface
214
+ with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo:
215
+ with gr.Column(variant="panel"):
216
+ gr.Markdown("<header><h1>Janus Multimodal Demo</h1></header>")
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ with gr.Tabs():
219
+ with gr.TabItem("Multimodal Understanding"):
220
+ gr.Markdown("### Chat with Images")
221
+ with gr.Row():
222
+ image_input = gr.Image(label="Upload Image", type="numpy", tool="editor")
223
+ with gr.Column():
224
+ question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4)
225
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
226
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p")
227
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature")
228
+ understanding_button = gr.Button("Chat", elem_id="understanding-button")
229
+ understanding_output = gr.Textbox(label="Response", lines=6)
230
+ with gr.Accordion("Examples", open=False):
231
+ gr.Examples(
232
+ label="Multimodal Understanding Examples",
233
+ examples=[
234
+ ["explain this meme", "doge.png"],
235
+ ["Convert the formula into LaTeX code.", "equation.png"],
236
+ ],
237
+ inputs=[question_input, image_input],
238
+ )
239
+ understanding_button.click(
240
+ multimodal_understanding,
241
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
242
+ outputs=understanding_output,
243
+ )
244
+
245
+ with gr.TabItem("Text-to-Image Generation"):
246
+ gr.Markdown("### Generate Images from Text")
247
+ with gr.Row():
248
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4)
249
+ with gr.Row():
250
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234)
251
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight")
252
+ t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature")
253
+ generation_button = gr.Button("Generate Images", elem_id="generation-button")
254
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
255
+ with gr.Accordion("Examples", open=False):
256
+ gr.Examples(
257
+ label="Text-to-Image Examples",
258
+ examples=[
259
+ "Master shifu racoon wearing drip attire as a street gangster.",
260
+ "The face of a beautiful girl",
261
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
262
+ "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.",
263
+ "An intricately designed eye with ornate swirl patterns, vivid blue iris, and classical architectural motifs, exuding mysterious timelessness."
264
+ ],
265
+ inputs=prompt_input,
266
+ )
267
+ generation_button.click(
268
+ fn=generate_image,
269
+ inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature],
270
+ outputs=image_output,
271
+ )
272
+
273
+ gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
274
 
275
+ demo.launch(share=True)