metadata
language:
- en
widget:
- text: >-
generate question: <hl> 42 <hl> is the answer to life, the universe and
everything. </s>
tags:
- seq2seq
license: cc-by-nc-sa-4.0
To use the model with a pipeline:
from transformers import pipeline
triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
# We need to use the tokenizer manually since we need special tokens.
extracted_text = triplet_extractor.tokenizer.decode(triplet_extractor("Punta Cana is a resort town in the municipality of Higuey, in La Altagracia Province, the eastern most province of the Dominican Republic", return_tensors=True, return_text=False)[0]["generated_token_ids"])
print(extracted_text)
# Function to parse the generated text and extract the triplets
def extract_triplets(text):
triplets = []
relation = ''
for token in text.split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append((subject, relation, object_))
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append((subject, relation, object_))
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
triplets.append((subject, relation, object_))
return triplets
extracted_triplets = extract_triplets(extracted_text)
print(extracted_triplets)
Or using the transformers
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def extract_triplets(text):
triplets = []
relation = ''
for token in text.split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append((subject, relation, object_))
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append((subject, relation, object_))
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
triplets.append((subject, relation, object_))
return triplets
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
gen_kwargs = {
"max_length": 256,
"length_penalty": 0,
"num_beams": 3,
"num_return_sequences": 3,
}
# Text to extract triplets from
text = 'Punta Cana is a resort town in the municipality of Higüey, in La Altagracia Province, the easternmost province of the Dominican Republic.'
# Tokenizer text
model_inputs = tokenizer(text, max_length=256, padding=True, truncation=True, return_tensors = 'pt')
# Generate
generated_tokens = model.generate(
model_inputs["input_ids"].to(model.device),
attention_mask=model_inputs["attention_mask"].to(model.device),
**gen_kwargs,
)
# Extract text
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)
# Extract triplets
for idx, sentence in enumerate(decoded_preds):
print(f'Prediction triplets sentence {idx}')
print(extract_triplets(sentence))