bge-m3-sparse-experimental / sample-encoding-sparse.py
p0x0q's picture
スパースが得られるように
5760b3d
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, XLMRobertaModel
# カスタムレイヤーの定義
class SparseLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(SparseLinear, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
# カスタムモデルの定義
class CustomXLMRobertaModel(XLMRobertaModel):
def __init__(self, config):
super(CustomXLMRobertaModel, self).__init__(config)
self.sparse_linear = SparseLinear(config.hidden_size, 1) # 適切な出力次元を設定
def forward(self, *args, **kwargs):
outputs = super(CustomXLMRobertaModel, self).forward(*args, **kwargs)
dense_embeddings = outputs.last_hidden_state
sparse_embeddings = self.sparse_linear(dense_embeddings)
return outputs, sparse_embeddings
# モデルとトークナイザーのロード
model_name = "." # ローカルディレクトリを指定
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoModel.from_pretrained(model_name).config
# マージされたモデルのロード
merged_model = CustomXLMRobertaModel.from_pretrained(model_name, config=config)
merged_model.load_state_dict(torch.load("merged_pytorch_model.bin"))
# テキストのエンコード
def encode_text(text):
inputs = tokenizer(text, return_tensors="pt")
outputs, sparse_embeddings = merged_model(**inputs)
return outputs, sparse_embeddings
# テキストのエンコード例
text = "こんにちは"
sparse_embeddings = encode_text(text)
print(sparse_embeddings)