Description

Experiments with encoder-decoder model, where encoder is alephbert-base and decoder is pruned mT5-base model
Could be useful for generation negative and hard-negative samples for pair-text classification.
(To paraphrase is better to use classical approaches rather than this one)

Usage

git clone https://huggingface.co/imvladikon/alephbert-encoder-t5-decoder
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoModel
from transformers.modeling_outputs import BaseModelOutput
from datasets import load_dataset

enc_checkpoint = "./alephbert-encoder-t5-decoder/encoder"
enc_tokenizer = AutoTokenizer.from_pretrained(enc_checkpoint)
encoder = AutoModel.from_pretrained(enc_checkpoint).cuda()

dec_checkpoint = "./alephbert-encoder-t5-decoder/decoder"
dec_tokenizer = AutoTokenizer.from_pretrained(dec_checkpoint)
decoder = AutoModelForSeq2SeqLM.from_pretrained(dec_checkpoint).cuda()


def encode(texts):
    encoded_input = enc_tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
    with torch.no_grad():
        model_output = encoder(**encoded_input.to(encoder.device))
        embeddings = model_output.pooler_output
        embeddings = torch.nn.functional.normalize(embeddings)
    return embeddings


def decode(embeddings, max_length=256, repetition_penalty=3.0, **kwargs):
    out = decoder.generate(
        encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), 
        max_length=max_length, 
        repetition_penalty=repetition_penalty,
    )
    return [dec_tokenizer.decode(tokens, skip_special_tokens=True) for tokens in out]


encoder.eval()

text = """
诪讞专 讬讜住讬祝 诇讛讬讜转 诪注讜谞谉 讞诇拽讬转 讜讘诪讛诇讱 讛讬讜诐 讬转讞讝拽讜 讛专讜讞讜转 讘讚专讜诐 讛讗专抓 讜讬讬转讻谉 讗讜讘讱 讘讗讝讜专.
""".strip()
batch = [text]
embeddings = encode(batch)
decoder.eval()
out = decoder.generate(encoder_outputs=BaseModelOutput(last_hidden_state=embeddings.unsqueeze(1)), max_length=512, repetition_penalty=3.0)

for t, o in zip(batch, out):
    print(t)
    print(dec_tokenizer.decode(o, skip_special_tokens=True))
    print('-----------')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.