|  | """Module for tokenization utilities""" | 
					
						
						|  |  | 
					
						
						|  | import logging | 
					
						
						|  | import re | 
					
						
						|  | from typing import Dict, List | 
					
						
						|  |  | 
					
						
						|  | from termcolor import colored | 
					
						
						|  |  | 
					
						
						|  | LOG = logging.getLogger("axolotl") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_dataset_labels( | 
					
						
						|  | dataset, | 
					
						
						|  | tokenizer, | 
					
						
						|  | num_examples=5, | 
					
						
						|  | text_only=False, | 
					
						
						|  | rl_mode=False, | 
					
						
						|  | ): | 
					
						
						|  |  | 
					
						
						|  | for idx in range(num_examples): | 
					
						
						|  | if not rl_mode: | 
					
						
						|  | check_example_labels(dataset[idx], tokenizer, text_only=text_only) | 
					
						
						|  | else: | 
					
						
						|  | check_rl_example_labels(dataset[idx], tokenizer, text_only=text_only) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_example_labels(example, tokenizer, text_only=False): | 
					
						
						|  |  | 
					
						
						|  | input_ids = example["input_ids"] | 
					
						
						|  | labels = example["labels"] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | colored_tokens = [] | 
					
						
						|  | for _, (input_id, label_id) in enumerate(zip(input_ids, labels)): | 
					
						
						|  | decoded_input_token = tokenizer.decode(input_id) | 
					
						
						|  |  | 
					
						
						|  | color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") | 
					
						
						|  | colored_token = colored(decoded_input_token, color) + ( | 
					
						
						|  | not text_only and colored(f"({label_id}, {input_id})", "white") or "" | 
					
						
						|  | ) | 
					
						
						|  | colored_tokens.append(colored_token) | 
					
						
						|  |  | 
					
						
						|  | delimiter = "" if text_only else " " | 
					
						
						|  | LOG.info(delimiter.join(colored_tokens)) | 
					
						
						|  | LOG.info("\n\n\n") | 
					
						
						|  |  | 
					
						
						|  | return " ".join(colored_tokens) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def color_token_for_rl_debug(decoded_token, encoded_token, color, text_only): | 
					
						
						|  | """Helper function to color tokens based on their type.""" | 
					
						
						|  | colored_text = colored(decoded_token, color) | 
					
						
						|  | return ( | 
					
						
						|  | colored_text | 
					
						
						|  | if text_only | 
					
						
						|  | else f"{colored_text}{colored(f'({encoded_token})', 'white')}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only): | 
					
						
						|  | """Helper function to process and color tokens.""" | 
					
						
						|  | colored_tokens = [ | 
					
						
						|  | color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only) | 
					
						
						|  | for token in tokenizer.encode(tokens) | 
					
						
						|  | ] | 
					
						
						|  | return colored_tokens | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_rl_example_labels(example, tokenizer, text_only=False): | 
					
						
						|  | field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected" | 
					
						
						|  |  | 
					
						
						|  | input_tokens = example[field_prompt] | 
					
						
						|  | labels_chosen, labels_rejected = example[field_chosen], example[field_rejected] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | colored_tokens = process_tokens_for_rl_debug( | 
					
						
						|  | input_tokens, "yellow", tokenizer, text_only | 
					
						
						|  | ) | 
					
						
						|  | colored_chosens = process_tokens_for_rl_debug( | 
					
						
						|  | labels_chosen, "green", tokenizer, text_only | 
					
						
						|  | ) | 
					
						
						|  | colored_rejecteds = process_tokens_for_rl_debug( | 
					
						
						|  | labels_rejected, "red", tokenizer, text_only | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | delimiter = "" if text_only else " " | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n") | 
					
						
						|  | LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n") | 
					
						
						|  | LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n") | 
					
						
						|  |  | 
					
						
						|  | return delimiter.join(colored_tokens) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"] | 
					
						
						|  | GLAIVE_TO_SHAREGPT_ROLE = { | 
					
						
						|  | "SYSTEM": "system", | 
					
						
						|  | "USER": "human", | 
					
						
						|  | "ASSISTANT": "gpt", | 
					
						
						|  | "FUNCTION RESPONSE": "tool", | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]: | 
					
						
						|  | """ | 
					
						
						|  | Converts a ChatML formatted row to a list of messages in ShareGPT format. | 
					
						
						|  | Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | system_prompt = row.get("system") | 
					
						
						|  | if system_prompt: | 
					
						
						|  | system_prompt = system_prompt.removeprefix("SYSTEM: ") | 
					
						
						|  |  | 
					
						
						|  | chat_str = row["chat"] | 
					
						
						|  | chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s] | 
					
						
						|  |  | 
					
						
						|  | chat_msg_dicts = [ | 
					
						
						|  | {"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value} | 
					
						
						|  | for role, value in zip(chat_msgs[::2], chat_msgs[1::2]) | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | if system_prompt: | 
					
						
						|  | chat_msg_dicts = [ | 
					
						
						|  | {"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt} | 
					
						
						|  | ] + chat_msg_dicts | 
					
						
						|  |  | 
					
						
						|  | return chat_msg_dicts | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def merge_consecutive_messages(messages): | 
					
						
						|  | """ | 
					
						
						|  | Merge consecutive messages from the same sender into a single message. | 
					
						
						|  | This can be useful with datasets that contain multiple consecutive tool calls. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | merged_messages = [] | 
					
						
						|  | current_from = None | 
					
						
						|  | current_message = "" | 
					
						
						|  |  | 
					
						
						|  | for msg in messages: | 
					
						
						|  | if current_from == msg["from"]: | 
					
						
						|  | current_message += msg["value"] | 
					
						
						|  | else: | 
					
						
						|  | if current_from is not None: | 
					
						
						|  | merged_messages.append({"from": current_from, "value": current_message}) | 
					
						
						|  | current_from = msg["from"] | 
					
						
						|  | current_message = msg["value"] | 
					
						
						|  |  | 
					
						
						|  | if current_from is not None: | 
					
						
						|  | merged_messages.append({"from": current_from, "value": current_message}) | 
					
						
						|  |  | 
					
						
						|  | return merged_messages | 
					
						
						|  |  |