MegaTronX commited on
Commit
fdbb7d2
·
verified ·
1 Parent(s): 8f3655b

Update joycaption.py

Browse files
Files changed (1) hide show
  1. joycaption.py +132 -130
joycaption.py CHANGED
@@ -267,144 +267,146 @@ load_text_model(MODEL_PATH, None, LOAD_IN_NF4, True)
267
 
268
  @spaces.GPU()
269
  @torch.inference_mode()
270
- @demo.queue()
271
  def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
272
  max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
273
- global tokenizer, text_model, image_adapter, pixtral_model, pixtral_processor, text_model_client, use_inference_client
274
- torch.cuda.empty_cache()
275
- gc.collect()
276
-
277
- # 'any' means no length specified
278
- length = None if caption_length == "any" else caption_length
279
-
280
- if isinstance(length, str):
281
- try:
282
- length = int(length)
283
- except ValueError:
284
- pass
285
 
286
- # Build prompt
287
- if length is None:
288
- map_idx = 0
289
- elif isinstance(length, int):
290
- map_idx = 1
291
- elif isinstance(length, str):
292
- map_idx = 2
293
- else:
294
- raise ValueError(f"Invalid caption length: {length}")
295
 
296
- prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
297
-
298
- # Add extra options
299
- if len(extra_options) > 0:
300
- prompt_str += " " + " ".join(extra_options)
301
 
302
- # Add name, length, word_count
303
- prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- if custom_prompt.strip() != "":
306
- prompt_str = custom_prompt.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- # For debugging
309
- print(f"Prompt: {prompt_str}")
310
-
311
- # Pixtral
312
- if model_name in PIXTRAL_PATHS:
313
- print(f"pixtral_model: {type(pixtral_model)}") #
314
- print(f"pixtral_processor: {type(pixtral_processor)}") #
315
- input_images = [input_image.convert("RGB")]
316
- input_prompt = "[INST]Caption this image:\n[IMG][/INST]"
317
- inputs = pixtral_processor(images=input_images, text=input_prompt, return_tensors="pt").to(device)
318
- generate_ids = pixtral_model.generate(**inputs, max_new_tokens=max_new_tokens)
319
- output = pixtral_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
320
- return input_prompt, output.strip()
321
-
322
- # Preprocess image
323
- # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
324
- #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
325
- image = input_image.resize((384, 384), Image.LANCZOS)
326
- pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
327
- pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
328
- pixel_values = pixel_values.to(device)
329
-
330
- # Embed image
331
- # This results in Batch x Image Tokens x Features
332
- with torch.amp.autocast_mode.autocast(device, enabled=True):
333
- vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
334
- image_features = vision_outputs.hidden_states
335
- embedded_images = image_adapter(image_features)
336
- embedded_images = embedded_images.to(device)
337
 
338
- # Build the conversation
339
- convo = [
340
- {
341
- "role": "system",
342
- "content": "You are a helpful image captioner.",
343
- },
344
- {
345
- "role": "user",
346
- "content": prompt_str,
347
- },
348
- ]
349
-
350
- # Format the conversation
351
- convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
352
- assert isinstance(convo_string, str)
353
-
354
- # Tokenize the conversation
355
- # prompt_str is tokenized separately so we can do the calculations below
356
- convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
357
- prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
358
- assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
359
- convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
360
- prompt_tokens = prompt_tokens.squeeze(0)
361
-
362
- # Calculate where to inject the image
363
- eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
364
- assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
365
-
366
- preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
367
-
368
- # Embed the tokens
369
- convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
370
-
371
- # Construct the input
372
- input_embeds = torch.cat([
373
- convo_embeds[:, :preamble_len], # Part before the prompt
374
- embedded_images.to(dtype=convo_embeds.dtype), # Image
375
- convo_embeds[:, preamble_len:], # The prompt and anything after it
376
- ], dim=1).to(device)
377
-
378
- input_ids = torch.cat([
379
- convo_tokens[:preamble_len].unsqueeze(0),
380
- torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
381
- convo_tokens[preamble_len:].unsqueeze(0),
382
- ], dim=1).to(device)
383
- attention_mask = torch.ones_like(input_ids)
384
-
385
- # Debugging
386
- #print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
387
-
388
- text_model.to(device)
389
- generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
390
- do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
391
-
392
- # Trim off the prompt
393
- generate_ids = generate_ids[:, input_ids.shape[1]:]
394
- if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
395
- generate_ids = generate_ids[:, :-1]
396
-
397
- caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
398
-
399
- return prompt_str, caption.strip()
400
-
401
-
402
- # https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
403
- # https://github.com/huggingface/transformers/issues/6535
404
- # https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
405
- # https://github.com/ggerganov/llama.cpp/discussions/7712
406
- # https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
407
- # https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
408
 
409
 
410
  def is_repo_name(s):
 
267
 
268
  @spaces.GPU()
269
  @torch.inference_mode()
 
270
  def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
271
  max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ try:
274
+ global tokenizer, text_model, image_adapter, pixtral_model, pixtral_processor, text_model_client, use_inference_client
275
+ torch.cuda.empty_cache()
276
+ gc.collect()
 
 
 
 
 
277
 
278
+ # 'any' means no length specified
279
+ length = None if caption_length == "any" else caption_length
 
 
 
280
 
