Update joycaption.py
Browse files- joycaption.py +132 -130
joycaption.py
CHANGED
@@ -267,144 +267,146 @@ load_text_model(MODEL_PATH, None, LOAD_IN_NF4, True)
|
|
267 |
|
268 |
@spaces.GPU()
|
269 |
@torch.inference_mode()
|
270 |
-
@demo.queue()
|
271 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
|
272 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
|
273 |
-
global tokenizer, text_model, image_adapter, pixtral_model, pixtral_processor, text_model_client, use_inference_client
|
274 |
-
torch.cuda.empty_cache()
|
275 |
-
gc.collect()
|
276 |
-
|
277 |
-
# 'any' means no length specified
|
278 |
-
length = None if caption_length == "any" else caption_length
|
279 |
-
|
280 |
-
if isinstance(length, str):
|
281 |
-
try:
|
282 |
-
length = int(length)
|
283 |
-
except ValueError:
|
284 |
-
pass
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
map_idx = 1
|
291 |
-
elif isinstance(length, str):
|
292 |
-
map_idx = 2
|
293 |
-
else:
|
294 |
-
raise ValueError(f"Invalid caption length: {length}")
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
# Add extra options
|
299 |
-
if len(extra_options) > 0:
|
300 |
-
prompt_str += " " + " ".join(extra_options)
|
301 |
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
# Pixtral
|
312 |
-
if model_name in PIXTRAL_PATHS:
|
313 |
-
print(f"pixtral_model: {type(pixtral_model)}") #
|
314 |
-
print(f"pixtral_processor: {type(pixtral_processor)}") #
|
315 |
-
input_images = [input_image.convert("RGB")]
|
316 |
-
input_prompt = "[INST]Caption this image:\n[IMG][/INST]"
|
317 |
-
inputs = pixtral_processor(images=input_images, text=input_prompt, return_tensors="pt").to(device)
|
318 |
-
generate_ids = pixtral_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
319 |
-
output = pixtral_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
320 |
-
return input_prompt, output.strip()
|
321 |
-
|
322 |
-
# Preprocess image
|
323 |
-
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
324 |
-
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
325 |
-
image = input_image.resize((384, 384), Image.LANCZOS)
|
326 |
-
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
327 |
-
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
328 |
-
pixel_values = pixel_values.to(device)
|
329 |
-
|
330 |
-
# Embed image
|
331 |
-
# This results in Batch x Image Tokens x Features
|
332 |
-
with torch.amp.autocast_mode.autocast(device, enabled=True):
|
333 |
-
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
334 |
-
image_features = vision_outputs.hidden_states
|
335 |
-
embedded_images = image_adapter(image_features)
|
336 |
-
embedded_images = embedded_images.to(device)
|
337 |
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
#
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
], dim=1).to(device)
|
377 |
-
|
378 |
-
input_ids = torch.cat([
|
379 |
-
convo_tokens[:preamble_len].unsqueeze(0),
|
380 |
-
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
|
381 |
-
convo_tokens[preamble_len:].unsqueeze(0),
|
382 |
-
], dim=1).to(device)
|
383 |
-
attention_mask = torch.ones_like(input_ids)
|
384 |
-
|
385 |
-
# Debugging
|
386 |
-
#print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
387 |
-
|
388 |
-
text_model.to(device)
|
389 |
-
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
|
390 |
-
do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
|
391 |
-
|
392 |
-
# Trim off the prompt
|
393 |
-
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
394 |
-
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
395 |
-
generate_ids = generate_ids[:, :-1]
|
396 |
-
|
397 |
-
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
398 |
-
|
399 |
-
return prompt_str, caption.strip()
|
400 |
-
|
401 |
-
|
402 |
-
# https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
|
403 |
-
# https://github.com/huggingface/transformers/issues/6535
|
404 |
-
# https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
|
405 |
-
# https://github.com/ggerganov/llama.cpp/discussions/7712
|
406 |
-
# https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
|
407 |
-
# https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
|
408 |
|
409 |
|
410 |
def is_repo_name(s):
|
|
|
267 |
|
268 |
@spaces.GPU()
|
269 |
@torch.inference_mode()
|
|
|
270 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
|
271 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
try:
|
274 |
+
global tokenizer, text_model, image_adapter, pixtral_model, pixtral_processor, text_model_client, use_inference_client
|
275 |
+
torch.cuda.empty_cache()
|
276 |
+
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
+
# 'any' means no length specified
|
279 |
+
length = None if caption_length == "any" else caption_length
|
|
|
|
|
|
|
280 |
|
281 |
+
if isinstance(length, str):
|
282 |
+
try:
|
283 |
+
length = int(length)
|
284 |
+
except ValueError:
|
285 |
+
pass
|
286 |
+
|
287 |
+
# Build prompt
|
288 |
+
if length is None:
|
289 |
+
map_idx = 0
|
290 |
+
elif isinstance(length, int):
|
291 |
+
map_idx = 1
|
292 |
+
elif isinstance(length, str):
|
293 |
+
map_idx = 2
|
294 |
+
else:
|
295 |
+
raise ValueError(f"Invalid caption length: {length}")
|
296 |
+
|
297 |
+
prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
|
298 |
+
|
299 |
+
# Add extra options
|
300 |
+
if len(extra_options) > 0:
|
301 |
+
prompt_str += " " + " ".join(extra_options)
|
302 |
+
|
303 |
+
# Add name, length, word_count
|
304 |
+
prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
|
305 |
|
306 |
+
if custom_prompt.strip() != "":
|
307 |
+
prompt_str = custom_prompt.strip()
|
308 |
+
|
309 |
+
# For debugging
|
310 |
+
print(f"Prompt: {prompt_str}")
|
311 |
+
|
312 |
+
# Pixtral
|
313 |
+
if model_name in PIXTRAL_PATHS:
|
314 |
+
print(f"pixtral_model: {type(pixtral_model)}") #
|
315 |
+
print(f"pixtral_processor: {type(pixtral_processor)}") #
|
316 |
+
input_images = [input_image.convert("RGB")]
|
317 |
+
input_prompt = "[INST]Caption this image:\n[IMG][/INST]"
|
318 |
+
inputs = pixtral_processor(images=input_images, text=input_prompt, return_tensors="pt").to(device)
|
319 |
+
generate_ids = pixtral_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
320 |
+
output = pixtral_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
321 |
+
return input_prompt, output.strip()
|
322 |
+
|
323 |
+
# Preprocess image
|
324 |
+
# NOTE: I found the default processor for so400M to have worse results than just using PIL directly
|
325 |
+
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
326 |
+
image = input_image.resize((384, 384), Image.LANCZOS)
|
327 |
+
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
328 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
329 |
+
pixel_values = pixel_values.to(device)
|
330 |
+
|
331 |
+
# Embed image
|
332 |
+
# This results in Batch x Image Tokens x Features
|
333 |
+
with torch.amp.autocast_mode.autocast(device, enabled=True):
|
334 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
335 |
+
image_features = vision_outputs.hidden_states
|
336 |
+
embedded_images = image_adapter(image_features)
|
337 |
+
embedded_images = embedded_images.to(device)
|
338 |
+
|
339 |
+
# Build the conversation
|
340 |
+
convo = [
|
341 |
+
{
|
342 |
+
"role": "system",
|
343 |
+
"content": "You are a helpful image captioner.",
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"role": "user",
|
347 |
+
"content": prompt_str,
|
348 |
+
},
|
349 |
+
]
|
350 |
+
|
351 |
+
# Format the conversation
|
352 |
+
convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
|
353 |
+
assert isinstance(convo_string, str)
|
354 |
+
|
355 |
+
# Tokenize the conversation
|
356 |
+
# prompt_str is tokenized separately so we can do the calculations below
|
357 |
+
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
|
358 |
+
prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
|
359 |
+
assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
|
360 |
+
convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
|
361 |
+
prompt_tokens = prompt_tokens.squeeze(0)
|
362 |
+
|
363 |
+
# Calculate where to inject the image
|
364 |
+
eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
|
365 |
+
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
366 |
+
|
367 |
+
preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
|
368 |
|
369 |
+
# Embed the tokens
|
370 |
+
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
+
# Construct the input
|
373 |
+
input_embeds = torch.cat([
|
374 |
+
convo_embeds[:, :preamble_len], # Part before the prompt
|
375 |
+
embedded_images.to(dtype=convo_embeds.dtype), # Image
|
376 |
+
convo_embeds[:, preamble_len:], # The prompt and anything after it
|
377 |
+
], dim=1).to(device)
|
378 |
+
|
379 |
+
input_ids = torch.cat([
|
380 |
+
convo_tokens[:preamble_len].unsqueeze(0),
|
381 |
+
torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
|
382 |
+
convo_tokens[preamble_len:].unsqueeze(0),
|
383 |
+
], dim=1).to(device)
|
384 |
+
attention_mask = torch.ones_like(input_ids)
|
385 |
+
|
386 |
+
# Debugging
|
387 |
+
#print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
|
388 |
+
|
389 |
+
text_model.to(device)
|
390 |
+
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
|
391 |
+
do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
|
392 |
+
|
393 |
+
# Trim off the prompt
|
394 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
395 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
396 |
+
generate_ids = generate_ids[:, :-1]
|
397 |
+
|
398 |
+
caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
399 |
+
|
400 |
+
return prompt_str, caption.strip()
|
401 |
+
|
402 |
+
except Exception as e:
|
403 |
+
print(e)
|
404 |
+
# https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
|
405 |
+
# https://github.com/huggingface/transformers/issues/6535
|
406 |
+
# https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
|
407 |
+
# https://github.com/ggerganov/llama.cpp/discussions/7712
|
408 |
+
# https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
|
409 |
+
# https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
|
411 |
|
412 |
def is_repo_name(s):
|