'
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+
+ return output
+
+
+def generate_cai_chat_html(history, name1, name2, style, reset_cache=False):
+ output = f''
+
+ # We use ?name2 and ?time.time() to force the browser to reset caches
+ img_bot = f'
' if Path("cache/pfp_character.png").exists() else ''
+ img_me = f'
' if Path("cache/pfp_me.png").exists() else ''
+
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+
+ {img_bot}
+
+
+
+ {name2}
+
+
+ {row[1]}
+
+
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+
+ {img_me}
+
+
+
+ {name1}
+
+
+ {row[0]}
+
+
+
+ """
+
+ output += "
"
+ return output
+
+
+def generate_chat_html(history, name1, name2, reset_cache=False):
+ output = f''
+
+ for i, _row in enumerate(history[::-1]):
+ row = [convert_to_markdown(entry) for entry in _row]
+
+ output += f"""
+
+ """
+
+ if len(row[0]) == 0: # don't display empty user messages
+ continue
+
+ output += f"""
+
+ """
+
+ output += "
"
+ return output
+
+
+def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False):
+ if mode == 'instruct':
+ return generate_instruct_html(history)
+ elif style == 'wpp':
+ return generate_chat_html(history, name1, name2)
+ else:
+ return generate_cai_chat_html(history, name1, name2, style, reset_cache)
diff --git a/modules/llama_attn_hijack.py b/modules/llama_attn_hijack.py
new file mode 100644
index 0000000000000000000000000000000000000000..e953f523d6c54581af1a30deb8b922f85b3e557a
--- /dev/null
+++ b/modules/llama_attn_hijack.py
@@ -0,0 +1,171 @@
+import logging
+import math
+import sys
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import transformers.models.llama.modeling_llama
+
+import modules.shared as shared
+
+if shared.args.xformers:
+ try:
+ import xformers.ops
+ except Exception:
+ logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
+
+
+def hijack_llama_attention():
+ if shared.args.xformers:
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+ logging.info("Replaced attention with xformers_attention")
+ elif shared.args.sdp_attention:
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
+ logging.info("Replaced attention with sdp_attention")
+
+
+def xformers_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
+ if not output_attentions:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
+ else:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
+
+
+def sdp_attention_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply sdp attention if we don't need to output the whole attention matrix
+ if not output_attentions:
+ attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, attn_weights, past_key_value
diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b0344cd1919cee4b82e709e43768053446ecff
--- /dev/null
+++ b/modules/llamacpp_model.py
@@ -0,0 +1,87 @@
+'''
+Based on
+https://github.com/abetlen/llama-cpp-python
+
+Documentation:
+https://abetlen.github.io/llama-cpp-python/
+'''
+
+import logging
+import re
+
+from llama_cpp import Llama, LlamaCache
+
+from modules import shared
+from modules.callbacks import Iteratorize
+import os
+
+
+class LlamaCppModel:
+ def __init__(self):
+ self.initialized = False
+
+ def __del__(self):
+ self.model.__del__()
+
+ @classmethod
+ def from_pretrained(self, path):
+ result = self()
+
+ cache_capacity = 0
+ if shared.args.cache_capacity is not None:
+ if 'GiB' in shared.args.cache_capacity:
+ cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000 * 1000
+ elif 'MiB' in shared.args.cache_capacity:
+ cache_capacity = int(re.sub('[a-zA-Z]', '', shared.args.cache_capacity)) * 1000 * 1000
+ else:
+ cache_capacity = int(shared.args.cache_capacity)
+
+ logging.info("Cache capacity is " + str(cache_capacity) + " bytes")
+
+ params = {
+ 'model_path': str(path),
+ 'n_ctx': 2048,
+ 'seed': 0,
+ 'n_threads': 8,
+ 'n_batch': shared.args.n_batch,
+ 'use_mmap': not shared.args.no_mmap,
+ 'use_mlock': shared.args.mlock,
+ 'n_gpu_layers': shared.args.n_gpu_layers
+ }
+ self.model = Llama(**params)
+ if cache_capacity > 0:
+ self.model.set_cache(LlamaCache(capacity_bytes=cache_capacity))
+
+ # This is ugly, but the model and the tokenizer are the same object in this library.
+ return result, result
+
+ def encode(self, string):
+ if type(string) is str:
+ string = string.encode()
+ return self.model.tokenize(string)
+
+ def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
+ context = context if type(context) is str else context.decode()
+ completion_chunks = self.model.create_completion(
+ prompt=context,
+ max_tokens=token_count,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ repeat_penalty=repetition_penalty,
+ stream=True
+ )
+ output = ""
+ for completion_chunk in completion_chunks:
+ text = completion_chunk['choices'][0]['text']
+ output += text
+ if callback:
+ callback(text)
+ return output
+
+ def generate_with_streaming(self, **kwargs):
+ with Iteratorize(self.generate, kwargs, callback=None) as generator:
+ reply = ''
+ for token in generator:
+ reply += token
+ yield reply
diff --git a/modules/logging_colors.py b/modules/logging_colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c9714f7cd08f88f30335dfc0b7a694879414a68
--- /dev/null
+++ b/modules/logging_colors.py
@@ -0,0 +1,109 @@
+# Copied from https://stackoverflow.com/a/1336640
+
+import logging
+import platform
+
+
+def add_coloring_to_emit_windows(fn):
+ # add methods we need to the class
+ def _out_handle(self):
+ import ctypes
+ return ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
+ out_handle = property(_out_handle)
+
+ def _set_color(self, code):
+ import ctypes
+
+ # Constants from the Windows API
+ self.STD_OUTPUT_HANDLE = -11
+ hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
+ ctypes.windll.kernel32.SetConsoleTextAttribute(hdl, code)
+
+ setattr(logging.StreamHandler, '_set_color', _set_color)
+
+ def new(*args):
+ FOREGROUND_BLUE = 0x0001 # text color contains blue.
+ FOREGROUND_GREEN = 0x0002 # text color contains green.
+ FOREGROUND_RED = 0x0004 # text color contains red.
+ FOREGROUND_INTENSITY = 0x0008 # text color is intensified.
+ FOREGROUND_WHITE = FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED
+ # winbase.h
+ # STD_INPUT_HANDLE = -10
+ # STD_OUTPUT_HANDLE = -11
+ # STD_ERROR_HANDLE = -12
+
+ # wincon.h
+ # FOREGROUND_BLACK = 0x0000
+ FOREGROUND_BLUE = 0x0001
+ FOREGROUND_GREEN = 0x0002
+ # FOREGROUND_CYAN = 0x0003
+ FOREGROUND_RED = 0x0004
+ FOREGROUND_MAGENTA = 0x0005
+ FOREGROUND_YELLOW = 0x0006
+ # FOREGROUND_GREY = 0x0007
+ FOREGROUND_INTENSITY = 0x0008 # foreground color is intensified.
+
+ # BACKGROUND_BLACK = 0x0000
+ # BACKGROUND_BLUE = 0x0010
+ # BACKGROUND_GREEN = 0x0020
+ # BACKGROUND_CYAN = 0x0030
+ # BACKGROUND_RED = 0x0040
+ # BACKGROUND_MAGENTA = 0x0050
+ BACKGROUND_YELLOW = 0x0060
+ # BACKGROUND_GREY = 0x0070
+ BACKGROUND_INTENSITY = 0x0080 # background color is intensified.
+
+ levelno = args[1].levelno
+ if (levelno >= 50):
+ color = BACKGROUND_YELLOW | FOREGROUND_RED | FOREGROUND_INTENSITY | BACKGROUND_INTENSITY
+ elif (levelno >= 40):
+ color = FOREGROUND_RED | FOREGROUND_INTENSITY
+ elif (levelno >= 30):
+ color = FOREGROUND_YELLOW | FOREGROUND_INTENSITY
+ elif (levelno >= 20):
+ color = FOREGROUND_GREEN
+ elif (levelno >= 10):
+ color = FOREGROUND_MAGENTA
+ else:
+ color = FOREGROUND_WHITE
+ args[0]._set_color(color)
+
+ ret = fn(*args)
+ args[0]._set_color(FOREGROUND_WHITE)
+ # print "after"
+ return ret
+ return new
+
+
+def add_coloring_to_emit_ansi(fn):
+ # add methods we need to the class
+ def new(*args):
+ levelno = args[1].levelno
+ if (levelno >= 50):
+ color = '\x1b[31m' # red
+ elif (levelno >= 40):
+ color = '\x1b[31m' # red
+ elif (levelno >= 30):
+ color = '\x1b[33m' # yellow
+ elif (levelno >= 20):
+ color = '\x1b[32m' # green
+ elif (levelno >= 10):
+ color = '\x1b[35m' # pink
+ else:
+ color = '\x1b[0m' # normal
+ args[1].msg = color + args[1].msg + '\x1b[0m' # normal
+ # print "after"
+ return fn(*args)
+ return new
+
+
+if platform.system() == 'Windows':
+ # Windows does not support ANSI escapes and we are using API calls to set the console color
+ logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
+else:
+ # all non-Windows platforms are supporting ANSI escapes so we use them
+ logging.StreamHandler.emit = add_coloring_to_emit_ansi(logging.StreamHandler.emit)
+ # log = logging.getLogger()
+ # log.addFilter(log_filter())
+ # //hdlr = logging.StreamHandler()
+ # //hdlr.setFormatter(formatter())
diff --git a/modules/models.py b/modules/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d00ef8708fb9e3e8544da58b6a32b6926144678
--- /dev/null
+++ b/modules/models.py
@@ -0,0 +1,262 @@
+import gc
+import json
+import logging
+import os
+import re
+import time
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+import transformers
+from accelerate import infer_auto_device_map, init_empty_weights
+from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
+ AutoModelForSeq2SeqLM, AutoTokenizer,
+ BitsAndBytesConfig, LlamaTokenizer)
+
+import modules.shared as shared
+from modules import llama_attn_hijack
+
+transformers.logging.set_verbosity_error()
+
+local_rank = None
+if shared.args.deepspeed:
+ import deepspeed
+ from transformers.deepspeed import (HfDeepSpeedConfig,
+ is_deepspeed_zero3_enabled)
+
+ from modules.deepspeed_parameters import generate_ds_config
+
+ # Distributed setup
+ local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ torch.cuda.set_device(local_rank)
+ deepspeed.init_distributed()
+ ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
+ dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
+
+
+# Some models require special treatment in various parts of the code.
+# This function detects those models
+def find_model_type(model_name):
+ path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
+ if not path_to_model.exists():
+ return 'None'
+
+ model_name_lower = model_name.lower()
+ if 'rwkv-' in model_name_lower:
+ return 'rwkv'
+ elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
+ return 'llamacpp'
+ elif re.match('.*ggml.*\.bin', model_name_lower):
+ return 'llamacpp'
+ elif 'chatglm' in model_name_lower:
+ return 'chatglm'
+ elif 'galactica' in model_name_lower:
+ return 'galactica'
+ elif 'llava' in model_name_lower:
+ return 'llava'
+ elif 'oasst' in model_name_lower:
+ return 'oasst'
+ elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
+ return 'gpt4chan'
+ else:
+ config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
+ # Not a "catch all", but fairly accurate
+ if config.to_dict().get("is_encoder_decoder", False):
+ return 'HF_seq2seq'
+ else:
+ return 'HF_generic'
+
+
+def load_model(model_name):
+ logging.info(f"Loading {model_name}...")
+ t0 = time.time()
+
+ shared.model_type = find_model_type(model_name)
+ if shared.model_type == 'None':
+ logging.error('The path to the model does not exist. Exiting.')
+ return None, None
+
+ if shared.args.autogptq:
+ load_func = AutoGPTQ_loader
+ elif shared.args.wbits > 0:
+ load_func = GPTQ_loader
+ elif shared.model_type == 'llamacpp':
+ load_func = llamacpp_loader
+ elif shared.model_type == 'rwkv':
+ load_func = RWKV_loader
+ elif shared.args.flexgen:
+ load_func = flexgen_loader
+ else:
+ load_func = huggingface_loader
+
+ output = load_func(model_name)
+ if type(output) is tuple:
+ model, tokenizer = output
+ else:
+ model = output
+ tokenizer = load_tokenizer(model_name, model)
+
+ # Hijack attention with xformers
+ if any((shared.args.xformers, shared.args.sdp_attention)):
+ llama_attn_hijack.hijack_llama_attention()
+
+ logging.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.\n")
+ return model, tokenizer
+
+
+def load_tokenizer(model_name, model):
+ tokenizer = None
+ if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
+ tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
+ elif type(model) is transformers.LlamaForCausalLM:
+ # Try to load an universal LLaMA tokenizer
+ if shared.model_type not in ['llava', 'oasst']:
+ for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
+ if p.exists():
+ logging.info(f"Loading the universal LLaMA tokenizer from {p}...")
+ tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
+ return tokenizer
+
+ # Otherwise, load it from the model folder and hope that these
+ # are not outdated tokenizer files.
+ tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), clean_up_tokenization_spaces=True)
+ try:
+ tokenizer.eos_token_id = 2
+ tokenizer.bos_token_id = 1
+ tokenizer.pad_token_id = 0
+ except:
+ pass
+ else:
+ path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
+ if path_to_model.exists():
+ tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
+
+ return tokenizer
+
+
+
+def flexgen_loader(model_name):
+ from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
+
+ # Initialize environment
+ env = ExecutionEnv.create(shared.args.disk_cache_dir)
+
+ # Offloading policy
+ policy = Policy(1, 1,
+ shared.args.percent[0], shared.args.percent[1],
+ shared.args.percent[2], shared.args.percent[3],
+ shared.args.percent[4], shared.args.percent[5],
+ overlap=True, sep_layer=True, pin_weight=shared.args.pin_weight,
+ cpu_cache_compute=False, attn_sparsity=1.0,
+ compress_weight=shared.args.compress_weight,
+ comp_weight_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=0, symmetric=False),
+ compress_cache=False,
+ comp_cache_config=CompressionConfig(
+ num_bits=4, group_size=64,
+ group_dim=2, symmetric=False))
+
+ model = OptLM(f"facebook/{model_name}", env, shared.args.model_dir, policy)
+ return model
+
+
+def RWKV_loader(model_name):
+ from modules.RWKV import RWKVModel, RWKVTokenizer
+
+ model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
+ tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir))
+ return model, tokenizer
+
+
+def llamacpp_loader(model_name):
+ from modules.llamacpp_model import LlamaCppModel
+
+ path = Path(f'{shared.args.model_dir}/{model_name}')
+ if path.is_file():
+ model_file = path
+ else:
+ model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
+
+ logging.info(f"llama.cpp weights detected: {model_file}\n")
+ model, tokenizer = LlamaCppModel.from_pretrained(model_file)
+ return model, tokenizer
+
+
+def GPTQ_loader(model_name):
+
+ # Monkey patch
+ if shared.args.monkey_patch:
+ logging.warning("Applying the monkey patch for using LoRAs in 4-bit mode. It may cause undefined behavior outside its intended scope.")
+ from modules.monkey_patch_gptq_lora import load_model_llama
+
+ model, _ = load_model_llama(model_name)
+
+ # No monkey patch
+ else:
+ import modules.GPTQ_loader
+
+ model = modules.GPTQ_loader.load_quantized(model_name)
+
+ return model
+
+
+def AutoGPTQ_loader(model_name):
+ import modules.AutoGPTQ_loader
+
+ return modules.AutoGPTQ_loader.load_quantized(model_name)
+
+
+def get_max_memory_dict():
+ max_memory = {}
+
+ return max_memory if len(max_memory) > 0 else None
+
+
+def clear_torch_cache():
+ gc.collect()
+ if not shared.args.cpu:
+ torch.cuda.empty_cache()
+
+
+def unload_model():
+ shared.model = shared.tokenizer = None
+ clear_torch_cache()
+
+
+def reload_model():
+ unload_model()
+ shared.model, shared.tokenizer = load_model(shared.model_name)
+
+
+def load_soft_prompt(name):
+ if name == 'None':
+ shared.soft_prompt = False
+ shared.soft_prompt_tensor = None
+ else:
+ with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
+ zf.extract('tensor.npy')
+ zf.extract('meta.json')
+ j = json.loads(open('meta.json', 'r').read())
+ logging.info(f"\nLoading the softprompt \"{name}\".")
+ for field in j:
+ if field != 'name':
+ if type(j[field]) is list:
+ logging.info(f"{field}: {', '.join(j[field])}")
+ else:
+ logging.info(f"{field}: {j[field]}")
+
+ logging.info()
+ tensor = np.load('tensor.npy')
+ Path('tensor.npy').unlink()
+ Path('meta.json').unlink()
+
+ tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
+ tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
+ shared.soft_prompt = True
+ shared.soft_prompt_tensor = tensor
+
+ return name
diff --git a/modules/monkey_patch_gptq_lora.py b/modules/monkey_patch_gptq_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..a37e790671f513b6a5744cc469424a967a75d43b
--- /dev/null
+++ b/modules/monkey_patch_gptq_lora.py
@@ -0,0 +1,39 @@
+# Copied from https://github.com/johnsmith0031/alpaca_lora_4bit
+
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path("repositories/alpaca_lora_4bit")))
+
+import autograd_4bit
+from amp_wrapper import AMPWrapper
+from autograd_4bit import (Autograd4bitQuantLinear,
+ load_llama_model_4bit_low_ram)
+from monkeypatch.peft_tuners_lora_monkey_patch import (
+ Linear4bitLt, replace_peft_model_with_gptq_lora_model)
+
+from modules import shared
+from modules.GPTQ_loader import find_quantized_model_file
+
+replace_peft_model_with_gptq_lora_model()
+
+
+def load_model_llama(model_name):
+ config_path = str(Path(f'{shared.args.model_dir}/{model_name}'))
+ model_path = str(find_quantized_model_file(model_name))
+ model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path, groupsize=shared.args.groupsize, is_v1_model=False)
+ for n, m in model.named_modules():
+ if isinstance(m, Autograd4bitQuantLinear) or isinstance(m, Linear4bitLt):
+ if m.is_v1_model:
+ m.zeros = m.zeros.half()
+ m.scales = m.scales.half()
+ m.bias = m.bias.half()
+
+ autograd_4bit.use_new = True
+ autograd_4bit.auto_switch = True
+
+ model.half()
+ wrapper = AMPWrapper(model)
+ wrapper.apply_generate()
+
+ return model, tokenizer
diff --git a/modules/shared.py b/modules/shared.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f945366cf7aca78b8a2a87b749964d038107f21
--- /dev/null
+++ b/modules/shared.py
@@ -0,0 +1,230 @@
+import argparse
+import logging
+from collections import OrderedDict
+from pathlib import Path
+
+import yaml
+
+model = None
+tokenizer = None
+model_name = "None"
+model_type = None
+lora_names = []
+soft_prompt_tensor = None
+soft_prompt = False
+
+# Chat variables
+history = {'internal': [], 'visible': []}
+character = 'None'
+stop_everything = False
+processing_message = '*Is typing...*'
+
+# UI elements (buttons, sliders, HTML, etc)
+gradio = {}
+
+# For keeping the values of UI elements on page reload
+persistent_interface_state = {}
+
+input_params = [] # Generation input parameters
+reload_inputs = [] # Parameters for reloading the chat interface
+
+# For restarting the interface
+need_restart = False
+
+settings = {
+ 'autoload_model': True,
+ 'max_new_tokens': 200,
+ 'max_new_tokens_min': 1,
+ 'max_new_tokens_max': 2000,
+ 'seed': -1,
+ 'character': 'None',
+ 'name1': 'You',
+ 'name2': 'Assistant',
+ 'context': 'This is a conversation with your Assistant. It is a computer program designed to help you with various tasks such as answering questions, providing recommendations, and helping with decision making. You can ask it anything you want and it will do its best to give you accurate and relevant information.',
+ 'greeting': '',
+ 'turn_template': '',
+ 'custom_stopping_strings': '',
+ 'stop_at_newline': False,
+ 'add_bos_token': True,
+ 'ban_eos_token': False,
+ 'skip_special_tokens': True,
+ 'truncation_length': 2048,
+ 'truncation_length_min': 0,
+ 'truncation_length_max': 8192,
+ 'mode': 'chat',
+ 'chat_style': 'cai-chat',
+ 'instruction_template': 'None',
+ 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
+ 'chat_prompt_size': 2048,
+ 'chat_prompt_size_min': 0,
+ 'chat_prompt_size_max': 2048,
+ 'chat_generation_attempts': 1,
+ 'chat_generation_attempts_min': 1,
+ 'chat_generation_attempts_max': 10,
+ 'default_extensions': [],
+ 'chat_default_extensions': ["gallery"],
+ 'presets': {
+ 'default': 'Default',
+ '.*(alpaca|llama|llava)': "LLaMA-Precise",
+ '.*pygmalion': 'NovelAI-Storywriter',
+ '.*RWKV': 'Naive',
+ '.*moss': 'MOSS',
+ },
+ 'prompts': {
+ 'default': 'QA',
+ '.*(gpt4chan|gpt-4chan|4chan)': 'GPT-4chan',
+ }
+}
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
+
+# Basic settings
+parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
+parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
+parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.')
+parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--lora', type=str, nargs="+", help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.')
+parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
+parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
+parser.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.')
+parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.')
+parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.')
+parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
+parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
+
+# Accelerate/transformers
+parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
+parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
+parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
+parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
+parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
+parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directory to save the disk cache to. Defaults to "cache".')
+parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
+parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
+parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
+parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
+parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_remote_code=True while loading a model. Necessary for ChatGLM.")
+
+# llama.cpp
+parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
+parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
+parser.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
+parser.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
+parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
+parser.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
+
+# GPTQ
+parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
+parser.add_argument('--model_type', type=str, help='Model type of pre-quantized model. Currently LLaMA, OPT, and GPT-J are supported.')
+parser.add_argument('--groupsize', type=int, default=-1, help='Group size.')
+parser.add_argument('--pre_layer', type=int, nargs="+", help='The number of layers to allocate to the GPU. Setting this parameter enables CPU offloading for 4-bit models. For multi-gpu, write the numbers separated by spaces, eg --pre_layer 30 60.')
+parser.add_argument('--checkpoint', type=str, help='The path to the quantized checkpoint file. If not specified, it will be automatically detected.')
+parser.add_argument('--monkey-patch', action='store_true', help='Apply the monkey patch for using LoRAs with quantized models.')
+parser.add_argument('--quant_attn', action='store_true', help='(triton) Enable quant attention.')
+parser.add_argument('--warmup_autotune', action='store_true', help='(triton) Enable warmup autotune.')
+parser.add_argument('--fused_mlp', action='store_true', help='(triton) Enable fused mlp.')
+
+# AutoGPTQ
+parser.add_argument('--autogptq', action='store_true', help='Use AutoGPTQ for loading quantized models instead of the internal GPTQ loader.')
+parser.add_argument('--triton', action='store_true', help='Use triton.')
+
+# FlexGen
+parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
+parser.add_argument('--percent', type=int, nargs="+", default=[0, 100, 100, 0, 100, 0], help='FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0).')
+parser.add_argument("--compress-weight", action="store_true", help="FlexGen: activate weight compression.")
+parser.add_argument("--pin-weight", type=str2bool, nargs="?", const=True, default=True, help="FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%%).")
+
+# DeepSpeed
+parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
+parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
+parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
+
+# RWKV
+parser.add_argument('--rwkv-strategy', type=str, default=None, help='RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8".')
+parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile the CUDA kernel for better performance.')
+
+# Gradio
+parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
+parser.add_argument('--listen-host', type=str, help='The hostname that the server will use.')
+parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.')
+parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.')
+parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
+parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
+
+# API
+parser.add_argument('--api', action='store_true', help='Enable the API extension.')
+parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
+parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
+parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
+
+# Multimodal
+parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
+
+args = parser.parse_args()
+args_defaults = parser.parse_args([])
+
+# Deprecation warnings for parameters that have been renamed
+deprecated_dict = {}
+for k in deprecated_dict:
+ if getattr(args, k) != deprecated_dict[k][1]:
+ logging.warning(f"--{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
+ setattr(args, deprecated_dict[k][0], getattr(args, k))
+
+# Security warnings
+if args.trust_remote_code:
+ logging.warning("trust_remote_code is enabled. This is dangerous.")
+if args.share:
+ logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
+
+
+def add_extension(name):
+ if args.extensions is None:
+ args.extensions = [name]
+ elif 'api' not in args.extensions:
+ args.extensions.append(name)
+
+
+# Activating the API extension
+if args.api or args.public_api:
+ add_extension('api')
+
+# Activating the multimodal extension
+if args.multimodal_pipeline is not None:
+ add_extension('multimodal')
+
+
+def is_chat():
+ return args.chat
+
+
+# Loading model-specific settings
+with Path(f'{args.model_dir}/config.yaml') as p:
+ if p.exists():
+ model_config = yaml.safe_load(open(p, 'r').read())
+ else:
+ model_config = {}
+
+# Applying user-defined model settings
+with Path(f'{args.model_dir}/config-user.yaml') as p:
+ if p.exists():
+ user_config = yaml.safe_load(open(p, 'r').read())
+ for k in user_config:
+ if k in model_config:
+ model_config[k].update(user_config[k])
+ else:
+ model_config[k] = user_config[k]
+
+model_config = OrderedDict(model_config)
diff --git a/modules/text_generation.py b/modules/text_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc5d25dc49c96c02a2f5462c8f8f07adb4c4185
--- /dev/null
+++ b/modules/text_generation.py
@@ -0,0 +1,384 @@
+import ast
+import logging
+import random
+import re
+import time
+import traceback
+
+import numpy as np
+import torch
+import transformers
+
+import modules.shared as shared
+from modules.callbacks import (Iteratorize, Stream,
+ _SentinelTokenStoppingCriteria)
+from modules.extensions import apply_extensions
+from modules.html_generator import generate_4chan_html, generate_basic_html
+from modules.models import clear_torch_cache, local_rank
+
+
+def get_max_prompt_length(state):
+ max_length = state['truncation_length'] - state['max_new_tokens']
+ if shared.soft_prompt:
+ max_length -= shared.soft_prompt_tensor.shape[1]
+
+ return max_length
+
+
+def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
+ if shared.model_type in ['rwkv', 'llamacpp']:
+ input_ids = shared.tokenizer.encode(str(prompt))
+ input_ids = np.array(input_ids).reshape(1, len(input_ids))
+ return input_ids
+ else:
+ input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
+
+ # This is a hack for making replies more creative.
+ if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
+ input_ids = input_ids[:, 1:]
+
+ # Llama adds this extra token when the first character is '\n', and this
+ # compromises the stopping criteria, so we just remove it
+ if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
+ input_ids = input_ids[:, 1:]
+
+ # Handling truncation
+ if truncation_length is not None:
+ input_ids = input_ids[:, -truncation_length:]
+
+ if shared.model_type in ['rwkv', 'llamacpp'] or shared.args.cpu:
+ return input_ids
+ elif shared.args.flexgen:
+ return input_ids.numpy()
+ elif shared.args.deepspeed:
+ return input_ids.to(device=local_rank)
+ elif torch.has_mps:
+ device = torch.device('mps')
+ return input_ids.to(device)
+ else:
+ return input_ids.cuda()
+
+
+def get_encoded_length(prompt):
+ length_after_extensions = apply_extensions('tokenized_length', prompt)
+ if length_after_extensions is not None:
+ return length_after_extensions
+
+ return len(encode(prompt)[0])
+
+
+def decode(output_ids, skip_special_tokens=True):
+ return shared.tokenizer.decode(output_ids, skip_special_tokens)
+
+
+def generate_softprompt_input_tensors(input_ids):
+ inputs_embeds = shared.model.transformer.wte(input_ids)
+ inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
+ filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
+ # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
+ return inputs_embeds, filler_input_ids
+
+
+# Removes empty replies from gpt4chan outputs
+def fix_gpt4chan(s):
+ for i in range(10):
+ s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
+ s = re.sub("--- [0-9]*\n *\n---", "---", s)
+ s = re.sub("--- [0-9]*\n\n\n---", "---", s)
+
+ return s
+
+
+# Fix the LaTeX equations in galactica
+def fix_galactica(s):
+ s = s.replace(r'\[', r'$')
+ s = s.replace(r'\]', r'$')
+ s = s.replace(r'\(', r'$')
+ s = s.replace(r'\)', r'$')
+ s = s.replace(r'$$', r'$')
+ s = re.sub(r'\n', r'\n\n', s)
+ s = re.sub(r"\n{3,}", "\n\n", s)
+ return s
+
+
+def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
+ if shared.model_type == 'HF_seq2seq':
+ reply = decode(output_ids, state['skip_special_tokens'])
+ else:
+ new_tokens = len(output_ids) - len(input_ids[0])
+ reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
+
+ # Prevent LlamaTokenizer from skipping a space
+ if type(shared.tokenizer) is transformers.LlamaTokenizer and len(output_ids) > 0:
+ if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'):
+ reply = ' ' + reply
+
+ if not is_chat:
+ reply = apply_extensions('output', reply)
+
+ return reply
+
+
+def formatted_outputs(reply, model_name):
+ if shared.model_type == 'galactica':
+ reply = fix_galactica(reply)
+ return reply, reply, generate_basic_html(reply)
+ elif shared.model_type == 'gpt4chan':
+ reply = fix_gpt4chan(reply)
+ return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
+ else:
+ return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
+
+
+def set_manual_seed(seed):
+ seed = int(seed)
+ if seed == -1:
+ seed = random.randint(1, 2**31)
+
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+
+ return seed
+
+
+def stop_everything_event():
+ shared.stop_everything = True
+
+
+def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
+ for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
+ if shared.model_type not in ['HF_seq2seq']:
+ reply = question + reply
+
+ yield formatted_outputs(reply, shared.model_name)
+
+
+def generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
+ state = apply_extensions('state', state)
+ generate_func = apply_extensions('custom_generate_reply')
+ if generate_func is None:
+ if shared.model_name == 'None' or shared.model is None:
+ logging.error("No model is loaded! Select one in the Model tab.")
+ yield question
+ return
+
+ if shared.model_type in ['rwkv', 'llamacpp']:
+ generate_func = generate_reply_custom
+ elif shared.args.flexgen:
+ generate_func = generate_reply_flexgen
+ else:
+ generate_func = generate_reply_HF
+
+ # Preparing the input
+ original_question = question
+ if not is_chat:
+ question = apply_extensions('input', question)
+
+ if shared.args.verbose:
+ print(f'\n\n{question}\n--------------------\n')
+
+ shared.stop_everything = False
+ clear_torch_cache()
+ seed = set_manual_seed(state['seed'])
+ for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings, is_chat=is_chat):
+ yield reply
+
+
+def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
+ generate_params = {}
+ for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
+ generate_params[k] = state[k]
+
+ if state['ban_eos_token']:
+ generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
+
+ if shared.args.no_cache:
+ generate_params.update({'use_cache': False})
+
+ if shared.args.deepspeed:
+ generate_params.update({'synced_gpus': True})
+
+ # Encode the input
+ input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
+ output = input_ids[0]
+ cuda = not any((shared.args.cpu, shared.args.deepspeed))
+
+ # Find the eos tokens
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+ if eos_token is not None:
+ eos_token_ids.append(int(encode(eos_token)[0][-1]))
+
+ # Add the encoded tokens to generate_params
+ if shared.soft_prompt:
+ inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+ question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
+ original_input_ids = input_ids
+ generate_params.update({'inputs_embeds': inputs_embeds})
+ generate_params.update({'inputs': filler_input_ids})
+ else:
+ question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
+ original_input_ids = input_ids
+ generate_params.update({'inputs': input_ids})
+ if inputs_embeds is not None:
+ generate_params.update({'inputs_embeds': inputs_embeds})
+
+ # Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
+ stopping_criteria_list = transformers.StoppingCriteriaList()
+ for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
+ if type(st) is list and len(st) > 0:
+ sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
+ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
+ break
+
+ # Update generate_params with the eos token and the stopping strings
+ generate_params['eos_token_id'] = eos_token_ids
+ generate_params['stopping_criteria'] = stopping_criteria_list
+
+ t0 = time.time()
+ try:
+ if not is_chat and shared.model_type != 'HF_seq2seq':
+ yield ''
+
+ # Generate the entire reply at once.
+ if not state['stream']:
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+ if cuda:
+ output = output.cuda()
+
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
+
+ # Stream the reply 1 token at a time.
+ # This is based on the trick of using 'stopping_criteria' to create an iterator.
+ else:
+
+ def generate_with_callback(callback=None, **kwargs):
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
+ clear_torch_cache()
+ with torch.no_grad():
+ shared.model.generate(**kwargs)
+
+ def generate_with_streaming(**kwargs):
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
+
+ with generate_with_streaming(**generate_params) as generator:
+ for output in generator:
+ if shared.soft_prompt:
+ output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
+
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
+ if output[-1] in eos_token_ids:
+ break
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(original_input_ids[0])
+ new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
+
+
+def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
+ seed = set_manual_seed(state['seed'])
+ generate_params = {'token_count': state['max_new_tokens']}
+ for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
+ generate_params[k] = state[k]
+
+ t0 = time.time()
+ try:
+ if not is_chat:
+ yield ''
+
+ if not state['stream']:
+ reply = shared.model.generate(context=question, **generate_params)
+ if not is_chat:
+ reply = apply_extensions('output', reply)
+
+ yield reply
+ else:
+ for reply in shared.model.generate_with_streaming(context=question, **generate_params):
+ if not is_chat:
+ reply = apply_extensions('output', reply)
+
+ yield reply
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(encode(original_question)[0])
+ new_tokens = len(encode(original_question + reply)[0]) - original_tokens
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
+
+
+def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
+ generate_params = {}
+ for k in ['max_new_tokens', 'do_sample', 'temperature']:
+ generate_params[k] = state[k]
+
+ if state['stream']:
+ generate_params['max_new_tokens'] = 8
+
+ # Encode the input
+ input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
+ output = input_ids[0]
+
+ # Find the eos tokens
+ eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
+ if eos_token is not None:
+ eos_token_ids.append(int(encode(eos_token)[0][-1]))
+
+ # Add the encoded tokens to generate_params
+ question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
+ original_input_ids = input_ids
+ generate_params.update({'inputs': input_ids})
+ if inputs_embeds is not None:
+ generate_params.update({'inputs_embeds': inputs_embeds})
+
+ # Update generate_params with the eos token and the stopping strings
+ generate_params['stop'] = eos_token_ids[-1]
+
+ t0 = time.time()
+ try:
+ if not is_chat:
+ yield ''
+
+ # Generate the entire reply at once.
+ if not state['stream']:
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+
+ yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
+
+ # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
+ else:
+ for i in range(state['max_new_tokens'] // 8 + 1):
+ if shared.stop_everything:
+ break
+
+ clear_torch_cache()
+ with torch.no_grad():
+ output = shared.model.generate(**generate_params)[0]
+
+ if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
+ break
+
+ yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
+ input_ids = np.reshape(output, (1, output.shape[0]))
+ generate_params.update({'inputs': input_ids})
+
+ except Exception:
+ traceback.print_exc()
+ finally:
+ t1 = time.time()
+ original_tokens = len(original_input_ids[0])
+ new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
+ print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
+ return
diff --git a/modules/training.py b/modules/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2410edcccc0d66f87508534d9df5bf7bdb18f3f
--- /dev/null
+++ b/modules/training.py
@@ -0,0 +1,491 @@
+import json
+import logging
+import math
+import sys
+import threading
+import time
+import traceback
+from pathlib import Path
+
+import gradio as gr
+import torch
+import transformers
+from datasets import Dataset, load_dataset
+from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
+ set_peft_model_state_dict)
+
+from modules import shared, ui, utils
+from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
+
+
+# This mapping is from a very recent commit, not yet released.
+# If not available, default to a backup map for some common model types.
+try:
+ from peft.utils.other import \
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
+ model_to_lora_modules
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
+ MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
+except:
+ standard_modules = ["q_proj", "v_proj"]
+ model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
+ MODEL_CLASSES = {
+ "LlamaForCausalLM": "llama",
+ "OPTForCausalLM": "opt",
+ "GPTJForCausalLM": "gptj",
+ "GPTNeoXForCausalLM": "gpt_neox"
+ }
+
+WANT_INTERRUPT = False
+
+PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
+
+
+def create_train_interface():
+ with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
+ gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
+
+ with gr.Row():
+ lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
+ always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
+ save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a checkpoint of the LoRA will be saved every time this many steps pass.')
+
+ with gr.Row():
+ copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras())
+ ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button')
+
+ with gr.Row():
+ # TODO: Implement multi-device support.
+ micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.')
+ batch_size = gr.Slider(label='Batch Size', value=128, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.')
+
+ with gr.Row():
+ epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.')
+ learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='Learning rate, in scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.')
+ lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.')
+
+ # TODO: What is the actual maximum rank? Likely distinct per model. This might be better to somehow be on a log scale.
+ lora_rank = gr.Slider(label='LoRA Rank', value=32, minimum=0, maximum=1024, step=4, info='LoRA Rank, or dimension count. Higher values produce a larger file with better control over the model\'s content. Smaller values produce a smaller file with less overall control. Small values like 4 or 8 are great for stylistic guidance, higher values like 128 or 256 are good for teaching content upgrades, extremely high values (1024+) are difficult to train but may improve fine-detail learning for large datasets. Higher ranks also require higher VRAM.')
+ lora_alpha = gr.Slider(label='LoRA Alpha', value=64, minimum=0, maximum=2048, step=4, info='LoRA Alpha. This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.')
+
+ cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=2048, value=256, step=32, info='Cutoff length for text input. Essentially, how long of a line of text to feed in at a time. Higher values require drastically more VRAM.')
+
+ with gr.Tab(label='Formatted Dataset'):
+ with gr.Row():
+ dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
+ ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
+ eval_dataset = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
+ ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'json')}, 'refresh-button')
+ format = gr.Dropdown(choices=utils.get_datasets('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
+ ui.create_refresh_button(format, lambda: None, lambda: {'choices': utils.get_datasets('training/formats', 'json')}, 'refresh-button')
+
+ eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
+
+ with gr.Tab(label="Raw text file"):
+ with gr.Row():
+ raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
+ ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
+
+ with gr.Row():
+ overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
+ newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
+
+ with gr.Accordion(label='Advanced Options', open=False):
+ lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.')
+ warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate will be lower than normal. This helps the trainer prepare the model and precompute statistics to improve the quality of training after the start.')
+ optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
+
+ with gr.Row():
+ higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
+
+ with gr.Row():
+ start_button = gr.Button("Start LoRA Training")
+ stop_button = gr.Button("Interrupt")
+
+ output = gr.Markdown(value="Ready")
+
+ with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
+ with gr.Row():
+ with gr.Column():
+ models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True)
+ evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
+ with gr.Row():
+ stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
+ max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
+
+ with gr.Row():
+ start_current_evaluation = gr.Button("Evaluate loaded model")
+ start_evaluation = gr.Button("Evaluate selected models")
+ stop_evaluation = gr.Button("Interrupt")
+
+ with gr.Column():
+ evaluation_log = gr.Markdown(value='')
+
+ evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
+ save_comments = gr.Button('Save comments')
+
+ # Training events
+ all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer]
+ copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
+ start_button.click(do_train, all_params, output)
+ stop_button.click(do_interrupt, None, None, queue=False)
+ higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
+
+ # Evaluation events. For some reason, the interrupt event
+ # doesn't work with the .then() syntax, so I write them one
+ # by one in this ugly but functional way.
+ ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
+ start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
+
+ tmp = gr.State('')
+ start_current_evaluation.click(lambda: ['current model'], None, tmp)
+ ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
+ start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
+
+ stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
+ save_comments.click(
+ save_past_evaluations, evaluation_table, None).then(
+ lambda: "Comments saved.", None, evaluation_log, show_progress=False)
+
+
+def do_interrupt():
+ global WANT_INTERRUPT
+ WANT_INTERRUPT = True
+
+
+def do_copy_params(lora_name: str, *args):
+ f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json"
+ if Path(f_name).is_file():
+ with open(f_name, 'r', encoding='utf-8') as format_file:
+ params: dict[str, str] = json.load(format_file)
+ else:
+ params = {}
+
+ result = list()
+ for i in range(0, len(PARAMETERS)):
+ key = PARAMETERS[i]
+ if key in params:
+ result.append(params[key])
+ else:
+ result.append(args[i])
+
+ return result
+
+
+def change_rank_limit(use_higher_ranks: bool):
+ mult = 2 if use_higher_ranks else 1
+ return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"}
+
+
+def clean_path(base_path: str, path: str):
+ """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
+ # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
+ # Or swap it to a strict whitelist of [a-zA-Z_0-9]
+ path = path.replace('\\', '/').replace('..', '_')
+ if base_path is None:
+ return path
+
+ return f'{Path(base_path).absolute()}/{path}'
+
+
+def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
+
+ if shared.args.monkey_patch:
+ from monkeypatch.peft_tuners_lora_monkey_patch import \
+ replace_peft_model_with_gptq_lora_model
+ replace_peft_model_with_gptq_lora_model()
+
+ global WANT_INTERRUPT
+ WANT_INTERRUPT = False
+
+ # == Input validation / processing ==
+ yield "Prepping..."
+ lora_file_path = clean_path(None, lora_name)
+ if lora_file_path.strip() == '':
+ yield "Missing or invalid LoRA file name input."
+ return
+
+ lora_file_path = f"{shared.args.lora_dir}/{lora_file_path}"
+ actual_lr = float(learning_rate)
+ model_type = type(shared.model).__name__
+
+ if model_type in MODEL_CLASSES:
+ model_id = MODEL_CLASSES[model_type]
+ else:
+ model_id = "llama"
+ if model_type == "PeftModelForCausalLM":
+ if len(shared.args.lora_names) > 0:
+ yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
+ else:
+ yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
+ else:
+ yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+ logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
+
+ time.sleep(5)
+
+ if shared.args.wbits > 0 and not shared.args.monkey_patch:
+ yield "LoRA training in 4-bit requires loading with `--monkey-patch`"
+ return
+
+ elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
+ yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
+ logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
+ time.sleep(2) # Give it a moment for the message to show in UI before continuing
+
+ if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
+ yield "Cannot input zeroes."
+ return
+
+ gradient_accumulation_steps = batch_size // micro_batch_size
+ shared.tokenizer.pad_token_id = 0
+ shared.tokenizer.padding_side = "left"
+
+ def tokenize(prompt):
+ result = shared.tokenizer(prompt, truncation=True, max_length=cutoff_len + 1, padding="max_length")
+ return {
+ "input_ids": result["input_ids"][:-1],
+ "attention_mask": result["attention_mask"][:-1],
+ }
+
+ # == Prep the dataset, format, etc ==
+ if raw_text_file not in ['None', '']:
+ logging.info("Loading raw text file dataset...")
+ with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
+ raw_text = file.read()
+
+ tokens = shared.tokenizer.encode(raw_text)
+ del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+ tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
+ for i in range(1, len(tokens)):
+ tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
+
+ text_chunks = [shared.tokenizer.decode(x) for x in tokens]
+ del tokens
+ if newline_favor_len > 0:
+ text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
+
+ train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
+ del text_chunks
+ eval_data = None
+
+ else:
+ if dataset in ['None', '']:
+ yield "**Missing dataset choice input, cannot continue.**"
+ return
+
+ if format in ['None', '']:
+ yield "**Missing format choice input, cannot continue.**"
+ return
+
+ with open(clean_path('training/formats', f'{format}.json'), 'r', encoding='utf-8') as formatFile:
+ format_data: dict[str, str] = json.load(formatFile)
+
+ def generate_prompt(data_point: dict[str, str]):
+ for options, data in format_data.items():
+ if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)):
+ for key, val in data_point.items():
+ if val is not None:
+ data = data.replace(f'%{key}%', val)
+ return data
+ raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')
+
+ def generate_and_tokenize_prompt(data_point):
+ prompt = generate_prompt(data_point)
+ return tokenize(prompt)
+
+ logging.info("Loading JSON datasets...")
+ data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
+ train_data = data['train'].map(generate_and_tokenize_prompt)
+
+ if eval_dataset == 'None':
+ eval_data = None
+ else:
+ eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
+ eval_data = eval_data['train'].map(generate_and_tokenize_prompt)
+
+ # == Start prepping the model itself ==
+ if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
+ logging.info("Getting model ready...")
+ prepare_model_for_int8_training(shared.model)
+
+ logging.info("Prepping for training...")
+ config = LoraConfig(
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ target_modules=model_to_lora_modules[model_id],
+ lora_dropout=lora_dropout,
+ bias="none",
+ task_type="CAUSAL_LM"
+ )
+
+ try:
+ logging.info("Creating LoRA model...")
+ lora_model = get_peft_model(shared.model, config)
+ if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
+ logging.info("Loading existing LoRA data...")
+ state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
+ set_peft_model_state_dict(lora_model, state_dict_peft)
+ except:
+ yield traceback.format_exc()
+ return
+
+ if shared.args.monkey_patch:
+ for n, m in lora_model.named_modules():
+ if '4bit' in str(type(m)):
+ if m.is_v1_model:
+ m.zeros = m.zeros.half()
+
+ m.scales = m.scales.half()
+
+ class Tracked():
+ def __init__(self):
+ self.current_steps = 0
+ self.max_steps = 0
+ self.did_save = False
+
+ tracked = Tracked()
+ actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps)
+
+ class Callbacks(transformers.TrainerCallback):
+ def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+ tracked.current_steps = state.global_step * gradient_accumulation_steps
+ tracked.max_steps = state.max_steps * gradient_accumulation_steps
+ if WANT_INTERRUPT:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+ elif state.global_step > 0 and actual_save_steps > 0 and state.global_step % actual_save_steps == 0:
+ lora_model.save_pretrained(f"{lora_file_path}/checkpoint-{tracked.current_steps}/")
+
+ def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
+ tracked.current_steps += 1
+ if WANT_INTERRUPT:
+ control.should_epoch_stop = True
+ control.should_training_stop = True
+
+ trainer = transformers.Trainer(
+ model=lora_model,
+ train_dataset=train_data,
+ eval_dataset=eval_data,
+ args=transformers.TrainingArguments(
+ per_device_train_batch_size=micro_batch_size,
+ gradient_accumulation_steps=gradient_accumulation_steps,
+ warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
+ num_train_epochs=epochs,
+ learning_rate=actual_lr,
+ fp16=False if shared.args.cpu else True,
+ optim=optimizer,
+ logging_steps=5,
+ evaluation_strategy="steps" if eval_data is not None else "no",
+ eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None,
+ save_strategy="steps" if eval_data is not None else "no",
+ output_dir=lora_file_path,
+ lr_scheduler_type=lr_scheduler_type,
+ load_best_model_at_end=eval_data is not None,
+ # TODO: Enable multi-device support
+ ddp_find_unused_parameters=None,
+ no_cuda=shared.args.cpu
+ ),
+ data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
+ callbacks=list([Callbacks()])
+ )
+
+ lora_model.config.use_cache = False
+
+ if torch.__version__ >= "2" and sys.platform != "win32":
+ lora_model = torch.compile(lora_model)
+
+ # == Save parameters for reuse ==
+ with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file:
+ vars = locals()
+ json.dump({x: vars[x] for x in PARAMETERS}, file)
+
+ # == Main run and monitor loop ==
+ logging.info("Starting training...")
+ yield "Starting..."
+ if WANT_INTERRUPT:
+ yield "Interrupted before start."
+ return
+
+ def threaded_run():
+ trainer.train()
+ # Note: save in the thread in case the gradio thread breaks (eg browser closed)
+ lora_model.save_pretrained(lora_file_path)
+ logging.info("LoRA training run is completed and saved.")
+ tracked.did_save = True
+
+ thread = threading.Thread(target=threaded_run)
+ thread.start()
+ last_step = 0
+ start_time = time.perf_counter()
+
+ while thread.is_alive():
+ time.sleep(0.5)
+ if WANT_INTERRUPT:
+ yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
+
+ elif tracked.current_steps != last_step:
+ last_step = tracked.current_steps
+ time_elapsed = time.perf_counter() - start_time
+ if time_elapsed <= 0:
+ timer_info = ""
+ total_time_estimate = 999
+ else:
+ its = tracked.current_steps / time_elapsed
+ if its > 1:
+ timer_info = f"`{its:.2f}` it/s"
+ else:
+ timer_info = f"`{1.0/its:.2f}` s/it"
+
+ total_time_estimate = (1.0 / its) * (tracked.max_steps)
+
+ yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
+
+ # Saving in the train thread might fail if an error occurs, so save here if so.
+ if not tracked.did_save:
+ logging.info("Training complete, saving...")
+ lora_model.save_pretrained(lora_file_path)
+
+ if WANT_INTERRUPT:
+ logging.info("Training interrupted.")
+ yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
+ else:
+ logging.info("Training complete!")
+ yield f"Done! LoRA saved to `{lora_file_path}`"
+
+
+def split_chunks(arr, step):
+ for i in range(0, len(arr), step):
+ yield arr[i:i + step]
+
+
+def cut_chunk_for_newline(chunk: str, max_length: int):
+ if '\n' not in chunk:
+ return chunk
+
+ first_newline = chunk.index('\n')
+ if first_newline < max_length:
+ chunk = chunk[first_newline + 1:]
+
+ if '\n' not in chunk:
+ return chunk
+
+ last_newline = chunk.rindex('\n')
+ if len(chunk) - last_newline < max_length:
+ chunk = chunk[:last_newline]
+
+ return chunk
+
+
+def format_time(seconds: float):
+ if seconds < 120:
+ return f"`{seconds:.0f}` seconds"
+
+ minutes = seconds / 60
+ if minutes < 120:
+ return f"`{minutes:.0f}` minutes"
+
+ hours = minutes / 60
+ return f"`{hours:.0f}` hours"
diff --git a/modules/ui.py b/modules/ui.py
new file mode 100644
index 0000000000000000000000000000000000000000..29cb0a28e9d35b0aeb8bfb679c9d986b9bc04be1
--- /dev/null
+++ b/modules/ui.py
@@ -0,0 +1,89 @@
+from pathlib import Path
+
+import gradio as gr
+import torch
+
+from modules import shared
+
+with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
+ css = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
+ chat_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
+ main_js = f.read()
+with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
+ chat_js = f.read()
+
+refresh_symbol = '\U0001f504' # 🔄
+theme = gr.themes.Default(
+ font=['Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'],
+ font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
+).set(
+ border_color_primary='#c5c5d2',
+ button_large_padding='6px 12px',
+ body_text_color_subdued='#484848',
+ background_fill_secondary='#eaeaea'
+)
+
+
+def list_model_elements():
+ elements = []
+ return elements
+
+
+def list_interface_input_elements(chat=False):
+ elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
+ if chat:
+ elements += ['name1', 'name2', 'greeting', 'context', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']
+
+ elements += list_model_elements()
+ return elements
+
+
+def gather_interface_values(*args):
+ output = {}
+ for i, element in enumerate(shared.input_elements):
+ output[element] = args[i]
+
+ shared.persistent_interface_state = output
+ return output
+
+
+def apply_interface_values(state, use_persistent=False):
+ if use_persistent:
+ state = shared.persistent_interface_state
+
+ elements = list_interface_input_elements(chat=shared.is_chat())
+ if len(state) == 0:
+ return [gr.update() for k in elements] # Dummy, do nothing
+ else:
+ return [state[k] if k in state else gr.update() for k in elements]
+
+
+class ToolButton(gr.Button, gr.components.FormComponent):
+ """Small button with single emoji as text, fits inside gradio forms"""
+
+ def __init__(self, **kwargs):
+ super().__init__(variant="tool", **kwargs)
+
+ def get_block_name(self):
+ return "button"
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ def refresh():
+ refresh_method()
+ args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+ for k, v in args.items():
+ setattr(refresh_component, k, v)
+
+ return gr.update(**(args or {}))
+
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button.click(
+ fn=refresh,
+ inputs=[],
+ outputs=[refresh_component]
+ )
+ return refresh_button
diff --git a/modules/utils.py b/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6722022d89003221980ed89cc9e9a0d5e1d7a429
--- /dev/null
+++ b/modules/utils.py
@@ -0,0 +1,76 @@
+import os
+import re
+from pathlib import Path
+
+from modules import shared
+
+
+def atoi(text):
+ return int(text) if text.isdigit() else text.lower()
+
+
+# Replace multiple string pairs in a string
+def replace_all(text, dic):
+ for i, j in dic.items():
+ text = text.replace(i, j)
+
+ return text
+
+
+def natural_keys(text):
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
+
+
+def get_available_models():
+ if shared.args.flexgen:
+ return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=natural_keys)
+ else:
+ return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json', '.yaml'))], key=natural_keys)
+
+
+def get_available_presets():
+ return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=natural_keys)
+
+
+def get_available_prompts():
+ prompts = []
+ files = set((k.stem for k in Path('prompts').glob('*.txt')))
+ prompts += sorted([k for k in files if re.match('^[0-9]', k)], key=natural_keys, reverse=True)
+ prompts += sorted([k for k in files if re.match('^[^0-9]', k)], key=natural_keys)
+ prompts += ['Instruct-' + k for k in get_available_instruction_templates() if k != 'None']
+ prompts += ['None']
+ return prompts
+
+
+def get_available_characters():
+ paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+ return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=natural_keys)
+
+
+def get_available_instruction_templates():
+ path = "characters/instruction-following"
+ paths = []
+ if os.path.exists(path):
+ paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+
+ return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys)
+
+
+def get_available_extensions():
+ return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=natural_keys)
+
+
+def get_available_softprompts():
+ return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=natural_keys)
+
+
+def get_available_loras():
+ return sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys)
+
+
+def get_datasets(path: str, ext: str):
+ return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
+
+
+def get_available_chat_styles():
+ return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
diff --git a/presets/Contrastive Search.txt b/presets/Contrastive Search.txt
new file mode 100644
index 0000000000000000000000000000000000000000..832bc9caf9b744d9d9c728f88d887f012a56ba3e
--- /dev/null
+++ b/presets/Contrastive Search.txt
@@ -0,0 +1,3 @@
+do_sample=False
+penalty_alpha=0.6
+top_k=4
diff --git a/presets/Debug-deterministic.txt b/presets/Debug-deterministic.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6673b71c8164effc401a486055b7f9a021b2acfb
--- /dev/null
+++ b/presets/Debug-deterministic.txt
@@ -0,0 +1 @@
+do_sample=False
diff --git a/presets/Default.txt b/presets/Default.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d28ce62f0e36d1f7824fe40d6e40018c9d78ea21
--- /dev/null
+++ b/presets/Default.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.5
+top_k=40
+temperature=0.7
+repetition_penalty=1.2
+typical_p=1.0
diff --git a/presets/Kobold-Godlike.txt b/presets/Kobold-Godlike.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0ba5b794b6d0130a1fa1d918bda9a276f7d23367
--- /dev/null
+++ b/presets/Kobold-Godlike.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.5
+top_k=0
+temperature=0.7
+repetition_penalty=1.1
+typical_p=0.19
diff --git a/presets/Kobold-Liminal Drift.txt b/presets/Kobold-Liminal Drift.txt
new file mode 100644
index 0000000000000000000000000000000000000000..be4dd3bd7a70af2d4eb6c847bed6bedee5379dce
--- /dev/null
+++ b/presets/Kobold-Liminal Drift.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=0
+temperature=0.66
+repetition_penalty=1.1
+typical_p=0.6
diff --git a/presets/LLaMA-Precise.txt b/presets/LLaMA-Precise.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8098b390a097fc9438a2a82ec2bdd58adb2a771b
--- /dev/null
+++ b/presets/LLaMA-Precise.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.1
+top_k=40
+temperature=0.7
+repetition_penalty=1.18
+typical_p=1.0
diff --git a/presets/MOSS.txt b/presets/MOSS.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e895e88623ef393abbcfd99ae5e53bc43f468763
--- /dev/null
+++ b/presets/MOSS.txt
@@ -0,0 +1,3 @@
+temperature=0.7
+top_p=0.8
+repetition_penalty=1.02
diff --git a/presets/Naive.txt b/presets/Naive.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aa8c058224c533f4084e230f6bbf77b63d5e81ea
--- /dev/null
+++ b/presets/Naive.txt
@@ -0,0 +1,4 @@
+do_sample=True
+temperature=0.7
+top_p=0.85
+top_k=50
diff --git a/presets/NovelAI-Best Guess.txt b/presets/NovelAI-Best Guess.txt
new file mode 100644
index 0000000000000000000000000000000000000000..db3fa75b2a11d7e29b108177f9894e82d1e52126
--- /dev/null
+++ b/presets/NovelAI-Best Guess.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.9
+top_k=100
+temperature=0.8
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Decadence.txt b/presets/NovelAI-Decadence.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d3109f3e3f3a021810d171a0b98f615766b57e4b
--- /dev/null
+++ b/presets/NovelAI-Decadence.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=100
+temperature=2
+repetition_penalty=1
+typical_p=0.97
diff --git a/presets/NovelAI-Genesis.txt b/presets/NovelAI-Genesis.txt
new file mode 100644
index 0000000000000000000000000000000000000000..cc7376b3b981a260448a65cd3c00c7b3904308e2
--- /dev/null
+++ b/presets/NovelAI-Genesis.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.98
+top_k=0
+temperature=0.63
+repetition_penalty=1.05
+typical_p=1.0
diff --git a/presets/NovelAI-Lycaenidae.txt b/presets/NovelAI-Lycaenidae.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0134569cef76bc0de6b3dc7885d94d9d9afdfd62
--- /dev/null
+++ b/presets/NovelAI-Lycaenidae.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.85
+top_k=12
+temperature=2
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Ouroboros.txt b/presets/NovelAI-Ouroboros.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e944b54e78e1f63bd4bb6f56a717e0fec751c6b
--- /dev/null
+++ b/presets/NovelAI-Ouroboros.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=100
+temperature=1.07
+repetition_penalty=1.05
+typical_p=1.0
diff --git a/presets/NovelAI-Pleasing Results.txt b/presets/NovelAI-Pleasing Results.txt
new file mode 100644
index 0000000000000000000000000000000000000000..330114a25db6d194dbc8689bf5476a81f649cf64
--- /dev/null
+++ b/presets/NovelAI-Pleasing Results.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=1.0
+top_k=0
+temperature=0.44
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Sphinx Moth.txt b/presets/NovelAI-Sphinx Moth.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bace1e24b5dcc64fdde99097930f41a991e91b8e
--- /dev/null
+++ b/presets/NovelAI-Sphinx Moth.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.18
+top_k=30
+temperature=2.0
+repetition_penalty=1.15
+typical_p=1.0
diff --git a/presets/NovelAI-Storywriter.txt b/presets/NovelAI-Storywriter.txt
new file mode 100644
index 0000000000000000000000000000000000000000..2df5f8181458c642ed4691925ade3d542de5391c
--- /dev/null
+++ b/presets/NovelAI-Storywriter.txt
@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.73
+top_k=0
+temperature=0.72
+repetition_penalty=1.1
+typical_p=1.0
diff --git a/presets/Verbose (Beam Search).txt b/presets/Verbose (Beam Search).txt
new file mode 100644
index 0000000000000000000000000000000000000000..464a4a5f0dda62348fda2cbbba4a98036c744d5c
--- /dev/null
+++ b/presets/Verbose (Beam Search).txt
@@ -0,0 +1,9 @@
+num_beams=10
+min_length=200
+length_penalty=1.4
+no_repeat_ngram_size=2
+early_stopping=True
+temperature=0.7
+top_k=150
+top_p=0.92
+repetition_penalty=4.5
diff --git a/prompts/Alpaca-with-Input.txt b/prompts/Alpaca-with-Input.txt
new file mode 100644
index 0000000000000000000000000000000000000000..56df0e285be9689ab1f8ea698ce748e6d1b02af2
--- /dev/null
+++ b/prompts/Alpaca-with-Input.txt
@@ -0,0 +1,10 @@
+Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
+
+### Instruction:
+Instruction
+
+### Input:
+Input
+
+### Response:
+
diff --git a/prompts/GPT-4chan.txt b/prompts/GPT-4chan.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1bc8c7f4613f982e3dfa367562a764cf5bd4c73b
--- /dev/null
+++ b/prompts/GPT-4chan.txt
@@ -0,0 +1,6 @@
+-----
+--- 865467536
+Hello, AI frens!
+How are you doing on this fine day?
+--- 865467537
+
diff --git a/prompts/QA.txt b/prompts/QA.txt
new file mode 100644
index 0000000000000000000000000000000000000000..32b0e2350f3c0a7f447dcd1aba11d6ae2247e5a8
--- /dev/null
+++ b/prompts/QA.txt
@@ -0,0 +1,4 @@
+Common sense questions and answers
+
+Question:
+Factual answer:
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..85d12b447199fbc06c3643ff0459c99319281d7b
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+accelerate==0.19.0
+colorama
+datasets
+flexgen==0.1.7
+gradio_client==0.2.5
+gradio==3.31.0
+markdown
+numpy
+pandas
+Pillow>=9.5.0
+pyyaml
+requests
+rwkv==0.7.3
+safetensors==0.3.1
+sentencepiece
+tqdm
+git+https://github.com/huggingface/peft
+transformers==4.29.1
+bitsandbytes==0.38.1
+llama-cpp-python==0.1.50
\ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/server.py b/server.py
new file mode 100644
index 0000000000000000000000000000000000000000..43eb295954318d4f94b2192d0c5127085b787880
--- /dev/null
+++ b/server.py
@@ -0,0 +1,896 @@
+import logging
+import os
+import requests
+import warnings
+import modules.logging_colors
+
+os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+os.environ['BITSANDBYTES_NOWELCOME'] = '1'
+warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.INFO)
+
+# This is a hack to prevent Gradio from phoning home when it gets imported
+def my_get(url, **kwargs):
+ logging.info('Gradio HTTP request redirected to localhost :)')
+ kwargs.setdefault('allow_redirects', True)
+ return requests.api.request('get', 'http://127.0.0.1/', **kwargs)
+
+
+original_get = requests.get
+requests.get = my_get
+import gradio as gr
+requests.get = original_get
+
+import matplotlib
+matplotlib.use('Agg') # This fixes LaTeX rendering on some systems
+
+import importlib
+import io
+import json
+import math
+import os
+import re
+import sys
+import time
+import traceback
+import zipfile
+from datetime import datetime
+from functools import partial
+from pathlib import Path
+
+import psutil
+import torch
+import yaml
+from PIL import Image
+
+import modules.extensions as extensions_module
+from modules import chat, shared, training, ui, utils
+from modules.extensions import apply_extensions
+from modules.html_generator import chat_html_wrapper
+from modules.LoRA import add_lora_to_model
+from modules.models import load_model, load_soft_prompt, unload_model
+from modules.text_generation import generate_reply_wrapper, get_encoded_length, stop_everything_event
+
+
+def load_model_wrapper(selected_model, autoload=False):
+ if not autoload:
+ yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
+ return
+
+ if selected_model == 'None':
+ yield "No model selected"
+ else:
+ try:
+ yield f"Loading {selected_model}..."
+ shared.model_name = selected_model
+ unload_model()
+ if selected_model != '':
+ shared.model, shared.tokenizer = load_model(shared.model_name)
+
+ yield f"Successfully loaded {selected_model}"
+ except:
+ yield traceback.format_exc()
+
+
+def load_lora_wrapper(selected_loras):
+ yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
+ add_lora_to_model(selected_loras)
+ yield ("Successfuly applied the LoRAs")
+
+
+def load_preset_values(preset_menu, state, return_dict=False):
+ generate_params = {
+ 'do_sample': True,
+ 'temperature': 1,
+ 'top_p': 1,
+ 'typical_p': 1,
+ 'repetition_penalty': 1,
+ 'encoder_repetition_penalty': 1,
+ 'top_k': 50,
+ 'num_beams': 1,
+ 'penalty_alpha': 0,
+ 'min_length': 0,
+ 'length_penalty': 1,
+ 'no_repeat_ngram_size': 0,
+ 'early_stopping': False,
+ }
+ with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
+ preset = infile.read()
+ for i in preset.splitlines():
+ i = i.rstrip(',').strip().split('=')
+ if len(i) == 2 and i[0].strip() != 'tokens':
+ generate_params[i[0].strip()] = eval(i[1].strip())
+ generate_params['temperature'] = min(1.99, generate_params['temperature'])
+
+ if return_dict:
+ return generate_params
+ else:
+ state.update(generate_params)
+ return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
+
+
+def upload_soft_prompt(file):
+ with zipfile.ZipFile(io.BytesIO(file)) as zf:
+ zf.extract('meta.json')
+ j = json.loads(open('meta.json', 'r').read())
+ name = j['name']
+ Path('meta.json').unlink()
+
+ with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
+ f.write(file)
+
+ return name
+
+
+def open_save_prompt():
+ fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
+ return gr.update(value=fname, visible=True), gr.update(visible=False), gr.update(visible=True)
+
+
+def save_prompt(text, fname):
+ if fname != "":
+ with open(Path(f'prompts/{fname}.txt'), 'w', encoding='utf-8') as f:
+ f.write(text)
+
+ message = f"Saved to prompts/{fname}.txt"
+ else:
+ message = "Error: No prompt name given."
+
+ return message, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
+
+
+def load_prompt(fname):
+ if fname in ['None', '']:
+ return ''
+ elif fname.startswith('Instruct-'):
+ fname = re.sub('^Instruct-', '', fname)
+ with open(Path(f'characters/instruction-following/{fname}.yaml'), 'r', encoding='utf-8') as f:
+ data = yaml.safe_load(f)
+ output = ''
+ if 'context' in data:
+ output += data['context']
+
+ replacements = {
+ '<|user|>': data['user'],
+ '<|bot|>': data['bot'],
+ '<|user-message|>': 'Input',
+ }
+
+ output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
+ return output.rstrip(' ')
+ else:
+ with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
+ text = f.read()
+ if text[-1] == '\n':
+ text = text[:-1]
+
+ return text
+
+
+def count_tokens(text):
+ tokens = get_encoded_length(text)
+ return f'{tokens} tokens in the input.'
+
+
+def download_model_wrapper(repo_id):
+ try:
+ downloader = importlib.import_module("download-model")
+ repo_id_parts = repo_id.split(":")
+ model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
+ branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
+ check = False
+
+ yield ("Cleaning up the model/branch names")
+ model, branch = downloader.sanitize_model_and_branch_names(model, branch)
+
+ yield ("Getting the download links from Hugging Face")
+ links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
+
+ yield ("Getting the output folder")
+ output_folder = downloader.get_output_folder(model, branch, is_lora)
+
+ if check:
+ yield ("Checking previously downloaded files")
+ downloader.check_model_files(model, branch, links, sha256, output_folder)
+ else:
+ yield (f"Downloading files to {output_folder}")
+ downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
+ yield ("Done!")
+ except:
+ yield traceback.format_exc()
+
+
+# Update the command-line arguments based on the interface values
+def update_model_parameters(state, initial=False):
+ elements = ui.list_model_elements() # the names of the parameters
+ gpu_memories = []
+
+ for i, element in enumerate(elements):
+ if element not in state:
+ continue
+
+ value = state[element]
+ if element.startswith('gpu_memory'):
+ gpu_memories.append(value)
+ continue
+
+ if initial and vars(shared.args)[element] != vars(shared.args_defaults)[element]:
+ continue
+
+ # Setting null defaults
+ if element in ['wbits', 'groupsize', 'model_type'] and value == 'None':
+ value = vars(shared.args_defaults)[element]
+ elif element in ['cpu_memory'] and value == 0:
+ value = vars(shared.args_defaults)[element]
+
+ # Making some simple conversions
+ if element in ['wbits', 'groupsize', 'pre_layer']:
+ value = int(value)
+ elif element == 'cpu_memory' and value is not None:
+ value = f"{value}MiB"
+
+ if element in ['pre_layer']:
+ value = [value] if value > 0 else None
+
+ setattr(shared.args, element, value)
+
+ found_positive = False
+ for i in gpu_memories:
+ if i > 0:
+ found_positive = True
+ break
+
+ if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
+ if found_positive:
+ shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
+ else:
+ shared.args.gpu_memory = None
+
+
+def get_model_specific_settings(model):
+ settings = shared.model_config
+ model_settings = {}
+
+ for pat in settings:
+ if re.match(pat.lower(), model.lower()):
+ for k in settings[pat]:
+ model_settings[k] = settings[pat][k]
+
+ return model_settings
+
+
+def load_model_specific_settings(model, state, return_dict=False):
+ model_settings = get_model_specific_settings(model)
+ for k in model_settings:
+ if k in state:
+ state[k] = model_settings[k]
+
+ return state
+
+
+def save_model_settings(model, state):
+ if model == 'None':
+ yield ("Not saving the settings because no model is loaded.")
+ return
+
+ with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
+ if p.exists():
+ user_config = yaml.safe_load(open(p, 'r').read())
+ else:
+ user_config = {}
+
+ model_regex = model + '$' # For exact matches
+ if model_regex not in user_config:
+ user_config[model_regex] = {}
+
+ for k in ui.list_model_elements():
+ user_config[model_regex][k] = state[k]
+
+ with open(p, 'w') as f:
+ f.write(yaml.dump(user_config))
+
+ yield (f"Settings for {model} saved to {p}")
+
+
+def create_model_menus():
+ # Finding the default values for the GPU and CPU memories
+ total_mem = []
+ for i in range(torch.cuda.device_count()):
+ total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
+
+ default_gpu_mem = []
+ if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
+ for i in shared.args.gpu_memory:
+ if 'mib' in i.lower():
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
+ else:
+ default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
+ while len(default_gpu_mem) < len(total_mem):
+ default_gpu_mem.append(0)
+
+ total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
+ if shared.args.cpu_memory is not None:
+ default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
+ else:
+ default_cpu_mem = 0
+
+
+def create_settings_menus(default_preset):
+
+ generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+ ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button')
+ with gr.Column():
+ shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
+
+ with gr.Row():
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Custom generation parameters ([click here to view technical documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature', info='Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness.')
+ shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p', info='If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results.')
+ shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k', info='Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results.')
+ shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p', info='If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text.')
+ with gr.Column():
+ shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.')
+ shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.')
+ shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.')
+ shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.')
+ shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
+ with gr.Column():
+ with gr.Box():
+ gr.Markdown('Contrastive search')
+ shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
+
+ gr.Markdown('Beam search (uses a lot of VRAM)')
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
+ shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
+ with gr.Column():
+ shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
+
+ with gr.Box():
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=1, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
+ shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"')
+ with gr.Column():
+ shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.')
+ shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.')
+
+ shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
+ shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
+
+ with gr.Accordion('Soft prompt', open=False):
+ with gr.Row():
+ shared.gradio['softprompts_menu'] = gr.Dropdown(choices=utils.get_available_softprompts(), value='None', label='Soft prompt')
+ ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': utils.get_available_softprompts()}, 'refresh-button')
+
+ gr.Markdown('Upload a soft prompt (.zip format):')
+ with gr.Row():
+ shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
+
+ shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+ shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
+ shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
+
+
+def set_interface_arguments(interface_mode, extensions, bool_active):
+ modes = ["default", "notebook", "chat", "cai_chat"]
+ cmd_list = vars(shared.args)
+ bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
+
+ shared.args.extensions = extensions
+ for k in modes[1:]:
+ setattr(shared.args, k, False)
+ if interface_mode != "default":
+ setattr(shared.args, interface_mode, True)
+
+ for k in bool_list:
+ setattr(shared.args, k, False)
+ for k in bool_active:
+ setattr(shared.args, k, True)
+
+ shared.need_restart = True
+
+
+def create_interface():
+
+ # Defining some variables
+ gen_events = []
+ default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
+ if len(shared.lora_names) == 1:
+ default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
+ else:
+ default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
+ title = 'Text generation web UI'
+
+ # Authentication variables
+ auth = None
+ if shared.args.gradio_auth_path is not None:
+ gradio_auth_creds = []
+ with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
+ for line in file.readlines():
+ gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
+ auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
+
+ # Importing the extension files and executing their setup() functions
+ if shared.args.extensions is not None and len(shared.args.extensions) > 0:
+ extensions_module.load_extensions()
+
+ # css/js strings
+ css = ui.css if not shared.is_chat() else ui.css + ui.chat_css
+ js = ui.main_js if not shared.is_chat() else ui.main_js + ui.chat_js
+ css += apply_extensions('css')
+ js += apply_extensions('js')
+
+ with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
+
+ # Create chat mode interface
+ if shared.is_chat():
+ shared.input_elements = ui.list_interface_input_elements(chat=True)
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
+ shared.gradio['Chat input'] = gr.State()
+ shared.gradio['dummy'] = gr.State()
+
+ with gr.Tab('Text generation', elem_id='main'):
+ shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'chat', 'cai-chat'))
+ shared.gradio['textbox'] = gr.Textbox(label='Input')
+ with gr.Row():
+ shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop')
+ shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate', variant='primary')
+ shared.gradio['Continue'] = gr.Button('Continue')
+
+ with gr.Row():
+ shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
+ shared.gradio['Regenerate'] = gr.Button('Regenerate')
+ shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
+
+ with gr.Row():
+ shared.gradio['Impersonate'] = gr.Button('Impersonate')
+ shared.gradio['Send dummy message'] = gr.Button('Send dummy message')
+ shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply')
+
+ with gr.Row():
+ shared.gradio['Remove last'] = gr.Button('Remove last')
+ shared.gradio['Clear history'] = gr.Button('Clear history')
+ shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant='stop', visible=False)
+ shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
+
+ shared.gradio['mode'] = gr.Radio(choices=['chat', 'chat-instruct', 'instruct'], value=shared.settings['mode'] if shared.settings['mode'] in ['chat', 'instruct', 'chat-instruct'] else 'chat', label='Mode', info='Defines how the chat prompt is generated. In instruct and chat-instruct modes, the instruction template selected under "Chat settings" must match the current model.')
+ shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct')
+
+ with gr.Tab('Chat settings', elem_id='chat-settings'):
+ with gr.Row():
+ shared.gradio['character_menu'] = gr.Dropdown(choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.')
+ ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button')
+
+ with gr.Row():
+ with gr.Column(scale=8):
+ shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
+ shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
+ shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
+ shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
+
+ with gr.Column(scale=1):
+ shared.gradio['character_picture'] = gr.Image(label='Character picture', type='pil')
+ shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
+
+ shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Instruction template', value='None', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.')
+ shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
+ shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
+ shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context')
+ shared.gradio['turn_template'] = gr.Textbox(value=shared.settings['turn_template'], lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.')
+ with gr.Row():
+ shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=4, label='Command for chat-instruct mode', info='<|character|> gets replaced by the bot name, and <|prompt|> gets replaced by the regular chat prompt.')
+
+ with gr.Row():
+ with gr.Tab('Chat history'):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('## Upload')
+ shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
+
+ with gr.Column():
+ gr.Markdown('## Download')
+ shared.gradio['download'] = gr.File()
+ shared.gradio['download_button'] = gr.Button(value='Click me')
+
+ with gr.Tab('Upload character'):
+ gr.Markdown('## JSON format')
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown('1. Select the JSON file')
+ shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
+
+ with gr.Column():
+ gr.Markdown('2. Select your character\'s profile picture (optional)')
+ shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
+
+ shared.gradio['Upload character'] = gr.Button(value='Submit')
+ gr.Markdown('## TavernAI PNG format')
+ shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
+
+ with gr.Tab("Parameters", elem_id="parameters"):
+ with gr.Box():
+ gr.Markdown("Chat parameters")
+ with gr.Row():
+ with gr.Column():
+ shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
+ shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
+
+ with gr.Column():
+ shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)', info='New generations will be called until either this number is reached or no new content is generated between two iterations')
+ shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character')
+
+ create_settings_menus(default_preset)
+
+ # Create notebook mode interface
+ elif shared.args.notebook:
+ shared.input_elements = ui.list_interface_input_elements(chat=False)
+ shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
+ shared.gradio['last_input'] = gr.State('')
+ with gr.Tab("Text generation", elem_id="main"):
+ with gr.Row():
+ with gr.Column(scale=4):
+ with gr.Tab('Raw'):
+ shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_classes="textbox", lines=27)
+
+ with gr.Tab('Markdown'):
+ shared.gradio['markdown'] = gr.Markdown()
+
+ with gr.Tab('HTML'):
+ shared.gradio['html'] = gr.HTML()
+
+ with gr.Row():
+ shared.gradio['Generate'] = gr.Button('Generate', variant='primary', elem_classes="small-button")
+ shared.gradio['Stop'] = gr.Button('Stop', elem_classes="small-button")
+ shared.gradio['Undo'] = gr.Button('Undo', elem_classes="small-button")
+ shared.gradio['Regenerate'] = gr.Button('Regenerate', elem_classes="small-button")
+
+ with gr.Column(scale=1):
+ gr.HTML('