|
--- |
|
license: mit |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
|
|
# Intro |
|
|
|
[Activation Beacon](https://arxiv.org/abs/2401.03462) is a plug-in module to transformer-based LLMs that enables effective, efficient, and flexible compression of long contexts. |
|
|
|
# Environment |
|
``` |
|
pip install transformers |
|
pip install flash-attn --no-build-isolation |
|
``` |
|
|
|
|
|
# Usage |
|
```python |
|
import json |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model_id = "namespace-Pt/beacon-qwen-2-7b-instruct" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2" |
|
) |
|
|
|
model = model.cuda().eval() |
|
|
|
with torch.no_grad(): |
|
# short context |
|
messages = [{"role": "user", "content": "Tell me about yourself."}] |
|
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda") |
|
outputs = model.generate(**inputs, max_new_tokens=50) |
|
print(f"Input Length: {inputs['input_ids'].shape[1]}") |
|
print(f"Output: {repr(tokenizer.decode(outputs[0], skip_special_tokens=True))}") |
|
|
|
# reset memory before new generation task |
|
model.memory.reset() |
|
|
|
# long context |
|
with open("infbench.json", encoding="utf-8") as f: |
|
example = json.load(f) |
|
messages = [{"role": "user", "content": example["context"]}] |
|
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda") |
|
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:] |
|
print("*"*20) |
|
print(f"Input Length: {inputs['input_ids'].shape[1]}") |
|
print(f"Answers: {example['answer']}") |
|
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") |
|
``` |
|
**NOTE**: It's okay to see warnings like `This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (32768). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.` Just ignore it. |