ginipick commited on
Commit
827cb17
ยท
verified ยท
1 Parent(s): c26cedb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -26
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
@@ -9,15 +9,32 @@ import os
9
  import time
10
  from Upsample import RealESRGAN
11
  import spaces # Import spaces for ZeroGPU compatibility
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Load model and processor
14
  model_path = "deepseek-ai/Janus-Pro-7B"
15
  config = AutoConfig.from_pretrained(model_path)
16
  language_config = config.language_config
17
  language_config._attn_implementation = 'eager'
18
- vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
19
- language_config=language_config,
20
- trust_remote_code=True)
 
 
21
  if torch.cuda.is_available():
22
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
23
  else:
@@ -28,16 +45,14 @@ tokenizer = vl_chat_processor.tokenizer
28
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
 
30
  # SR model
31
- sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
32
  sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
33
 
34
  @torch.inference_mode()
35
  @spaces.GPU(duration=120)
36
  def multimodal_understanding(image, question, seed, top_p, temperature):
37
- # Clear CUDA cache before generating
38
  torch.cuda.empty_cache()
39
-
40
- # Set seed
41
  torch.manual_seed(seed)
42
  np.random.seed(seed)
43
  torch.cuda.manual_seed(seed)
@@ -90,9 +105,11 @@ def generate(input_ids, width, height, temperature: float = 1,
90
  pkv = None
91
  for i in range(image_token_num_per_image):
92
  with torch.no_grad():
93
- outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
94
- use_cache=True,
95
- past_key_values=pkv)
 
 
96
  pkv = outputs.past_key_values
97
  hidden_states = outputs.last_hidden_state
98
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
@@ -107,8 +124,10 @@ def generate(input_ids, width, height, temperature: float = 1,
107
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
108
  inputs_embeds = img_embeds.unsqueeze(dim=1)
109
 
110
- patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
111
- shape=[parallel_size, 8, width // patch_size, height // patch_size])
 
 
112
  return generated_tokens.to(dtype=torch.int), patches
113
 
114
  def unpack(dec, width, height, parallel_size=5):
@@ -121,6 +140,9 @@ def unpack(dec, width, height, parallel_size=5):
121
  @torch.inference_mode()
122
  @spaces.GPU(duration=120)
123
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
 
 
 
124
  torch.cuda.empty_cache()
125
  if seed is not None:
126
  torch.manual_seed(seed)
@@ -140,16 +162,20 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
140
  )
141
  text = text + vl_chat_processor.image_start_tag
142
  input_ids = torch.LongTensor(tokenizer.encode(text))
143
- output, patches = generate(input_ids,
144
- width // 16 * 16,
145
- height // 16 * 16,
146
- cfg_weight=guidance,
147
- parallel_size=parallel_size,
148
- temperature=t2i_temperature)
149
- images = unpack(patches,
150
- width // 16 * 16,
151
- height // 16 * 16,
152
- parallel_size=parallel_size)
 
 
 
 
153
 
154
  stime = time.time()
155
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
@@ -231,8 +257,7 @@ with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo:
231
  gr.Examples(
232
  label="Multimodal Understanding Examples",
233
  examples=[
234
- ["explain this meme", "doge.png"],
235
- ["์ด ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•ด์ค˜", "korean_example.png"]
236
  ],
237
  inputs=[question_input, image_input],
238
  )
@@ -273,4 +298,3 @@ with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo:
273
  gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
274
 
275
  demo.launch(share=True)
276
-
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM, pipeline as translation_pipeline
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
 
9
  import time
10
  from Upsample import RealESRGAN
11
  import spaces # Import spaces for ZeroGPU compatibility
12
+ import re
13
+
14
+ # ๋ฒˆ์—ญ ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (ํ•œ๊ธ€ โ†’ ์˜์–ด)
15
+ translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
16
+
17
+ def translate_if_korean(prompt: str) -> str:
18
+ """ํ”„๋กฌํ”„ํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉด ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
19
+ if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ๊ฐ€-ํžฃ]', prompt):
20
+ try:
21
+ translation = translator(prompt)[0]['translation_text']
22
+ return translation
23
+ except Exception as e:
24
+ print(f"Translation error: {e}")
25
+ return prompt
26
+ return prompt
27
 
28
  # Load model and processor
29
  model_path = "deepseek-ai/Janus-Pro-7B"
30
  config = AutoConfig.from_pretrained(model_path)
31
  language_config = config.language_config
32
  language_config._attn_implementation = 'eager'
33
+ vl_gpt = AutoModelForCausalLM.from_pretrained(
34
+ model_path,
35
+ language_config=language_config,
36
+ trust_remote_code=True
37
+ )
38
  if torch.cuda.is_available():
39
  vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
40
  else:
 
45
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
46
 
47
  # SR model
48
+ sr_model = RealESRGAN(torch.device(cuda_device), scale=2)
49
  sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
50
 
51
  @torch.inference_mode()
52
  @spaces.GPU(duration=120)
53
  def multimodal_understanding(image, question, seed, top_p, temperature):
54
+ # (์ƒ๋žต) ๊ธฐ์กด multimodal ์ดํ•ด ํ•จ์ˆ˜ ๋‚ด์šฉ ๊ทธ๋Œ€๋กœ...
55
  torch.cuda.empty_cache()
 
 
56
  torch.manual_seed(seed)
57
  np.random.seed(seed)
58
  torch.cuda.manual_seed(seed)
 
105
  pkv = None
106
  for i in range(image_token_num_per_image):
107
  with torch.no_grad():
108
+ outputs = vl_gpt.language_model.model(
109
+ inputs_embeds=inputs_embeds,
110
+ use_cache=True,
111
+ past_key_values=pkv
112
+ )
113
  pkv = outputs.past_key_values
114
  hidden_states = outputs.last_hidden_state
115
  logits = vl_gpt.gen_head(hidden_states[:, -1, :])
 
124
  img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
125
  inputs_embeds = img_embeds.unsqueeze(dim=1)
126
 
127
+ patches = vl_gpt.gen_vision_model.decode_code(
128
+ generated_tokens.to(dtype=torch.int),
129
+ shape=[parallel_size, 8, width // patch_size, height // patch_size]
130
+ )
131
  return generated_tokens.to(dtype=torch.int), patches
132
 
133
  def unpack(dec, width, height, parallel_size=5):
 
140
  @torch.inference_mode()
141
  @spaces.GPU(duration=120)
142
  def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
143
+ # ๋ฒˆ์—ญ: ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ์— ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉด ์˜์–ด๋กœ ๋ณ€ํ™˜
144
+ prompt = translate_if_korean(prompt)
145
+
146
  torch.cuda.empty_cache()
147
  if seed is not None:
148
  torch.manual_seed(seed)
 
162
  )
163
  text = text + vl_chat_processor.image_start_tag
164
  input_ids = torch.LongTensor(tokenizer.encode(text))
165
+ output, patches = generate(
166
+ input_ids,
167
+ width // 16 * 16,
168
+ height // 16 * 16,
169
+ cfg_weight=guidance,
170
+ parallel_size=parallel_size,
171
+ temperature=t2i_temperature
172
+ )
173
+ images = unpack(
174
+ patches,
175
+ width // 16 * 16,
176
+ height // 16 * 16,
177
+ parallel_size=parallel_size
178
+ )
179
 
180
  stime = time.time()
181
  ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
 
257
  gr.Examples(
258
  label="Multimodal Understanding Examples",
259
  examples=[
260
+ ["explain this meme", "doge.png"]
 
261
  ],
262
  inputs=[question_input, image_input],
263
  )
 
298
  gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
299
 
300
  demo.launch(share=True)