281
+ if isinstance(length, str):
282
+ try:
283
+ length = int(length)
284
+ except ValueError:
285
+ pass
286
+
287
+ # Build prompt
288
+ if length is None:
289
+ map_idx = 0
290
+ elif isinstance(length, int):
291
+ map_idx = 1
292
+ elif isinstance(length, str):
293
+ map_idx = 2
294
+ else:
295
+ raise ValueError(f"Invalid caption length: {length}")
296
+
297
+ prompt_str = CAPTION_TYPE_MAP[caption_type][map_idx]
298
+
299
+ # Add extra options
300
+ if len(extra_options) > 0:
301
+ prompt_str += " " + " ".join(extra_options)
302
+
303
+ # Add name, length, word_count
304
+ prompt_str = prompt_str.format(name=name_input, length=caption_length, word_count=caption_length)
305
 
306
+ if custom_prompt.strip() != "":
307
+ prompt_str = custom_prompt.strip()
308
+
309
+ # For debugging
310
+ print(f"Prompt: {prompt_str}")
311
+
312
+ # Pixtral
313
+ if model_name in PIXTRAL_PATHS:
314
+ print(f"pixtral_model: {type(pixtral_model)}") #
315
+ print(f"pixtral_processor: {type(pixtral_processor)}") #
316
+ input_images = [input_image.convert("RGB")]
317
+ input_prompt = "[INST]Caption this image:\n[IMG][/INST]"
318
+ inputs = pixtral_processor(images=input_images, text=input_prompt, return_tensors="pt").to(device)
319
+ generate_ids = pixtral_model.generate(**inputs, max_new_tokens=max_new_tokens)
320
+ output = pixtral_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
321
+ return input_prompt, output.strip()
322
+
323
+ # Preprocess image
324
+ # NOTE: I found the default processor for so400M to have worse results than just using PIL directly
325
+ #image = clip_processor(images=input_image, return_tensors='pt').pixel_values
326
+ image = input_image.resize((384, 384), Image.LANCZOS)
327
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
328
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
329
+ pixel_values = pixel_values.to(device)
330
+
331
+ # Embed image
332
+ # This results in Batch x Image Tokens x Features
333
+ with torch.amp.autocast_mode.autocast(device, enabled=True):
334
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
335
+ image_features = vision_outputs.hidden_states
336
+ embedded_images = image_adapter(image_features)
337
+ embedded_images = embedded_images.to(device)
338
+
339
+ # Build the conversation
340
+ convo = [
341
+ {
342
+ "role": "system",
343
+ "content": "You are a helpful image captioner.",
344
+ },
345
+ {
346
+ "role": "user",
347
+ "content": prompt_str,
348
+ },
349
+ ]
350
+
351
+ # Format the conversation
352
+ convo_string = tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = True)
353
+ assert isinstance(convo_string, str)
354
+
355
+ # Tokenize the conversation
356
+ # prompt_str is tokenized separately so we can do the calculations below
357
+ convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
358
+ prompt_tokens = tokenizer.encode(prompt_str, return_tensors="pt", add_special_tokens=False, truncation=False)
359
+ assert isinstance(convo_tokens, torch.Tensor) and isinstance(prompt_tokens, torch.Tensor)
360
+ convo_tokens = convo_tokens.squeeze(0) # Squeeze just to make the following easier
361
+ prompt_tokens = prompt_tokens.squeeze(0)
362
+
363
+ # Calculate where to inject the image
364
+ eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
365
+ assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
366
+
367
+ preamble_len = eot_id_indices[1] - prompt_tokens.shape[0] # Number of tokens before the prompt
368
 
369
+ # Embed the tokens
370
+ convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to(device))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
+ # Construct the input
373
+ input_embeds = torch.cat([
374
+ convo_embeds[:, :preamble_len], # Part before the prompt
375
+ embedded_images.to(dtype=convo_embeds.dtype), # Image
376
+ convo_embeds[:, preamble_len:], # The prompt and anything after it
377
+ ], dim=1).to(device)
378
+
379
+ input_ids = torch.cat([
380
+ convo_tokens[:preamble_len].unsqueeze(0),
381
+ torch.zeros((1, embedded_images.shape[1]), dtype=torch.long), # Dummy tokens for the image (TODO: Should probably use a special token here so as not to confuse any generation algorithms that might be inspecting the input)
382
+ convo_tokens[preamble_len:].unsqueeze(0),
383
+ ], dim=1).to(device)
384
+ attention_mask = torch.ones_like(input_ids)
385
+
386
+ # Debugging
387
+ #print(f"Input to model: {repr(tokenizer.decode(input_ids[0]))}")
388
+
389
+ text_model.to(device)
390
+ generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens,
391
+ do_sample=True, suppress_tokens=None, top_p=top_p, temperature=temperature)
392
+
393
+ # Trim off the prompt
394
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
395
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
396
+ generate_ids = generate_ids[:, :-1]
397
+
398
+ caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
399
+
400
+ return prompt_str, caption.strip()
401
+
402
+ except Exception as e:
403
+ print(e)
404
+ # https://huggingface.co/docs/transformers/v4.44.2/main_classes/text_generation#transformers.FlaxGenerationMixin.generate
405
+ # https://github.com/huggingface/transformers/issues/6535
406
+ # https://zenn.dev/hijikix/articles/8c445f4373fdcc ja
407
+ # https://github.com/ggerganov/llama.cpp/discussions/7712
408
+ # https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility
409
+ # https://huggingface.co/docs/huggingface_hub/v0.24.6/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
 
412
  def is_repo_name(s):