Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoConfig, AutoModelForCausalLM
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from janus.utils.io import load_pil_images
|
6 |
from PIL import Image
|
@@ -9,15 +9,32 @@ import os
|
|
9 |
import time
|
10 |
from Upsample import RealESRGAN
|
11 |
import spaces # Import spaces for ZeroGPU compatibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Load model and processor
|
14 |
model_path = "deepseek-ai/Janus-Pro-7B"
|
15 |
config = AutoConfig.from_pretrained(model_path)
|
16 |
language_config = config.language_config
|
17 |
language_config._attn_implementation = 'eager'
|
18 |
-
vl_gpt = AutoModelForCausalLM.from_pretrained(
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
if torch.cuda.is_available():
|
22 |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
23 |
else:
|
@@ -28,16 +45,14 @@ tokenizer = vl_chat_processor.tokenizer
|
|
28 |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
29 |
|
30 |
# SR model
|
31 |
-
sr_model = RealESRGAN(torch.device(
|
32 |
sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
|
33 |
|
34 |
@torch.inference_mode()
|
35 |
@spaces.GPU(duration=120)
|
36 |
def multimodal_understanding(image, question, seed, top_p, temperature):
|
37 |
-
#
|
38 |
torch.cuda.empty_cache()
|
39 |
-
|
40 |
-
# Set seed
|
41 |
torch.manual_seed(seed)
|
42 |
np.random.seed(seed)
|
43 |
torch.cuda.manual_seed(seed)
|
@@ -90,9 +105,11 @@ def generate(input_ids, width, height, temperature: float = 1,
|
|
90 |
pkv = None
|
91 |
for i in range(image_token_num_per_image):
|
92 |
with torch.no_grad():
|
93 |
-
outputs = vl_gpt.language_model.model(
|
94 |
-
|
95 |
-
|
|
|
|
|
96 |
pkv = outputs.past_key_values
|
97 |
hidden_states = outputs.last_hidden_state
|
98 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
@@ -107,8 +124,10 @@ def generate(input_ids, width, height, temperature: float = 1,
|
|
107 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
108 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
109 |
|
110 |
-
patches = vl_gpt.gen_vision_model.decode_code(
|
111 |
-
|
|
|
|
|
112 |
return generated_tokens.to(dtype=torch.int), patches
|
113 |
|
114 |
def unpack(dec, width, height, parallel_size=5):
|
@@ -121,6 +140,9 @@ def unpack(dec, width, height, parallel_size=5):
|
|
121 |
@torch.inference_mode()
|
122 |
@spaces.GPU(duration=120)
|
123 |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
|
|
|
|
|
|
|
124 |
torch.cuda.empty_cache()
|
125 |
if seed is not None:
|
126 |
torch.manual_seed(seed)
|
@@ -140,16 +162,20 @@ def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
|
|
140 |
)
|
141 |
text = text + vl_chat_processor.image_start_tag
|
142 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
143 |
-
output, patches = generate(
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
|
154 |
stime = time.time()
|
155 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
@@ -231,8 +257,7 @@ with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo:
|
|
231 |
gr.Examples(
|
232 |
label="Multimodal Understanding Examples",
|
233 |
examples=[
|
234 |
-
["explain this meme", "doge.png"]
|
235 |
-
["์ด ์ด๋ฏธ์ง๋ฅผ ์ค๋ช
ํด์ค", "korean_example.png"]
|
236 |
],
|
237 |
inputs=[question_input, image_input],
|
238 |
)
|
@@ -273,4 +298,3 @@ with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo:
|
|
273 |
gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
|
274 |
|
275 |
demo.launch(share=True)
|
276 |
-
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoConfig, AutoModelForCausalLM, pipeline as translation_pipeline
|
4 |
from janus.models import MultiModalityCausalLM, VLChatProcessor
|
5 |
from janus.utils.io import load_pil_images
|
6 |
from PIL import Image
|
|
|
9 |
import time
|
10 |
from Upsample import RealESRGAN
|
11 |
import spaces # Import spaces for ZeroGPU compatibility
|
12 |
+
import re
|
13 |
+
|
14 |
+
# ๋ฒ์ญ ํ์ดํ๋ผ์ธ ์ด๊ธฐํ (ํ๊ธ โ ์์ด)
|
15 |
+
translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
16 |
+
|
17 |
+
def translate_if_korean(prompt: str) -> str:
|
18 |
+
"""ํ๋กฌํํธ์ ํ๊ธ์ด ํฌํจ๋์ด ์์ผ๋ฉด ์์ด๋ก ๋ฒ์ญ"""
|
19 |
+
if re.search(r'[ใฑ-ใ
ใ
-ใ
ฃ๊ฐ-ํฃ]', prompt):
|
20 |
+
try:
|
21 |
+
translation = translator(prompt)[0]['translation_text']
|
22 |
+
return translation
|
23 |
+
except Exception as e:
|
24 |
+
print(f"Translation error: {e}")
|
25 |
+
return prompt
|
26 |
+
return prompt
|
27 |
|
28 |
# Load model and processor
|
29 |
model_path = "deepseek-ai/Janus-Pro-7B"
|
30 |
config = AutoConfig.from_pretrained(model_path)
|
31 |
language_config = config.language_config
|
32 |
language_config._attn_implementation = 'eager'
|
33 |
+
vl_gpt = AutoModelForCausalLM.from_pretrained(
|
34 |
+
model_path,
|
35 |
+
language_config=language_config,
|
36 |
+
trust_remote_code=True
|
37 |
+
)
|
38 |
if torch.cuda.is_available():
|
39 |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
|
40 |
else:
|
|
|
45 |
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
46 |
|
47 |
# SR model
|
48 |
+
sr_model = RealESRGAN(torch.device(cuda_device), scale=2)
|
49 |
sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False)
|
50 |
|
51 |
@torch.inference_mode()
|
52 |
@spaces.GPU(duration=120)
|
53 |
def multimodal_understanding(image, question, seed, top_p, temperature):
|
54 |
+
# (์๋ต) ๊ธฐ์กด multimodal ์ดํด ํจ์ ๋ด์ฉ ๊ทธ๋๋ก...
|
55 |
torch.cuda.empty_cache()
|
|
|
|
|
56 |
torch.manual_seed(seed)
|
57 |
np.random.seed(seed)
|
58 |
torch.cuda.manual_seed(seed)
|
|
|
105 |
pkv = None
|
106 |
for i in range(image_token_num_per_image):
|
107 |
with torch.no_grad():
|
108 |
+
outputs = vl_gpt.language_model.model(
|
109 |
+
inputs_embeds=inputs_embeds,
|
110 |
+
use_cache=True,
|
111 |
+
past_key_values=pkv
|
112 |
+
)
|
113 |
pkv = outputs.past_key_values
|
114 |
hidden_states = outputs.last_hidden_state
|
115 |
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
|
|
|
124 |
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
|
125 |
inputs_embeds = img_embeds.unsqueeze(dim=1)
|
126 |
|
127 |
+
patches = vl_gpt.gen_vision_model.decode_code(
|
128 |
+
generated_tokens.to(dtype=torch.int),
|
129 |
+
shape=[parallel_size, 8, width // patch_size, height // patch_size]
|
130 |
+
)
|
131 |
return generated_tokens.to(dtype=torch.int), patches
|
132 |
|
133 |
def unpack(dec, width, height, parallel_size=5):
|
|
|
140 |
@torch.inference_mode()
|
141 |
@spaces.GPU(duration=120)
|
142 |
def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0):
|
143 |
+
# ๋ฒ์ญ: ์
๋ ฅ ํ๋กฌํํธ์ ํ๊ธ์ด ํฌํจ๋์ด ์์ผ๋ฉด ์์ด๋ก ๋ณํ
|
144 |
+
prompt = translate_if_korean(prompt)
|
145 |
+
|
146 |
torch.cuda.empty_cache()
|
147 |
if seed is not None:
|
148 |
torch.manual_seed(seed)
|
|
|
162 |
)
|
163 |
text = text + vl_chat_processor.image_start_tag
|
164 |
input_ids = torch.LongTensor(tokenizer.encode(text))
|
165 |
+
output, patches = generate(
|
166 |
+
input_ids,
|
167 |
+
width // 16 * 16,
|
168 |
+
height // 16 * 16,
|
169 |
+
cfg_weight=guidance,
|
170 |
+
parallel_size=parallel_size,
|
171 |
+
temperature=t2i_temperature
|
172 |
+
)
|
173 |
+
images = unpack(
|
174 |
+
patches,
|
175 |
+
width // 16 * 16,
|
176 |
+
height // 16 * 16,
|
177 |
+
parallel_size=parallel_size
|
178 |
+
)
|
179 |
|
180 |
stime = time.time()
|
181 |
ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)]
|
|
|
257 |
gr.Examples(
|
258 |
label="Multimodal Understanding Examples",
|
259 |
examples=[
|
260 |
+
["explain this meme", "doge.png"]
|
|
|
261 |
],
|
262 |
inputs=[question_input, image_input],
|
263 |
)
|
|
|
298 |
gr.Markdown("<footer style='text-align:center; padding:20px 0;'>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
|
299 |
|
300 |
demo.launch(share=True)
|
|