Linear probe checkpoints for https://footprints.baulab.info
Paper: https://arxiv.org/abs/2406.20086

To load a Llama-2-7b checkpoint at layer 0 and target index -3:

import torch 
import torch.nn as nn
from huggingface_hub import hf_hub_download

class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, bias=False):
        super(LinearModel, self).__init__()
        self.fc = nn.Linear(input_size, output_size, bias=bias)
    def forward(self, x):
        output = self.fc(x)
        return output

# example: llama-2-7b probe at layer 0, predicting 3 tokens ago
# predicting the next token would be `layer0_tgtidx1.ckpt`
checkpoint_path = hf_hub_download(
    repo_id="sfeucht/footprints", 
    filename="llama-2-7b/layer0_tgtidx-3.ckpt"
)

# model_size is 4096 for both models.
# vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b
probe = LinearModel(4096, 32000).cuda()
probe.load_state_dict(torch.load(checkpoint_path))
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.