File size: 1,795 Bytes
fdad0f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Due to licensing restrictions from LLAMA, you need to have the original LLAMA-7B model to use this model. 
# To decrypt the model weights, obtain the original LLAMA-7B model (not the huggingface version) and run the following command:
# decrypt.py [path-to-consolidated.00.pth] [path-to-our-model-folder]

import os
import sys
import glob
import numpy as np

def xor_files(seed_path, input_path, output_path, buffer_size=16*1024*1024):
    # Check if output file exists
    if os.path.exists(output_path):
        print('Skipping already decrypted file: ' + output_path)
        return
    print('Decrypting: ', input_path, ' to ', output_path)
    with open(seed_path, "rb") as seed_file:
        # Read first 16MB of seed file
        seed_data = seed_file.read(buffer_size)
        # store to bufSeed
        bufSeed = np.frombuffer(seed_data, dtype=np.uint8)

    with open(input_path, "rb") as input_file, open(output_path, "wb") as output_file:
        while True:
            input_data = input_file.read(buffer_size)
            if not input_data:
                break
            inputLen = len(input_data)
            bufTmp = np.frombuffer(input_data, dtype=np.uint8) ^ bufSeed[:inputLen]
            output_data = bufTmp.tobytes()
            output_file.write(output_data)


def main(seed_path, folder_path):
    enc_files = glob.glob(os.path.join(folder_path, "*.enc"))

    for enc_file in enc_files:
        output_file = os.path.splitext(enc_file)[0]
        xor_files(seed_path, enc_file, output_file)

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python decrypt.py <path-to-llama-7b-consolidated.00.pth-file> <our-model-folder>")
        sys.exit(1)

    seed_path = sys.argv[1]
    folder_path = sys.argv[2]

    main(seed_path, folder_path)