YuWangX commited on
Commit
3ab81fe
·
verified ·
1 Parent(s): bac5710

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +49 -3
README.md CHANGED
@@ -1,3 +1,49 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+
6
+ This model is continually pre-trained from [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) with the structure proposed in [MemoryLLM](https://arxiv.org/abs/2402.04624).
7
+
8
+ To use the model, please use the following code:
9
+ ```
10
+ git clone [email protected]:wangyu-ustc/MemoryLLM.git
11
+ cd MemoryLLM
12
+ ```
13
+ Then simply use the following code to load the model:
14
+ ```python
15
+ from modeling_memoryllm import MemoryLLM
16
+ from configuration_memoryllm import MemoryLLMConfig
17
+ from transformers import AutoTokenizer
18
+ model = MemoryLLM.from_pretrained("YuWangX/memoryllm-8b-chat")
19
+ tokenizer = AutoTokenizer.from_pretrained("YuWangX/memoryllm-8b-chat")
20
+ ```
21
+
22
+ ```python
23
+
24
+ ### How to use the model
25
+ Inject a piece of context into the model using the following script:
26
+ ```python
27
+ model = model.cuda()
28
+
29
+ # Self-Update with the new context
30
+ ctx = "David likes eating apples."
31
+ model.inject_memory(tokenizer(ctx, return_tensors='pt', add_special_tokens=False).input_ids.cuda(), update_memory=True)
32
+
33
+ # Generation
34
+ messages = [{
35
+ 'role': 'user', "content": "What fruits does David like?",
36
+ }]
37
+
38
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
39
+ inputs = inputs[:, 1:] # remove bos token
40
+
41
+ outputs = model.generate(input_ids=inputs.cuda(),
42
+ max_new_tokens=20)
43
+ response = tokenizer.decode(outputs[0])
44
+
45
+ outputs = model.generate(inputs=input_ids.cuda(), attention_mask=attention_mask.cuda(), max_new_tokens=10)
46
+ print(tokenizer.decode(outputs[0]))
47
+ ```
48
+ ```
49
+