File size: 1,296 Bytes
32fe622 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import sys
from pathlib import Path
from mistral_inference.generate import generate
from mistral_inference.model import Transformer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None):
# Find the correct tokenizer file
model_path = Path(model_path)
tokenizer_file = model_path / "tokenizer.model.v3"
if not tokenizer_file.is_file():
raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}")
mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file))
tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
transformer = Transformer.from_folder(
model_path, max_batch_size=3, num_pipeline_ranks=1
)
if lora_path is not None:
transformer.load_lora(Path(lora_path))
tokens = tokenizer.encode(prompt, bos=True, eos=False)
generated_tokens, _ = generate(
[tokens],
transformer,
max_tokens=max_tokens,
temperature=temperature,
eos_id=tokenizer.eos_id,
)
answer = tokenizer.decode(generated_tokens[0])
print(answer)
if __name__ == "__main__":
import fire
fire.Fire(run_chat)
|