ff670's picture
Upload decrypt.py
fdad0f3
raw
history blame
1.8 kB
# Due to licensing restrictions from LLAMA, you need to have the original LLAMA-7B model to use this model.
# To decrypt the model weights, obtain the original LLAMA-7B model (not the huggingface version) and run the following command:
# decrypt.py [path-to-consolidated.00.pth] [path-to-our-model-folder]
import os
import sys
import glob
import numpy as np
def xor_files(seed_path, input_path, output_path, buffer_size=16*1024*1024):
# Check if output file exists
if os.path.exists(output_path):
print('Skipping already decrypted file: ' + output_path)
return
print('Decrypting: ', input_path, ' to ', output_path)
with open(seed_path, "rb") as seed_file:
# Read first 16MB of seed file
seed_data = seed_file.read(buffer_size)
# store to bufSeed
bufSeed = np.frombuffer(seed_data, dtype=np.uint8)
with open(input_path, "rb") as input_file, open(output_path, "wb") as output_file:
while True:
input_data = input_file.read(buffer_size)
if not input_data:
break
inputLen = len(input_data)
bufTmp = np.frombuffer(input_data, dtype=np.uint8) ^ bufSeed[:inputLen]
output_data = bufTmp.tobytes()
output_file.write(output_data)
def main(seed_path, folder_path):
enc_files = glob.glob(os.path.join(folder_path, "*.enc"))
for enc_file in enc_files:
output_file = os.path.splitext(enc_file)[0]
xor_files(seed_path, enc_file, output_file)
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python decrypt.py <path-to-llama-7b-consolidated.00.pth-file> <our-model-folder>")
sys.exit(1)
seed_path = sys.argv[1]
folder_path = sys.argv[2]
main(seed_path, folder_path)