memoryllm-8b-chat / README.md
YuWangX's picture
Upload MemoryLLM
e09761e verified
|
raw
history blame
1.63 kB
metadata
license: apache-2.0

This model is continually pre-trained from meta-llama/Meta-Llama-3-8B with the structure proposed in MemoryLLM.
We equip Llama-3 with 12800 memory tokens in each layer, leading to a memory pool of 1.67B parameters.

To use the model, please use the following code:

git clone [email protected]:wangyu-ustc/MemoryLLM.git
cd MemoryLLM

Then simply use the following code to load the model:

from modeling_memoryllm import MemoryLLM
from configuration_memoryllm import MemoryLLMConfig
from transformers import AutoTokenizer
model = MemoryLLM.from_pretrained("YuWangX/memoryllm-8b-chat")
tokenizer = AutoTokenizer.from_pretrained("YuWangX/memoryllm-8b-chat")

How to use the model

Inject a piece of context into the model using the following script:

model = model.cuda()

# Self-Update with the new context
ctx = "David likes eating apples."
model.inject_memory(tokenizer(ctx, return_tensors='pt', add_special_tokens=False).input_ids.cuda(), update_memory=True)

# Generation
messages = [{
    'role': 'user', "content": "What fruits does David like?",
}]

inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
inputs = inputs[:, 1:] # remove bos token

outputs = model.generate(input_ids=inputs.cuda(),
                         max_new_tokens=20)
response = tokenizer.decode(outputs[0])

outputs = model.generate(inputs=input_ids.cuda(), attention_mask=attention_mask.cuda(), max_new_tokens=10)
print(tokenizer.decode(outputs[0]))