Configuration Parsing Warning: In config.json: "quantization_config.bits" must be an integer

Sample inference script:

import random
import re
from argparse import ArgumentParser

import torch
import torchaudio
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Cache,
    ExLlamaV2Config,
    ExLlamaV2Tokenizer,
    Timer,
)
from exllamav2.generator import (
    ExLlamaV2DynamicGenerator,
    ExLlamaV2DynamicJob,
    ExLlamaV2Sampler,
)
from rich import print
from torchaudio import functional as F
from xcodec2.modeling_xcodec2 import XCodec2Model

parser = ArgumentParser()
parser.add_argument("-m", "--model", required=True)
parser.add_argument("-v", "--vocoder", required=True)
parser.add_argument("-a", "--audio", default="")
parser.add_argument("-t", "--transcript", default="")
parser.add_argument("-i", "--input", required=True)
parser.add_argument("-o", "--output", default="output.wav")
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("--max_seq_len", type=int, default=2048)
parser.add_argument("--sample_rate", type=int, default=16000)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=1.0)
args = parser.parse_args()

with Timer() as timer:
    config = ExLlamaV2Config(args.model)
    config.max_seq_len = args.max_seq_len
    model = ExLlamaV2(config, lazy_load=True)
    cache = ExLlamaV2Cache(model, lazy=True)
    model.load_autosplit(cache, progress=True)
    tokenizer = ExLlamaV2Tokenizer(config)
    generator = ExLlamaV2DynamicGenerator(model, cache, tokenizer)

print(f"Loaded model in {timer.interval:.2f} seconds.")

with Timer() as timer:
    vocoder = XCodec2Model.from_pretrained(args.vocoder)
    vocoder = vocoder.cuda().eval()

print(f"Loaded vocoder in {timer.interval:.2f} seconds.")

if args.audio:
    with Timer() as timer:
        audio, sample_rate = torchaudio.load(args.audio)
        audio = audio.cuda()

        if audio.shape[0] > 1:
            audio = torch.mean(audio, dim=0, keepdim=True)

        if sample_rate != args.sample_rate:
            audio = F.resample(audio, sample_rate, args.sample_rate)

    print(f"Loaded audio in {timer.interval:.2f} seconds.")

    with Timer() as timer:
        audio = vocoder.encode_code(audio)
        audio = audio[0, 0, :]
        audio = [f"<|s_{a}|>" for a in audio]
        audio = "".join(audio)

    print(f"Encoded audio in {timer.interval:.2f} seconds.")

with Timer() as timer:
    if args.audio and args.transcript:
        args.transcript += " "
    else:
        args.transcript = ""

    input = (
        "<|start_header_id|>user<|end_header_id|>\n\n"
        "Convert the text to speech:"
        "<|TEXT_UNDERSTANDING_START|>"
        f"{args.transcript}{args.input}"
        "<|TEXT_UNDERSTANDING_END|>"
        "<|eot_id|>\n"
        "<|start_header_id|>assistant<|end_header_id|>\n\n"
        "<|SPEECH_GENERATION_START|>"
        f"{audio}"
    )

    input_ids = tokenizer.encode(input, add_bos=True, encode_special_tokens=True)

print(f"Encoded input in {timer.interval:.2f} seconds.")

with Timer() as timer:
    max_new_tokens = config.max_seq_len - input_ids.shape[-1]
    gen_settings = ExLlamaV2Sampler.Settings()
    gen_settings.temperature = args.temperature
    gen_settings.top_p = args.top_p
    seed = args.seed if args.seed else random.randint(0, 2**64 - 1)
    stop_conditions = ["<|SPEECH_GENERATION_END|>"]

    job = ExLlamaV2DynamicJob(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        gen_settings=gen_settings,
        seed=seed,
        stop_conditions=stop_conditions,
        decode_special_tokens=True,
    )

    generator.enqueue(job)
    output = []

    while generator.num_remaining_jobs():
        for result in generator.iterate():
            if result.get("stage", "") == "streaming":
                text = result.get("text", "")
                output.append(text)

                if args.debug:
                    print(text, end="", flush=True)

            if result.get("eos"):
                generator.clear_queue()

                if args.debug:
                    print()

print(
    f"Generated {len(output)} tokens with seed {seed} in {timer.interval:.2f} seconds."
)

with Timer() as timer:
    output = "".join(output)
    output = [int(o) for o in re.findall(r"<\|s_(\d+)\|>", output)]
    output = torch.tensor([[output]]).cuda()
    output = vocoder.decode_code(output)
    output = output[0, 0, :]
    output = output.unsqueeze(0).cpu()
    torchaudio.save(args.output, output, args.sample_rate)

print(f"Decoded audio in {timer.interval:.2f} seconds.")
Downloads last month
6
Inference API
Unable to determine this model's library. Check the docs .

Model tree for Annuvin/Llasa-3B-6.5bpw-h8-exl2

Quantized
(4)
this model