File size: 2,133 Bytes
07816a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac38f7f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, XLMRobertaModel

# カスタムモデルの定義
class CustomXLMRobertaModel(XLMRobertaModel):
    def __init__(self, config):
        super(CustomXLMRobertaModel, self).__init__(config)
        self.sparse_linear = SparseLinear(config.hidden_size, 1)  # 適切な出力次元を設定

    def forward(self, *args, **kwargs):
        outputs = super(CustomXLMRobertaModel, self).forward(*args, **kwargs)
        dense_embeddings = outputs.last_hidden_state
        sparse_embeddings = self.sparse_linear(dense_embeddings)
        return sparse_embeddings

# カスタムレイヤーの定義
class SparseLinear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SparseLinear, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

# モデルとトークナイザーのロード
model_name = "."  # ローカルディレクトリを指定
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoModel.from_pretrained(model_name).config
model = CustomXLMRobertaModel.from_pretrained(model_name, config=config)

# カスタムレイヤーのインスタンスを作成
input_dim = 1024  # Denseベクトルの次元(モデルのhidden_sizeに合わせる)
output_dim = 1    # Sparseベクトルの次元(保存された重みに合わせる)
sparse_linear = SparseLinear(input_dim, output_dim)

# Sparse線形変換のロード
sparse_linear_path = "sparse_linear.pt"
sparse_linear_state_dict = torch.load(sparse_linear_path, weights_only=True)

# state_dictのキーを変換
sparse_linear_state_dict = {
    f"linear.{key}": value for key, value in sparse_linear_state_dict.items()
}

# カスタムレイヤーにstate_dictをロード
sparse_linear.load_state_dict(sparse_linear_state_dict)

# カスタムレイヤーの重みを元のモデルの重みに追加
model.sparse_linear.load_state_dict(sparse_linear.state_dict())

# マージされたモデルの保存
torch.save(model.state_dict(), "merged_pytorch_model.bin")