|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModel, AutoTokenizer, XLMRobertaModel |
|
|
|
|
|
class CustomXLMRobertaModel(XLMRobertaModel): |
|
def __init__(self, config): |
|
super(CustomXLMRobertaModel, self).__init__(config) |
|
self.sparse_linear = SparseLinear(config.hidden_size, 1) |
|
|
|
def forward(self, *args, **kwargs): |
|
outputs = super(CustomXLMRobertaModel, self).forward(*args, **kwargs) |
|
dense_embeddings = outputs.last_hidden_state |
|
sparse_embeddings = self.sparse_linear(dense_embeddings) |
|
return sparse_embeddings |
|
|
|
|
|
class SparseLinear(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super(SparseLinear, self).__init__() |
|
self.linear = nn.Linear(input_dim, output_dim) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
|
|
model_name = "." |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
config = AutoModel.from_pretrained(model_name).config |
|
model = CustomXLMRobertaModel.from_pretrained(model_name, config=config) |
|
|
|
|
|
input_dim = 1024 |
|
output_dim = 1 |
|
sparse_linear = SparseLinear(input_dim, output_dim) |
|
|
|
|
|
sparse_linear_path = "sparse_linear.pt" |
|
sparse_linear_state_dict = torch.load(sparse_linear_path, weights_only=True) |
|
|
|
|
|
sparse_linear_state_dict = { |
|
f"linear.{key}": value for key, value in sparse_linear_state_dict.items() |
|
} |
|
|
|
|
|
sparse_linear.load_state_dict(sparse_linear_state_dict) |
|
|
|
|
|
model.sparse_linear.load_state_dict(sparse_linear.state_dict()) |
|
|
|
|
|
torch.save(model.state_dict(), "merged_pytorch_model.bin") |
|
|