Spaces:
Running
Running
File size: 8,220 Bytes
9df4cc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import os
import datasets
# A dictionary to store various prompt templates.
template_dict = {
'default': 'Instruction: {instruction}\nInput: {input}\nAnswer: '
}
# A dictionary to store the LoRA module mapping for different models.
lora_module_dict = {
'chatglm2': ['query_key_value'],
'falcon': ['query_key_value'],
'bloom': ['query_key_value'],
'internlm': ['q_proj', 'k_proj', 'v_proj'],
'llama2': ['q_proj', 'k_proj', 'v_proj'],
'llama2-13b': ['q_proj', 'k_proj', 'v_proj'],
'llama2-13b-nr': ['q_proj', 'k_proj', 'v_proj'],
'qwen': ["c_attn"],
'mpt': ['Wqkv'],
'baichuan': ['q_proj', 'k_proj', 'v_proj'],
}
def get_prompt(template, instruction, input_text):
"""
Generates a prompt based on a predefined template, instruction, and input.
Args:
template (str): The key to select the prompt template from the predefined dictionary.
instruction (str): The instruction text to be included in the prompt.
input_text (str): The input text to be included in the prompt.
Returns:
str: The generated prompt.
Raises:
KeyError: If the provided template key is not found in the template dictionary.
"""
if not instruction:
return input_text
if template not in template_dict:
raise KeyError(f"Template '{template}' not found. Available templates: {', '.join(template_dict.keys())}")
return template_dict[template].format(instruction=instruction, input=input_text)
def test_mapping(args, feature):
"""
Generate a mapping for testing purposes by constructing a prompt based on given instructions and input.
Args:
args (Namespace): A namespace object that holds various configurations, including the instruction template.
feature (dict): A dictionary containing 'instruction' and 'input' fields used to construct the prompt.
Returns:
dict: A dictionary containing the generated prompt.
Raises:
ValueError: If 'instruction' or 'input' are not provided in the feature dictionary.
"""
# Ensure 'instruction' and 'input' are present in the feature dictionary.
if 'instruction' not in feature or 'input' not in feature:
raise ValueError("Both 'instruction' and 'input' need to be provided in the feature dictionary.")
# Construct the prompt using the provided instruction and input.
prompt = get_prompt(
args.instruct_template,
feature['instruction'],
feature['input']
)
return {
"prompt": prompt,
}
def tokenize(args, tokenizer, feature):
"""
Tokenizes the input prompt and target/output for model training or evaluation.
Args:
args (Namespace): A namespace object containing various settings and configurations.
tokenizer (Tokenizer): A tokenizer object used to convert text into tokens.
feature (dict): A dictionary containing 'input', 'instruction', and 'output' fields.
Returns:
dict: A dictionary containing tokenized 'input_ids', 'labels', and a flag 'exceed_max_length'.
"""
# Generate the prompt.
prompt = get_prompt(
args.instruct_template,
feature['instruction'],
feature['input']
)
# Tokenize the prompt.
prompt_ids = tokenizer(
prompt,
padding=False,
max_length=args.max_length,
truncation=True
)['input_ids']
# Tokenize the target/output.
target_ids = tokenizer(
feature['output'].strip(),
padding=False,
max_length=args.max_length,
truncation=True,
add_special_tokens=False
)['input_ids']
# Combine tokenized prompt and target output.
input_ids = prompt_ids + target_ids
# Check if the combined length exceeds the maximum allowed length.
exceed_max_length = len(input_ids) >= args.max_length
# Add an end-of-sequence (EOS) token if it's not already present
# and if the sequence length is within the limit.
if input_ids[-1] != tokenizer.eos_token_id and not exceed_max_length:
input_ids.append(tokenizer.eos_token_id)
# Create label IDs for training.
# The labels should start from where the prompt ends, and be padded for the prompt portion.
label_ids = [tokenizer.pad_token_id] * len(prompt_ids) + input_ids[len(prompt_ids):]
return {
"input_ids": input_ids,
"labels": label_ids,
"exceed_max_length": exceed_max_length
}
def parse_model_name(name, from_remote=False):
"""
Parse the model name and return the appropriate path based on whether
the model is to be fetched from a remote source or from a local source.
Args:
- name (str): Name of the model.
- from_remote (bool): If True, return the remote path, else return the local path.
Returns:
- str: The appropriate path for the given model name.
"""
model_paths = {
'chatglm2': ('THUDM/chatglm2-6b', 'base_models/chatglm2-6b'),
'llama2': ('meta-llama/Llama-2-7b-hf', 'base_models/Llama-2-7b-hf'),
'llama2-13b': ('meta-llama/Llama-2-13b-hf', 'base_models/Llama-2-13b-hf'),
'llama2-13b-nr': ('NousResearch/Llama-2-13b-hf', 'base_models/Llama-2-13b-hf'),
'falcon': ('tiiuae/falcon-7b', 'base_models/falcon-7b'),
'internlm': ('internlm/internlm-7b', 'base_models/internlm-7b'),
'qwen': ('Qwen/Qwen-7B', 'base_models/Qwen-7B'),
'baichuan': ('baichuan-inc/Baichuan2-7B-Base', 'base_models/Baichuan2-7B-Base'),
'mpt': ('cekal/mpt-7b-peft-compatible', 'base_models/mpt-7b-peft-compatible'),
'bloom': ('bigscience/bloom-7b1', 'base_models/bloom-7b1')
}
if name in model_paths:
return model_paths[name][0] if from_remote else model_paths[name][1]
else:
valid_model_names = ', '.join(model_paths.keys())
raise ValueError(f"Undefined base model '{name}'. Valid model names are: {valid_model_names}")
def load_dataset(names, from_remote=False):
"""
Load one or multiple datasets based on the provided names and source location.
Args:
names (str): A comma-separated list of dataset names. Each name can be followed by '*n' to indicate replication.
from_remote (bool): If True, load the dataset from Hugging Face's model hub. Otherwise, load it from a local disk.
Returns:
List[Dataset]: A list of loaded datasets. Each dataset is possibly replicated based on the input names.
"""
# Split the dataset names by commas for handling multiple datasets
dataset_names = names.split(',')
dataset_list = []
for name in dataset_names:
# Initialize replication factor to 1
replication_factor = 1
dataset_name = name
# Check if the dataset name includes a replication factor
if '*' in name:
dataset_name, replication_factor = name.split('*')
replication_factor = int(replication_factor)
if replication_factor < 1:
raise ValueError("Replication factor must be a positive integer.")
# Construct the correct dataset path or name based on the source location
dataset_path_or_name = ('FinGPT/fingpt-' if from_remote else 'data/fingpt-') + dataset_name
if not os.path.exists(dataset_path_or_name) and not from_remote:
raise FileNotFoundError(f"The dataset path {dataset_path_or_name} does not exist.")
# Load the dataset
try:
tmp_dataset = datasets.load_dataset(dataset_path_or_name) if from_remote else datasets.load_from_disk(
dataset_path_or_name)
except Exception as e:
raise RuntimeError(f"Failed to load the dataset: {str(e)}")
# Check for 'test' split and create it from 'train' if necessary
if 'test' not in tmp_dataset:
if 'train' in tmp_dataset:
tmp_dataset = tmp_dataset['train']
tmp_dataset = tmp_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)
else:
raise ValueError("The dataset must contain a 'train' or 'test' split.")
# Append the possibly replicated dataset to the list
dataset_list.extend([tmp_dataset] * replication_factor)
return dataset_list
|