# Copyright 2024 THUDM and the LlamaFactory team. # # This code is inspired by the THUDM's ChatGLM implementation. # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio from threading import Thread from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from ..extras.misc import torch_gc from ..hparams import get_infer_args from .hf_engine import HuggingfaceEngine from .vllm_engine import VllmEngine if TYPE_CHECKING: from numpy.typing import NDArray from .base_engine import BaseEngine, Response def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None: asyncio.set_event_loop(loop) loop.run_forever() class ChatModel: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, generating_args = get_infer_args(args) if model_args.infer_backend == "huggingface": self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) elif model_args.infer_backend == "vllm": self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) else: raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() task = asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) task.result() def chat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, image: Optional["NDArray"] = None, **input_kwargs, ) -> List["Response"]: task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) return task.result() async def achat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, image: Optional["NDArray"] = None, **input_kwargs, ) -> List["Response"]: return await self.engine.chat(messages, system, tools, image, **input_kwargs) def stream_chat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, image: Optional["NDArray"] = None, **input_kwargs, ) -> Generator[str, None, None]: generator = self.astream_chat(messages, system, tools, image, **input_kwargs) while True: try: task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) yield task.result() except StopAsyncIteration: break async def astream_chat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, image: Optional["NDArray"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): yield new_token def get_scores( self, batch_input: List[str], **input_kwargs, ) -> List[float]: task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) return task.result() async def aget_scores( self, batch_input: List[str], **input_kwargs, ) -> List[float]: return await self.engine.get_scores(batch_input, **input_kwargs) def run_chat() -> None: try: import platform if platform.system() != "Windows": import readline # noqa: F401 except ImportError: print("Install `readline` for a better experience.") chat_model = ChatModel() messages = [] print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") while True: try: query = input("\nUser: ") except UnicodeDecodeError: print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") continue except Exception: raise if query.strip() == "exit": break if query.strip() == "clear": messages = [] torch_gc() print("History has been removed.") continue messages.append({"role": "user", "content": query}) print("Assistant: ", end="", flush=True) response = "" for new_text in chat_model.stream_chat(messages): print(new_text, end="", flush=True) response += new_text print() messages.append({"role": "assistant", "content": response})