Upload batch-caption.py
Browse files- scripts/batch-caption.py +12 -6
scripts/batch-caption.py
CHANGED
@@ -45,9 +45,9 @@ parser.add_argument("--top-p", type=lambda x: none_or_type(x, float), default=0.
|
|
45 |
parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling")
|
46 |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)")
|
47 |
parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel")
|
48 |
-
parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-alpha-two-hf-llava", help="Model to use")
|
49 |
-
|
50 |
-
parser.add_argument("--nf4", action="store_true", default=
|
51 |
|
52 |
PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
|
53 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
@@ -89,8 +89,12 @@ def main():
|
|
89 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
90 |
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
|
91 |
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
92 |
-
if IS_NF4:
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
assert isinstance(llava_model, LlavaForConditionalGeneration)
|
95 |
|
96 |
dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length)
|
@@ -104,6 +108,7 @@ def main():
|
|
104 |
vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
|
105 |
vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device
|
106 |
language_device = llava_model.language_model.get_input_embeddings().weight.device
|
|
|
107 |
|
108 |
# Move to GPU
|
109 |
pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True)
|
@@ -336,4 +341,5 @@ if __name__ == "__main__":
|
|
336 |
|
337 |
# https://github.com/huggingface/peft/issues/156
|
338 |
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1331
|
339 |
-
# https://github.com/huggingface/peft/issues/1831
|
|
|
|
45 |
parser.add_argument("--top-k", type=lambda x: none_or_type(x, int), default=None, help="Top-k sampling")
|
46 |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Maximum length of the generated caption (in tokens)")
|
47 |
parser.add_argument("--num-workers", type=int, default=4, help="Number of workers loading images in parallel")
|
48 |
+
#parser.add_argument("--model", type=str, default="fancyfeast/llama-joycaption-alpha-two-hf-llava", help="Model to use")
|
49 |
+
parser.add_argument("--model", type=str, default="John6666/llama-joycaption-alpha-two-hf-llava-nf4", help="Model to use")
|
50 |
+
parser.add_argument("--nf4", action="store_true", default=True, help="Use NF4 (default: bfloat16)")
|
51 |
|
52 |
PIL.Image.MAX_IMAGE_PIXELS = 933120000 # Quiets Pillow from giving warnings on really large images (WARNING: Exposes a risk of DoS from malicious images)
|
53 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
|
89 |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
|
90 |
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True)
|
91 |
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
92 |
+
if IS_NF4:
|
93 |
+
llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", quantization_config=nf4_config).eval()
|
94 |
+
# https://github.com/fpgaminer/joycaption/issues/3#issuecomment-2619253277
|
95 |
+
attention = llava_model.vision_tower.vision_model.head.attention
|
96 |
+
attention.out_proj = torch.nn.Linear(attention.embed_dim, attention.embed_dim, device=llava_model.device, dtype=torch.bfloat16)
|
97 |
+
else: llava_model = LlavaForConditionalGeneration.from_pretrained(args.model, torch_dtype="bfloat16", device_map="auto").eval()
|
98 |
assert isinstance(llava_model, LlavaForConditionalGeneration)
|
99 |
|
100 |
dataset = ImageDataset(prompts, image_paths, tokenizer, llava_model.config.image_token_index, llava_model.config.image_seq_length)
|
|
|
108 |
vision_dtype = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.dtype
|
109 |
vision_device = llava_model.vision_tower.vision_model.embeddings.patch_embedding.weight.device
|
110 |
language_device = llava_model.language_model.get_input_embeddings().weight.device
|
111 |
+
print(vision_device, vision_dtype, language_device)
|
112 |
|
113 |
# Move to GPU
|
114 |
pixel_values = batch['pixel_values'].to(vision_device, non_blocking=True)
|
|
|
341 |
|
342 |
# https://github.com/huggingface/peft/issues/156
|
343 |
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1331
|
344 |
+
# https://github.com/huggingface/peft/issues/1831
|
345 |
+
# https://github.com/fpgaminer/joycaption/issues/3
|