|
import sys |
|
from pathlib import Path |
|
from mistral_inference.generate import generate |
|
from mistral_inference.model import Transformer |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
|
|
def run_chat(model_path: str, prompt: str, max_tokens: int = 256, temperature: float = 1.0, instruct: bool = True, lora_path: str = None): |
|
|
|
model_path = Path(model_path) |
|
tokenizer_file = model_path / "tokenizer.model.v3" |
|
|
|
if not tokenizer_file.is_file(): |
|
raise FileNotFoundError(f"Tokenizer model file not found at {tokenizer_file}") |
|
|
|
mistral_tokenizer = MistralTokenizer.from_file(str(tokenizer_file)) |
|
tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer |
|
|
|
transformer = Transformer.from_folder( |
|
model_path, max_batch_size=3, num_pipeline_ranks=1 |
|
) |
|
|
|
if lora_path is not None: |
|
transformer.load_lora(Path(lora_path)) |
|
|
|
tokens = tokenizer.encode(prompt, bos=True, eos=False) |
|
generated_tokens, _ = generate( |
|
[tokens], |
|
transformer, |
|
max_tokens=max_tokens, |
|
temperature=temperature, |
|
eos_id=tokenizer.eos_id, |
|
) |
|
answer = tokenizer.decode(generated_tokens[0]) |
|
print(answer) |
|
|
|
if __name__ == "__main__": |
|
import fire |
|
fire.Fire(run_chat) |
|
|