John6666 commited on
Commit
6b7230b
·
verified ·
1 Parent(s): 0f4a1a7

Upload batch-caption.py

Browse files
Files changed (1) hide show
  1. scripts/batch-caption.py +12 -6
scripts/batch-caption.py CHANGED
@@ -45,9 +45,9 @@ parser.add_argument("--top-p", type=lambda x: none_or_type(x, float), default=0.
45
  parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling")
46
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)")
47
  parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel")
48
- parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-alpha-two-hf-llava", help="Model to use")
49
- #parser.add_argument("--model", type=str, default="John6666/llama-joycaption-alpha-two-hf-llava-nf4", help="Model to use")
50
- parser.add_argument("--nf4", action="store_true", default=False, help="Use NF4 (default: bfloat16)")
51
 
52
  PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
53
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -89,8 +89,12 @@ def main():
89
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
90
  tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
91
  assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
92
- if IS_NF4: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, quantization_config=nf4_config, torch_dtype="bfloat16", device_map=device)
93
- else: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", device_map=device)
 
 
 
 
94
  assert isinstance(llava_model, LlavaForConditionalGeneration)
95
 
96
  dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length)
@@ -104,6 +108,7 @@ def main():
104
  vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
105
  vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device
106
  language_device = llava_model.language_model.get_input_embeddings().weight.device
 
107
 
108
  # Move to GPU
109
  pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True)
@@ -336,4 +341,5 @@ if __name__ == "__main__":
336
 
337
  # https://github.com/huggingface/peft/issues/156
338
  # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1331
339
- # https://github.com/huggingface/peft/issues/1831
 
 
45
  parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling")
46
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)")
47
  parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel")
48
+ #parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-alpha-two-hf-llava", help="Model to use")
49
+ parser.add_argument("--model", type=str, default="John6666/llama-joycaption-alpha-two-hf-llava-nf4", help="Model to use")
50
+ parser.add_argument("--nf4", action="store_true", default=True, help="Use NF4 (default: bfloat16)")
51
 
52
  PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
53
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
89
  bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
90
  tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
91
  assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
92
+ if IS_NF4:
93
+ llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", quantization_config=nf4_config).eval()
94
+ # https://github.com/fpgaminer/joycaption/issues/3#issuecomment-2619253277
95
+ attention = llava_model.vision_tower.vision_model.head.attention
96
+ attention.out_proj = torch.nn.Linear(attention.embed_dim, attention.embed_dim, device=llava_model.device, dtype=torch.bfloat16)
97
+ else: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", device_map="auto").eval()
98
  assert isinstance(llava_model, LlavaForConditionalGeneration)
99
 
100
  dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length)
 
108
  vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
109
  vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device
110
  language_device = llava_model.language_model.get_input_embeddings().weight.device
111
+ print(vision_device, vision_dtype, language_device)
112
 
113
  # Move to GPU
114
  pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True)
 
341
 
342
  # https://github.com/huggingface/peft/issues/156
343
  # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1331
344
+ # https://github.com/huggingface/peft/issues/1831
345
+ # https://github.com/fpgaminer/joycaption/issues/3