Jamba 8xMoe (Slerp Merge)

This model has been merged from Jamba a 52B parameter model with 16 experts. It used an accumulative SLERP to merge experts from 16 to 8.

4 Bit Inference Code

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

model_id = "isemmanuelolowe/Jamba-8xMoE_slerp"

tokenizer = AutoTokenizer.from_pretrained(model_id)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # load_in_8bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    llm_int8_skip_modules=["mamba"],
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    quantization_config=quantization_config
)

input_ids = tokenizer("Here is how to do bubble sort\n```python\n", return_tensors="pt")["input_ids"].to("cuda")

out = model.generate(input_ids, max_new_tokens=256, temperature=0, repetition_penalty=1)
print(tokenizer.batch_decode(out, skip_special_tokens=True))

OUTPUT: Here is how to do bubble sort

['Here is how to do bubble sort\n```python\ndef bubble_sort(array):\n    for i in 0, len(array):\n    for j in 0, len(array):\n    if a[i] < a[j]\n    a[i], a[j]\n\n```\n\n\n\n\n\n\n']
Downloads last month
8
Safetensors
Model size
29B params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.