|
--- |
|
license: mit |
|
pipeline_tag: text-generation |
|
--- |
|
|
|
<div align="center"> |
|
<h1>UltraGist for Mistral-7B-Instruct-v0.2</h1> |
|
|
|
[<a href="https://arxiv.org/abs/2405.16635">Paper</a>] [<a href="https://github.com/namespace-Pt/UltraGist">Github</a>] |
|
</div> |
|
|
|
UltraGist is a context compression method can **flexibly**, **effectively**, and **efficiently** to handle various context lengths and compression ratios. We apply UltraGist on [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2). |
|
|
|
## Usage |
|
```python |
|
import json |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
model_id = "namespace-Pt/ultragist-mistral-7b-inst" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="sdpa", |
|
# load the entire model on the default gpu |
|
device_map={"": "cuda"}, |
|
# you can manually set the compression ratio, otherwise the model will automatically choose the most suitable compression ratio from [2,4,8,16,32] |
|
# ultragist_ratio=[8], |
|
).eval() |
|
|
|
|
|
with torch.no_grad(): |
|
# long context |
|
with open("data/nqa.json", encoding="utf-8") as f: |
|
example = json.load(f) |
|
content = f"Read this article:\n\n{example['context']}\n\nNow, answer the question based on the above context.\nQuestion:\n{example['input']}" |
|
messages = [{"role": "user", "content": content}] |
|
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda") |
|
|
|
# reset memory before new compression task |
|
model.memory.reset() |
|
|
|
# directly call generate to progressively compress the context while generating next tokens |
|
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=40)[:, inputs["input_ids"].shape[1]:] |
|
print("*"*20) |
|
print(f"Input size: {inputs['input_ids'].shape[1]}") |
|
print(f"Question: {example['input']}") |
|
print(f"Answers: {example['answers']}") |
|
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") |
|
print("*"*20) |
|
|
|
# extract the compressed memory (including the generated tokens) |
|
compressed_memory = model.memory.get_memory() |
|
ultragist_size, raw_size, sink_size = model.memory.get_memory_size() |
|
print(f"UltraGist size: {ultragist_size}") |
|
print(f"Raw size: {raw_size}") |
|
print(f"Sink size: {sink_size}") |
|
print(f"Memory: {compressed_memory[0][0].shape}") |
|
print("*"*20) |
|
``` |
|
|