File size: 947 Bytes
12001a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from dataclasses import asdict
import pytest
import sys


@pytest.mark.skipif(sys.platform == "win32", reason="EmptyInitOnDevice on CPU not working for Windows.")
@pytest.mark.parametrize("model_size", ["7B", "13B", "30B", "65B"])
def test_config_identical(model_size, lit_llama):
    import lit_llama.adapter as llama_adapter
    import lit_llama.model as llama
    from lit_llama.utils import EmptyInitOnDevice

    llama_config = asdict(llama.LLaMAConfig.from_name(model_size))
    adapter_config = asdict(llama_adapter.LLaMAConfig.from_name(model_size))

    del adapter_config["adapter_prompt_length"]
    del adapter_config["adapter_start_layer"]
    assert adapter_config == llama_config

    with EmptyInitOnDevice():
        llama_model = llama.LLaMA.from_name(model_size)
        adapter_model = llama_adapter.LLaMA.from_name(model_size)
        assert llama_model.lm_head.weight.shape == adapter_model.lm_head.weight.shape