File size: 1,296 Bytes
32fe622
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import sys
from pathlib import Path
from mistral_inference.generate import generate
from mistral_inference.model import Transformer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None):
    # Find the correct tokenizer file
    model_path = Path(model_path)
    tokenizer_file = model_path / "tokenizer.model.v3"
    
    if not tokenizer_file.is_file():
        raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}")

    mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file))
    tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer

    transformer = Transformer.from_folder(
        model_path, max_batch_size=3, num_pipeline_ranks=1
    )

    if lora_path is not None:
        transformer.load_lora(Path(lora_path))

    tokens = tokenizer.encode(prompt, bos=True, eos=False)
    generated_tokens, _ = generate(
        [tokens],
        transformer,
        max_tokens=max_tokens,
        temperature=temperature,
        eos_id=tokenizer.eos_id,
    )
    answer = tokenizer.decode(generated_tokens[0])
    print(answer)

if __name__ == "__main__":
    import fire
    fire.Fire(run_chat)