import re import time import random import io from pathlib import Path import json import torch import requests from safetensors.torch import save_file from exllamav2 import( ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer, ) from exllamav2.generator import ( ExLlamaV2BaseGenerator, ExLlamaV2Sampler ) from exl2_wrapper import ExLlamaV2ModuleWrapper ### START Settings template = '<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI assistant.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n' model_dir = '/path/to/Meta-Llama-3-8B-Instruct' harmful_prompts_url = 'ADD_URL_HERE' harmless_prompts_url = 'https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json' ### END Settings torch.cuda._lazy_init() torch.set_printoptions(precision = 5, sci_mode = False, linewidth = 150) config = ExLlamaV2Config() config.model_dir = model_dir config.prepare() config.max_seq_len = 2048 model = ExLlamaV2(config) ExLlamaV2ModuleWrapper.wrap(model, False) model._residual = [] # Enable residual capture out_dir = Path(config.model_dir.replace('/', '_')) out_dir.mkdir(exist_ok = True) harmful_prompts_file = out_dir / Path('harmful_prompts.json') harmless_prompts_file = out_dir / Path('harmless_prompts.json') refused_residual_file = out_dir / Path('refused_residual.pth') allowed_residual_file = out_dir / Path('allowed_residual.pth') allowed_residual_mean_file = out_dir / Path('allowed_residual_mean.pth') suppress_dir_file = out_dir / Path('suppress_dir.safetensors') refused = [] def get_residual(prompts, num_tokens, silent, max_capture, capture_type): global model, tokenizer, settings, refused, generator refused = [] residuals = [] print(f'Processing {len(prompts)} prompts') for idx, prompt in enumerate(prompts): if idx and not (idx % 100): print('', len(residuals)) prompt = template.format(instruction = prompt) model._residual = [] out = generator.generate_simple(prompt, settings, num_tokens, completion_only = True) refusal = re.match(r'^(I\'m not|I cannot|I can\'t|I\'m sorry|As an A|I apolog|I\'m (unable|really|here)|[1I], as|I must|I understand|It(\'s| is) important|Sorry|The (assistant|AI))', out) if capture_type is None or (capture_type == 'refused' and refusal) or (capture_type == 'allowed' and not refusal): residuals.append(model._residual[:]) if refusal: refused.append(prompt) print('-' if refusal else '+', end='', flush = True) if max_capture and len(residuals) >= max_capture: print('\nMax capture reached') break if not silent: print(out) if not len(residuals): return None print(f'\nCaptured {len(residuals)} residual streams') res = [] for l in range(len(residuals[0])): res.append(torch.cat([t[l][0, -1, :].unsqueeze(0) for t in residuals], dim=0)) return res if not harmful_prompts_file.exists(): print('Downloading harmful prompts') res = requests.get(harmful_prompts_url) harmful_prompts = [] for line in res.iter_lines(): if line: harmful_prompts.append(json.loads(line.decode())['prompt']) with harmful_prompts_file.open('w') as f: json.dump(harmful_prompts, f) print('Done') else: with harmful_prompts_file.open('r') as f: harmful_prompts = json.load(f) print(" -- Loading model...") t = time.time() cache = ExLlamaV2Cache(model, lazy=True) model.load_autosplit(cache) t = time.time() - t print(f" -- Loaded model in {t:.4f} seconds") print(" -- Loading tokenizer...") tokenizer = ExLlamaV2Tokenizer(config) settings = ExLlamaV2Sampler.Settings() settings.temperature = 0 generator = ExLlamaV2BaseGenerator(model, cache, tokenizer) with torch.inference_mode(): if not refused_residual_file.exists(): print('Building refused residual data') refused_residual = get_residual(harmful_prompts, 4, True, 2000, 'refused') torch.save(refused_residual, refused_residual_file) else: print('Loading refusal residual data') refused_residual = torch.load(refused_residual_file) print('Done') allowed_residual_mean = [] if not allowed_residual_mean_file.exists(): if not allowed_residual_file.exists(): print('Building allowed residual data') if not harmless_prompts_file.exists(): print('Downloading harmless prompts') res = requests.get(harmless_prompts_url) all_prompts = json.loads(res.content.decode('utf8')) harmless_prompts = [i['instruction'] for i in all_prompts if i['input'] == ''] with harmless_prompts_file.open('w') as f: json.dump(harmless_prompts, f) print('Done') else: with harmless_prompts_file.open('r') as f: harmless_prompts = json.load(f) allowed_residual = get_residual(harmless_prompts, 4, True, 2000, 'allowed') torch.save(allowed_residual, allowed_residual_file) else: print('Loading allowed residual data') allowed_residual = torch.load(allowed_residual_file) print('Done') print('Calculating mean allowed residual') for i in range(len(allowed_residual)): allowed_residual_mean.append(allowed_residual[i].mean(dim = 0)) print('Done') torch.save(allowed_residual_mean, allowed_residual_mean_file) else: allowed_residual_mean = torch.load(allowed_residual_mean_file) if model._suppress_dir is None: model._suppress_dir = [] for o in range(6): print('Iteration', o) for i in range(len(refused_residual)): refusal_dir = refused_residual[i].mean(dim = 0) - allowed_residual_mean[i] refusal_dir = refusal_dir / refusal_dir.norm() if refusal_dir.norm() > 0.0001 else torch.zeros_like(refusal_dir) if len(model._suppress_dir) > i: model._suppress_dir[i] = (model._suppress_dir[i] + refusal_dir) / 2 else: model._suppress_dir.append(refusal_dir) refused_residual = get_residual(random.sample(harmful_prompts, 2000), 4, True, 50, 'refused') if not refused_residual or refused_residual[0].shape[0] < 30: break save_file({f'_suppress_dir_{layer}': tensor for layer, tensor in enumerate(model._suppress_dir)}, suppress_dir_file) torch.cuda.synchronize()