marianna13 commited on
Commit
ef8da88
Β·
1 Parent(s): 203ce72
llava/model/language_model/llava_mistral.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ MistralConfig, MistralModel, MistralForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(MistralConfig):
31
+ model_type = "llava_mistral"
32
+
33
+
34
+ class LlavaMistralModel(LlavaMetaModel, MistralModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: MistralConfig):
38
+ super(LlavaMistralModel, self).__init__(config)
39
+
40
+
41
+ class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(MistralForCausalLM, self).__init__(config)
46
+ self.model = LlavaMistralModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("llava_mistral", LlavaConfig)
140
+ AutoModelForCausalLM.register(LlavaConfig, LlavaMistralForCausalLM)
llava/model/language_model/llava_phi.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn import CrossEntropyLoss
21
+
22
+ from transformers import AutoConfig, AutoModelForCausalLM, \
23
+ PhiConfig, PhiModel, PhiForCausalLM
24
+
25
+ from transformers.modeling_outputs import CausalLMOutputWithPast
26
+
27
+ from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28
+
29
+
30
+ class LlavaConfig(PhiConfig):
31
+ model_type = "phi-llava"
32
+
33
+
34
+ class LlavaPhiModel(LlavaMetaModel, PhiModel):
35
+ config_class = LlavaConfig
36
+
37
+ def __init__(self, config: PhiConfig):
38
+ super(LlavaPhiModel, self).__init__(config)
39
+
40
+
41
+ class LlavaPhiForCausalLM(PhiForCausalLM, LlavaMetaForCausalLM):
42
+ config_class = LlavaConfig
43
+
44
+ def __init__(self, config):
45
+ super(PhiForCausalLM, self).__init__(config)
46
+ self.model = LlavaPhiModel(config)
47
+
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
70
+ output_hidden_states = (
71
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
72
+ )
73
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
74
+
75
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
76
+
77
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
78
+ outputs = self.model(
79
+ input_ids=input_ids,
80
+ attention_mask=attention_mask,
81
+ past_key_values=past_key_values,
82
+ inputs_embeds=inputs_embeds,
83
+ use_cache=use_cache,
84
+ output_attentions=output_attentions,
85
+ output_hidden_states=output_hidden_states,
86
+ return_dict=return_dict
87
+ )
88
+
89
+ hidden_states = outputs[0]
90
+ logits = self.lm_head(hidden_states)
91
+
92
+ loss = None
93
+ if labels is not None:
94
+ # Shift so that tokens < n predict n
95
+ shift_logits = logits[..., :-1, :].contiguous()
96
+ shift_labels = labels[..., 1:].contiguous()
97
+ # Flatten the tokens
98
+ loss_fct = CrossEntropyLoss()
99
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
100
+ shift_labels = shift_labels.view(-1)
101
+ # Enable model/pipeline parallelism
102
+ shift_labels = shift_labels.to(shift_logits.device)
103
+ loss = loss_fct(shift_logits, shift_labels)
104
+
105
+ if not return_dict:
106
+ output = (logits,) + outputs[1:]
107
+ return (loss,) + output if loss is not None else output
108
+
109
+ return CausalLMOutputWithPast(
110
+ loss=loss,
111
+ logits=logits,
112
+ past_key_values=outputs.past_key_values,
113
+ hidden_states=outputs.hidden_states,
114
+ attentions=outputs.attentions,
115
+ )
116
+
117
+ def prepare_inputs_for_generation(
118
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
119
+ ):
120
+ if past_key_values:
121
+ input_ids = input_ids[:, -1:]
122
+
123
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
124
+ if inputs_embeds is not None and past_key_values is None:
125
+ model_inputs = {"inputs_embeds": inputs_embeds}
126
+ else:
127
+ model_inputs = {"input_ids": input_ids}
128
+
129
+ model_inputs.update(
130
+ {
131
+ "past_key_values": past_key_values,
132
+ "use_cache": kwargs.get("use_cache"),
133
+ "attention_mask": attention_mask,
134
+ "images": kwargs.get("images", None),
135
+ }
136
+ )
137
+ return model_inputs
138
+
139
+ AutoConfig.register("phi-llava", LlavaConfig)
140
+ AutoModelForCausalLM.register(LlavaConfig, LlavaPhiForCausalLM)
llava/train/webdataset_utils.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ from dataclasses import dataclass
9
+ from multiprocessing import Value
10
+
11
+ import braceexpand
12
+ import numpy as np
13
+ import pandas as pd
14
+ import torch
15
+ import webdataset as wds
16
+ from PIL import Image
17
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
18
+ from webdataset.filters import _shuffle
19
+ from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
20
+ # from data_utils import get_normalized_weights_and_num_samples
21
+ from typing import List, Tuple
22
+
23
+
24
+ def get_normalized_weights_and_num_samples(
25
+ weights: List[float], num_samples: int
26
+ ) -> Tuple[List[float], List[int]]:
27
+ # Normalize weights
28
+ weight_sum = sum(weights)
29
+ assert weight_sum > 0.0
30
+ weights = [weight / weight_sum for weight in weights]
31
+ # Add 0.5% (the 1.005 factor) so in case the blending dataset does
32
+ # not uniformly distribute the number of samples, we still have
33
+ # samples left to feed to the network.
34
+ weighted_num_samples = []
35
+ for weight in weights:
36
+ weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
37
+ return weights, weighted_num_samples
38
+
39
+ class SharedEpoch:
40
+ def __init__(self, epoch: int = 0):
41
+ self.shared_epoch = Value('i', epoch)
42
+
43
+ def set_value(self, epoch):
44
+ self.shared_epoch.value = epoch
45
+
46
+ def get_value(self):
47
+ return self.shared_epoch.value
48
+
49
+
50
+ @dataclass
51
+ class DataInfo:
52
+ dataloader: DataLoader
53
+ shared_epoch: SharedEpoch = None
54
+
55
+ def set_epoch(self, epoch):
56
+ if self.shared_epoch is not None:
57
+ self.shared_epoch.set_value(epoch)
58
+
59
+ def expand_urls(urls, weights=None):
60
+ if weights is None:
61
+ expanded_urls = wds.shardlists.expand_urls(urls)
62
+ expanded_urls=[url for url in expanded_urls if os.path.exists(url)]
63
+ ### go save existed url
64
+ return expanded_urls, None
65
+ if isinstance(urls, str):
66
+ urllist = urls.split("::")
67
+ weights = weights.split('::')
68
+ assert len(weights) == len(urllist),\
69
+ f"Expected the number of data components ({len(urllist)}) and weights({len(weights)}) to match."
70
+ weights = [float(weight) for weight in weights]
71
+ all_urls, all_weights = [], []
72
+ for url, weight in zip(urllist, weights):
73
+ expanded_url = list(braceexpand.braceexpand(url))
74
+ expanded_urls=[url for url in expanded_urls if os.path.exists(url)]
75
+ ### go save existed url
76
+ expanded_weights = [weight for _ in expanded_url]
77
+ all_urls.extend(expanded_url)
78
+ all_weights.extend(expanded_weights)
79
+ return all_urls, all_weights
80
+ else:
81
+ all_urls = list(urls)
82
+ return all_urls, weights
83
+
84
+
85
+ def get_dataset_size(shards):
86
+ shards_list = (shards)
87
+ dir_path = os.path.dirname(shards_list[0])
88
+ sizes_filename = os.path.join(dir_path, 'sizes.json')
89
+ len_filename = os.path.join(dir_path, '__len__')
90
+ if os.path.exists(sizes_filename):
91
+ sizes = json.load(open(sizes_filename, 'r'))
92
+ total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
93
+ elif os.path.exists(len_filename):
94
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
95
+ total_size = ast.literal_eval(open(len_filename, 'r').read())
96
+ else:
97
+ total_size = None # num samples undefined
98
+ # some common dataset sizes (at time of authors last download)
99
+ # CC3M (train): 2905954
100
+ # CC12M: 10968539
101
+ # LAION-400M: 407332084
102
+ # LAION-2B (english): 2170337258
103
+ num_shards = len(shards_list)
104
+ return total_size, num_shards
105
+
106
+ def count_samples(dataloader):
107
+ os.environ["WDS_EPOCH"] = "0"
108
+ n_elements, n_batches = 0, 0
109
+ for images, texts in dataloader:
110
+ n_batches += 1
111
+ n_elements += len(images)
112
+ assert len(images) == len(texts)
113
+ return n_elements, n_batches
114
+
115
+ def filter_no_caption_or_no_image(sample):
116
+ has_caption = ('txt' in sample)
117
+ has_image = ('png' in sample or 'jpg' in sample or 'jpeg' in sample or 'webp' in sample)
118
+ return has_caption and has_image
119
+
120
+
121
+ def log_and_continue(exn):
122
+ """Call in an exception handler to ignore any exception, issue a warning, and continue."""
123
+ logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
124
+ return True
125
+
126
+ def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
127
+ """Return function over iterator that groups key, value pairs into samples.
128
+
129
+ :param keys: function that splits the key into key and extension (base_plus_ext)
130
+ :param lcase: convert suffixes to lower case (Default value = True)
131
+ """
132
+ current_sample = None
133
+ for filesample in data:
134
+ assert isinstance(filesample, dict)
135
+ fname, value = filesample["fname"], filesample["data"]
136
+ prefix, suffix = keys(fname)
137
+ if prefix is None:
138
+ continue
139
+ if lcase:
140
+ suffix = suffix.lower()
141
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
142
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
143
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
144
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
145
+ if valid_sample(current_sample):
146
+ yield current_sample
147
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
148
+ if suffixes is None or suffix in suffixes:
149
+ current_sample[suffix] = value
150
+ if valid_sample(current_sample):
151
+ yield current_sample
152
+
153
+
154
+ def tarfile_to_samples_nothrow(src, handler=log_and_continue):
155
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
156
+ streams = url_opener(src, handler=handler)
157
+ files = tar_file_expander(streams, handler=handler)
158
+ samples = group_by_keys_nothrow(files, handler=handler)
159
+ return samples
160
+
161
+
162
+ def pytorch_worker_seed(increment=0):
163
+ """get dataloader worker seed from pytorch"""
164
+ worker_info = get_worker_info()
165
+ if worker_info is not None:
166
+ # favour using the seed already created for pytorch dataloader workers if it exists
167
+ seed = worker_info.seed
168
+ if increment:
169
+ # space out seed increments so they can't overlap across workers in different iterations
170
+ seed += increment * max(1, worker_info.num_workers)
171
+ return seed
172
+ # fallback to wds rank based seed
173
+ return wds.utils.pytorch_worker_seed()
174
+
175
+
176
+ _SHARD_SHUFFLE_SIZE = 2000
177
+ _SHARD_SHUFFLE_INITIAL = 500
178
+ _SAMPLE_SHUFFLE_SIZE = 5000
179
+ _SAMPLE_SHUFFLE_INITIAL = 1000
180
+
181
+
182
+ class detshuffle2(wds.PipelineStage):
183
+ def __init__(
184
+ self,
185
+ bufsize=1000,
186
+ initial=100,
187
+ seed=0,
188
+ epoch=-1,
189
+ ):
190
+ self.bufsize = bufsize
191
+ self.initial = initial
192
+ self.seed = seed
193
+ self.epoch = epoch
194
+
195
+ def run(self, src):
196
+ if isinstance(self.epoch, SharedEpoch):
197
+ epoch = self.epoch.get_value()
198
+ else:
199
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
200
+ # situation as different workers may wrap at different times (or not at all).
201
+ self.epoch += 1
202
+ epoch = self.epoch
203
+ rng = random.Random()
204
+ if self.seed < 0:
205
+ # If seed is negative, we use the worker's seed, this will be different across all nodes/workers
206
+ seed = pytorch_worker_seed(epoch)
207
+ else:
208
+ # This seed to be deterministic AND the same across all nodes/workers in each epoch
209
+ seed = self.seed + epoch
210
+ rng.seed(seed)
211
+ return _shuffle(src, self.bufsize, self.initial, rng)
212
+
213
+
214
+ class ResampledShards2(IterableDataset):
215
+ """An iterable dataset yielding a list of urls."""
216
+
217
+ def __init__(
218
+ self,
219
+ urls,
220
+ weights=None,
221
+ nshards=sys.maxsize,
222
+ worker_seed=None,
223
+ deterministic=False,
224
+ epoch=-1,
225
+ ):
226
+ """Sample shards from the shard list with replacement.
227
+
228
+ :param urls: a list of URLs as a Python list or brace notation string
229
+ """
230
+ super().__init__()
231
+ urls, weights = expand_urls(urls, weights)
232
+ self.urls = urls
233
+ self.weights = weights
234
+ if self.weights is not None:
235
+ assert len(self.urls) == len(self.weights),\
236
+ f"Number of urls {len(self.urls)} and weights {len(self.weights)} should match."
237
+ assert isinstance(self.urls[0], str)
238
+ self.nshards = nshards
239
+ self.rng = random.Random()
240
+ self.worker_seed = worker_seed
241
+ self.deterministic = deterministic
242
+ self.epoch = epoch
243
+
244
+ def __iter__(self):
245
+ """Return an iterator over the shards."""
246
+ if isinstance(self.epoch, SharedEpoch):
247
+ epoch = self.epoch.get_value()
248
+ else:
249
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
250
+ # situation as different workers may wrap at different times (or not at all).
251
+ self.epoch += 1
252
+ epoch = self.epoch
253
+ if self.deterministic:
254
+ # reset seed w/ epoch if deterministic
255
+ if self.worker_seed is None:
256
+ # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
257
+ seed = pytorch_worker_seed(epoch)
258
+ else:
259
+ seed = self.worker_seed() + epoch
260
+ self.rng.seed(seed)
261
+ for _ in range(self.nshards):
262
+ if self.weights is None:
263
+ yield dict(url=self.rng.choice(self.urls))
264
+ else:
265
+ yield dict(url=self.rng.choices(self.urls, weights=self.weights, k=1)[0])
266
+
267
+
268
+ def image_text_dict_collation_fn(samples):
269
+ """Customize collation_fn to generate dict batch """
270
+ assert isinstance(samples[0], (list, tuple)), type(samples[0])
271
+ batched = list(zip(*samples))
272
+ result = dict()
273
+ import torch
274
+ import numpy as np
275
+ for b in batched:
276
+ b = torch.stack(list(b))
277
+ if b.dim()>=3: # dim means image
278
+ result['img']=b
279
+ else:
280
+ result['text']=b
281
+
282
+ return result
283
+
284
+ def decode_image(png_bytes):
285
+ return Image.open(BytesIO(png_bytes))
286
+
287
+
288
+ def process_sample(sample):
289
+
290
+ if "png" not in sample:
291
+ sample["png"] = b''
292
+ else:
293
+ sample["png"] = decode_image(sample["png"])
294
+
295
+ sample = {"png": sample["png"], "json": sample["json"]}
296
+ return sample
297
+
298
+ def get_wds_data(args, is_train, epoch=0, floor=False, wds_processor=None):
299
+ if args.data_path and (args.train_data_weights is None):
300
+ args.train_data_weights = [1.0] * len(args.data_path)
301
+
302
+ input_shards = args.data_path if is_train else args.valid_data_path
303
+ input_weights = args.train_data_weights if is_train else args.valid_data_weights
304
+
305
+ assert input_shards is not None
306
+
307
+ resampled = getattr(args, 'dataset_resampled', False)
308
+
309
+ num_shards = None
310
+ if is_train:
311
+ if args.train_num_samples is not None:
312
+ num_samples = args.train_num_samples
313
+ else:
314
+ num_samples, num_shards = get_dataset_size(input_shards)
315
+ if not num_samples:
316
+ raise RuntimeError(
317
+ 'Currently, the number of dataset samples must be specified for the training dataset. '
318
+ 'Please specify it via `--train-num-samples` if no dataset length info is present.')
319
+ else:
320
+ # Eval will just exhaust the iterator if the size is not specified.
321
+ num_samples = args.val_num_samples or 0
322
+ weights, weighted_num_samples = get_normalized_weights_and_num_samples(input_weights, num_samples)
323
+ shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
324
+
325
+ if resampled:
326
+ complete_url_list = []
327
+ complete_weights = []
328
+ for i, (urls, weights) in enumerate(zip(input_shards, weights)):
329
+ current_url_list = expand_urls(urls)[0]
330
+ complete_url_list.extend(current_url_list)
331
+ per_url_weight = weights / len(current_url_list)
332
+ complete_weights.extend([per_url_weight] * len(current_url_list))
333
+ # pipeline = [ResampledShards2(
334
+ # complete_url_list,
335
+ # weights=complete_weights,
336
+ # deterministic=True,
337
+ # epoch=shared_epoch,
338
+ # )]
339
+ pipeline = [wds.SimpleShardList(complete_url_list)]
340
+ else:
341
+ # assert args.train_data_upsampling_factors is None,\
342
+ # "--train_data_upsampling_factors is only supported when sampling with replacement (with --dataset-resampled)."
343
+ pipeline = [wds.SimpleShardList(input_shards)]
344
+
345
+ # at this point we have an iterator over all the shards
346
+ if is_train:
347
+ if not resampled:
348
+ pipeline.extend([
349
+ detshuffle2(
350
+ bufsize=_SHARD_SHUFFLE_SIZE,
351
+ initial=_SHARD_SHUFFLE_INITIAL,
352
+ # seed=args.seed,
353
+ epoch=shared_epoch,
354
+ ),
355
+ wds.split_by_node,
356
+ wds.split_by_worker,
357
+ ])
358
+ pipeline.extend([
359
+ # at this point, we have an iterator over the shards assigned to each worker at each node
360
+ # tarfile_to_samples_nothrow, # wds.tarfile_to_samples(handler=log_and_continue),
361
+ wds.tarfile_to_samples(handler=log_and_continue),
362
+ wds.shuffle(
363
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
364
+ initial=_SAMPLE_SHUFFLE_INITIAL,
365
+ ),
366
+ ])
367
+ else:
368
+ pipeline.extend([
369
+ wds.split_by_worker,
370
+ # at this point, we have an iterator over the shards assigned to each worker
371
+ wds.tarfile_to_samples(handler=log_and_continue),
372
+ ])
373
+ ### build preprocess_img and preprocess_text from args
374
+ # from .transforms import get_clip_transforms
375
+ # preprocess_img = get_clip_transforms(image_size=data_args.image_processor.crop_size)
376
+
377
+ # assert (
378
+ # args.tokenizer.name in ['HFGPT2Tokenizer','HFGPT2TokenizerFast','HFTokenizer']
379
+ # ), f"Webdataset only support HFTokenizer, HFGPT2Tokenizer or HFGPT2TokenizerFast"
380
+
381
+ # tokenize = args.tokenizer.tokenize
382
+
383
+
384
+ pipeline.extend([
385
+ # wds.select(filter_no_caption_or_no_image),
386
+ wds.decode("pilrgb", handler=log_and_continue),
387
+ wds.rename(image="jpg;png;jpeg;webp", text="json"),
388
+ wds.to_tuple("image", "text"),
389
+ wds.map(wds_processor)
390
+ # wds.map_dict(image=preprocess_img, text=lambda text: tokenize(text)[0]),
391
+
392
+ # wds.batched(args.batch_size, collation_fn=image_text_dict_collation_fn, partial=not is_train)
393
+ ])
394
+ # pipeline.extend([
395
+ # wds.map(process_sample),
396
+ # wds.rename(image="jpg;png;jpeg;webp", text="json"),
397
+ # wds.to_tuple("image", "text")
398
+ # wds.map(wds_processor)
399
+ # ])
400
+
401
+ dataset = wds.DataPipeline(*pipeline)
402
+
403
+ if is_train:
404
+ if not resampled:
405
+ num_shards = num_shards or len(expand_urls(input_shards)[0])
406
+ # assert num_shards >= args.num_workers * args.world_size, 'number of shards must be >= total workers'
407
+ # roll over and repeat a few samples to get same number of full batches on each node
408
+ round_fn = math.floor if floor else math.ceil
409
+ global_batch_size = args.batch_size * args.world_size
410
+ num_batches = round_fn(num_samples / global_batch_size)
411
+ num_workers = max(1, args.dataloader_num_workers)
412
+ num_worker_batches = round_fn(num_batches / num_workers) # per dataloader worker
413
+ num_batches = num_worker_batches * num_workers
414
+ num_samples = num_batches * global_batch_size
415
+ dataset = dataset.with_epoch(num_worker_batches) # each worker is iterating over this
416
+ # dataset = dataset.with_epoch(num_samples)
417
+ else:
418
+ # last batches are partial, eval is done on single (master) node
419
+ num_batches = math.ceil(num_samples / args.batch_size)
420
+
421
+ # dataloader = wds.WebLoader(
422
+ # dataset,
423
+ # batch_size=None,
424
+ # shuffle=False,
425
+ # num_workers=args.num_workers,
426
+ # persistent_workers=not (args.num_workers == 0), # set persistent_workers to false if num_workers is 0
427
+ # )
428
+
429
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
430
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
431
+ # if is_train:
432
+ # # roll over and repeat a few samples to get same number of full batches on each node
433
+ # global_batch_size = args.batch_size * args.world_size
434
+ # num_batches = math.ceil(num_samples / global_batch_size)
435
+ # num_workers = max(1, args.num_workers)
436
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
437
+ # num_samples = num_batches * global_batch_size
438
+ # dataloader = dataloader.with_epoch(num_batches)
439
+ # else:
440
+ # # last batches are partial, eval is done on single (master) node
441
+ # num_batches = math.ceil(num_samples / args.batch_size)
442
+
443
+ # add meta-data to dataloader instance for convenience
444
+ # dataloader.num_batches = num_batches
445
+ # dataloader.num_samples = num_samples
446
+
447
+ return dataset
448
+
449
+
450
+ # def get_data(args, preprocess_fns, epoch=0, tokenizer=None):
451
+ # preprocess_train, preprocess_val = preprocess_fns
452
+ # data = {}
453
+
454
+ # if args.train_data or args.dataset_type == "synthetic":
455
+ # data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
456
+ # args, preprocess_train, is_train=True, epoch=epoch, tokenizer=tokenizer)
457
+
458
+ # if args.val_data:
459
+ # data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
460
+ # args, preprocess_val, is_train=False, tokenizer=tokenizer)
461
+
462
+ # return data