metadata
license: apache-2.0
This model is continually pre-trained from meta-llama/Meta-Llama-3-8B with the structure proposed in MemoryLLM.
We equip Llama-3 with 12800 memory tokens in each layer, leading to a memory pool of 1.67B parameters.
To use the model, please use the following code:
git clone [email protected]:wangyu-ustc/MemoryLLM.git
cd MemoryLLM
Then simply use the following code to load the model:
from modeling_memoryllm import MemoryLLM
from configuration_memoryllm import MemoryLLMConfig
from transformers import AutoTokenizer
model = MemoryLLM.from_pretrained("YuWangX/memoryllm-8b-chat")
tokenizer = AutoTokenizer.from_pretrained("YuWangX/memoryllm-8b-chat")
How to use the model
Inject a piece of context into the model using the following script:
model = model.cuda()
# Self-Update with the new context
ctx = "David likes eating apples."
model.inject_memory(tokenizer(ctx, return_tensors='pt', add_special_tokens=False).input_ids.cuda(), update_memory=True)
# Generation
messages = [{
'role': 'user', "content": "What fruits does David like?",
}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
inputs = inputs[:, 1:] # remove bos token
outputs = model.generate(input_ids=inputs.cuda(),
max_new_tokens=20)
response = tokenizer.decode(outputs[0])
outputs = model.generate(inputs=input_ids.cuda(), attention_mask=attention_mask.cuda(), max_new_tokens=10)
print(tokenizer.decode(outputs[0]))