zhb10086 commited on
Commit
dd2ba51
·
verified ·
1 Parent(s): 501f4d0

Update preprocessing_molmo.py

Browse files
Files changed (1) hide show
  1. preprocessing_molmo.py +16 -5
preprocessing_molmo.py CHANGED
@@ -97,22 +97,31 @@ class MolmoProcessor(ProcessorMixin):
97
  self._special_tokens = get_special_token_ids(self.tokenizer)
98
  return self._special_tokens
99
 
100
- def get_tokens_input(self, prompt, message_format, always_start_with_space, out_text=None):
101
  if message_format == "none" or message_format is None:
102
  pass
103
  elif message_format == "role":
104
  prompt = "User: " + prompt + " Assistant:"
105
  else:
106
  raise NotImplementedError(f"Message format {message_format} not implemented")
107
-
108
  if always_start_with_space:
109
  prompt = " " + prompt
110
-
111
  if out_text is not None:
112
  prompt = " ".join([prompt, out_text])
113
-
114
  tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
115
 
 
 
 
 
 
 
 
 
 
116
  return tokens
117
 
118
  def process(
@@ -120,6 +129,7 @@ class MolmoProcessor(ProcessorMixin):
120
  text: TextInput = None,
121
  images: ImageInput = None,
122
  out_text: TextInput = None,
 
123
  *,
124
  tokens: Optional[PreTokenizedInput] = None,
125
  out_tokens: Optional[PreTokenizedInput] = None,
@@ -136,7 +146,8 @@ class MolmoProcessor(ProcessorMixin):
136
  text,
137
  output_kwargs["text_kwargs"]["message_format"],
138
  output_kwargs["text_kwargs"]["always_start_with_space"],
139
- out_text
 
140
  )
141
 
142
  if out_tokens is not None:
 
97
  self._special_tokens = get_special_token_ids(self.tokenizer)
98
  return self._special_tokens
99
 
100
+ def get_tokens_input(self, prompt, message_format, always_start_with_space, out_text=None, pad_length=None):
101
  if message_format == "none" or message_format is None:
102
  pass
103
  elif message_format == "role":
104
  prompt = "User: " + prompt + " Assistant:"
105
  else:
106
  raise NotImplementedError(f"Message format {message_format} not implemented")
107
+
108
  if always_start_with_space:
109
  prompt = " " + prompt
110
+
111
  if out_text is not None:
112
  prompt = " ".join([prompt, out_text])
113
+
114
  tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
115
 
116
+ if pad_length is not None:
117
+ assert isinstance(pad_length, int)
118
+ if len(tokens) > pad_length:
119
+ tokens = tokens[:pad_length]
120
+
121
+ if len(tokens) < pad_length:
122
+ pad_token_id = self.tokenizer.pad_token_id or 0 # Use 0 if pad_token_id is not set
123
+ tokens = tokens + [pad_token_id] * (pad_length - len(tokens))
124
+
125
  return tokens
126
 
127
  def process(
 
129
  text: TextInput = None,
130
  images: ImageInput = None,
131
  out_text: TextInput = None,
132
+ pad_length: int = None,
133
  *,
134
  tokens: Optional[PreTokenizedInput] = None,
135
  out_tokens: Optional[PreTokenizedInput] = None,
 
146
  text,
147
  output_kwargs["text_kwargs"]["message_format"],
148
  output_kwargs["text_kwargs"]["always_start_with_space"],
149
+ out_text,
150
+ pad_length,
151
  )
152
 
153
  if out_tokens is not None: