File size: 6,665 Bytes
a0e37ea
de89071
a0e37ea
 
 
 
 
 
76fc84a
a0e37ea
 
 
 
a89fd86
a0e37ea
62388d1
a0e37ea
a89fd86
 
62388d1
a0e37ea
62388d1
a0e37ea
62388d1
 
cdaece6
62388d1
a0e37ea
62388d1
4c7e018
 
 
 
62388d1
a0e37ea
62388d1
a0e37ea
4c7e018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6865e7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c7e018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76fc84a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
---
library_name: transformers
license: apache-2.0
base_model: unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit
tags:
- generated_from_trainer
datasets:
- instruction_solution_to_thought_dataset.jsonl
- secemp9/instruction_solution_thought
model-index:
- name: outputs_solution_to_thought
  results: []
---
![image/png](https://cdn-uploads.huggingface.co/production/uploads/65986192b0c5357368bacbf8/-_THTLhEqxfXjuyh_jaFk.png)

# TraceBack 12b Release



TraceBack is what I came up with when I thought, "how can we scale reasoning trace data generation effectively?"

Turn out you do not need to depend on just reasoning models (r1, o1, o3, etc) to create reasoning trace!

It has many goals in mind, but mainly:
- enabling faster synthetic reasoning dataset generation, since we're using a small model here (smaller than r1, etc) so faster to do inference on, thus easier to scale
- distill on synthetic traces for out of domain non-verifiable problems
- converting any non-reasoning model output/datasets to a reasoning synthetic dataset when used as input

So far, current proof of concept managed to check the boxes for 1 and 3, and I plan on scaling this more as:
- this only use Mistral Nemo 12b as base
- Was only trained for 2 epochs
- Only 200k samples were used for finetuning (Qlora), dataset at https://huggingface.co/datasets/secemp9/instruction_solution_thought
 
So there are still much room for improvement

This was trained using both instruction and solution as input, and the output being a plausible/possible/matching reasoning trace based on that.

I believe this is the future of reasoning data generation. Stay tuned for an eval release

Here some inference example, using chatgpt instruction + solution as input:

# Inference Example
Here I use a simple example from chatgpt, passing both the instruction and the solution as input to the model:
![image/png](https://cdn-uploads.huggingface.co/production/uploads/65986192b0c5357368bacbf8/rtuYmWGw8lk09AQi_dpX8.png)

# Dataset Example

Here the format for the dataset follow instruction + solution: reasoning trace pairs
Sample conversation:
```
{
  "messages": [
    {
      "role": "user",
      "content": "Instruction:
      text_here

      Solution:
      text_here
    },
    {
      "role": "assistant",
      "content": "text_here"
    }
  ]
}
```
which look like:
![image/png](https://cdn-uploads.huggingface.co/production/uploads/65986192b0c5357368bacbf8/GdbZxeLSDsJmZDHJ8SN-g.png)

# Prompt Format

For the prompt format, I was really trying to not overengineer, but I'm sure there is a better way to format this.

For now it's just:
Instruction:

Solution:

the output of the model doesn't have (for now) any formatting, it's just reasoning as output

# Code Example

- Using transformers:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load the tokenizer and model
model_name = "secemp9/TraceBack-12b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Move the model to the desired device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

# Define the messages
messages = [
    {"role": "user", "content": """Instruction:
how many r in strawberry


Solution:
There are **three** "r"s in "strawberry."
"""}
]

# Step 1: Apply chat template to get formatted text as a string
formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Step 2: Tokenize the formatted text into a dictionary of tensors
inputs = tokenizer(formatted_text, return_tensors="pt").to(device)

# Generate the response
outputs = model.generate(**inputs, max_new_tokens=32000)

# Decode and print the output
generated_text = tokenizer.decode(outputs[0])
print(generated_text)
```

- unsloth
```python
from unsloth import FastLanguageModel

# Load the model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained("secemp9/TraceBack-12b")

# Define the messages (replace "stuff_here" with your actual input)
messages = [
    {"role": "user", "content": """Instruction:
how many r in strawberry


Solution:
There are **three** "r"s in "strawberry."
"""}
]

# Step 1: Apply chat template to get formatted text as a string
formatted_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Step 2: Tokenize the formatted text into a dictionary of tensors
inputs = tokenizer(formatted_text, return_tensors="pt").to(model.device)

# Generate the response
outputs = model.generate(**inputs, max_new_tokens=32000)

# Decode and print the output
generated_text = tokenizer.decode(outputs[0])
print(generated_text)
```
# Axolotl config

For this, I basically tried to convert my unsloth code to an axolotl config file. I also used deepspeed. Configuration below:

config.yml
```
# Base model configuration
base_model: unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit
load_in_4bit: true

# Dataset configuration
datasets:
  - path: instruction_solution_to_thought_dataset.jsonl
    type: chat_template

# Chat template
chat_template: chatml

# LoRA adapter configuration
adapter: lora
lora_r: 16
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
  - q_proj
  - k_proj
  - v_proj
  - o_proj
  - gate_proj
  - up_proj
  - down_proj

# Training hyperparameters
max_seq_length: 128000
micro_batch_size: 2
gradient_accumulation_steps: 8
learning_rate: 3e-5
num_epochs: 3
warmup_steps: 100
optimizer: adamw_8bit
weight_decay: 0.01
lr_scheduler_type: cosine
max_grad_norm: 1.0
output_dir: ./outputs_solution_to_thought
seed: 3407
merge_lora: true
hf_upload: true
hf_repo: secemp9/TraceBack-12b
xformers_attention:
flash_attention: True
bf16: true          # Enable BF16 mixed precision
# Multi-GPU training with DeepSpeed
deepspeed: deepspeed_configs/zero2.json

# Optional: Enable gradient checkpointing
gradient_checkpointing: true
```

deepspeed_configs/zero2.json
```
{
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 2e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 2e8,
    "contiguous_gradients": true
  },
  "bf16": {
    "enabled": true
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "weight_decay": "auto",
      "betas": [0.9, 0.999],
      "eps": 1e-8
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": "auto",
      "warmup_num_steps": "auto"
    }
  },
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_accumulation_steps": "auto",
  "steps_per_print": 10,
  "wandb": {
    "enabled": true
  }
}
```