| import argparse |
| import markdown2 |
| import os |
| import sys |
| import uvicorn |
|
|
| from pathlib import Path |
| from typing import Union |
|
|
| from fastapi import FastAPI, Depends |
| from fastapi.responses import HTMLResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from pydantic import BaseModel, Field |
| from sse_starlette.sse import EventSourceResponse, ServerSentEvent |
| from tclogger import logger |
|
|
| from constants.models import AVAILABLE_MODELS_DICTS |
| from constants.envs import CONFIG |
|
|
| from messagers.message_composer import MessageComposer |
| from mocks.stream_chat_mocker import stream_chat_mock |
| from networks.huggingface_streamer import HuggingfaceStreamer |
| from networks.openai_streamer import OpenaiStreamer |
|
|
|
|
| class ChatAPIApp: |
| def __init__(self): |
| self.app = FastAPI( |
| docs_url="/", |
| title=CONFIG["app_name"], |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, |
| version=CONFIG["version"], |
| ) |
| self.setup_routes() |
|
|
| def get_available_models(self): |
| return {"object": "list", "data": AVAILABLE_MODELS_DICTS} |
|
|
| def extract_api_key( |
| credentials: HTTPAuthorizationCredentials = Depends( |
| HTTPBearer(auto_error=False) |
| ), |
| ): |
| api_key = None |
| if credentials: |
| api_key = credentials.credentials |
| else: |
| api_key = os.getenv("HF_TOKEN") |
|
|
| if api_key: |
| if api_key.startswith("hf_"): |
| return api_key |
| else: |
| logger.warn(f"Invalid HF Token!") |
| else: |
| logger.warn("Not provide HF Token!") |
| return None |
|
|
| class ChatCompletionsPostItem(BaseModel): |
| model: str = Field( |
| default="mixtral-8x7b", |
| description="(str) `mixtral-8x7b`", |
| ) |
| messages: list = Field( |
| default=[{"role": "user", "content": "Hello, who are you?"}], |
| description="(list) Messages", |
| ) |
| temperature: Union[float, None] = Field( |
| default=0.5, |
| description="(float) Temperature", |
| ) |
| top_p: Union[float, None] = Field( |
| default=0.95, |
| description="(float) top p", |
| ) |
| max_tokens: Union[int, None] = Field( |
| default=-1, |
| description="(int) Max tokens", |
| ) |
| use_cache: bool = Field( |
| default=False, |
| description="(bool) Use cache", |
| ) |
| stream: bool = Field( |
| default=True, |
| description="(bool) Stream", |
| ) |
|
|
| def chat_completions( |
| self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key) |
| ): |
| if item.model == "gpt-3.5-turbo": |
| streamer = OpenaiStreamer() |
| stream_response = streamer.chat_response(messages=item.messages) |
| else: |
| streamer = HuggingfaceStreamer(model=item.model) |
| composer = MessageComposer(model=item.model) |
| composer.merge(messages=item.messages) |
| stream_response = streamer.chat_response( |
| prompt=composer.merged_str, |
| temperature=item.temperature, |
| top_p=item.top_p, |
| max_new_tokens=item.max_tokens, |
| api_key=api_key, |
| use_cache=item.use_cache, |
| ) |
|
|
| if item.stream: |
| event_source_response = EventSourceResponse( |
| streamer.chat_return_generator(stream_response), |
| media_type="text/event-stream", |
| ping=2000, |
| ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}), |
| ) |
| return event_source_response |
| else: |
| data_response = streamer.chat_return_dict(stream_response) |
| return data_response |
|
|
| def get_readme(self): |
| readme_path = Path(__file__).parents[1] / "README.md" |
| with open(readme_path, "r", encoding="utf-8") as rf: |
| readme_str = rf.read() |
| readme_html = markdown2.markdown( |
| readme_str, extras=["table", "fenced-code-blocks", "highlightjs-lang"] |
| ) |
| return readme_html |
|
|
| def setup_routes(self): |
| for prefix in ["", "/v1", "/api", "/api/v1"]: |
| if prefix in ["/api/v1"]: |
| include_in_schema = True |
| else: |
| include_in_schema = False |
|
|
| self.app.get( |
| prefix + "/models", |
| summary="Get available models", |
| include_in_schema=include_in_schema, |
| )(self.get_available_models) |
|
|
| self.app.post( |
| prefix + "/chat/completions", |
| summary="Chat completions in conversation session", |
| include_in_schema=include_in_schema, |
| )(self.chat_completions) |
| self.app.get( |
| "/readme", |
| summary="README of HF LLM API", |
| response_class=HTMLResponse, |
| include_in_schema=False, |
| )(self.get_readme) |
|
|
|
|
| class ArgParser(argparse.ArgumentParser): |
| def __init__(self, *args, **kwargs): |
| super(ArgParser, self).__init__(*args, **kwargs) |
|
|
| self.add_argument( |
| "-s", |
| "--host", |
| type=str, |
| default=CONFIG["host"], |
| help=f"Host for {CONFIG['app_name']}", |
| ) |
| self.add_argument( |
| "-p", |
| "--port", |
| type=int, |
| default=CONFIG["port"], |
| help=f"Port for {CONFIG['app_name']}", |
| ) |
|
|
| self.add_argument( |
| "-d", |
| "--dev", |
| default=False, |
| action="store_true", |
| help="Run in dev mode", |
| ) |
|
|
| self.args = self.parse_args(sys.argv[1:]) |
|
|
|
|
| app = ChatAPIApp().app |
|
|
| if __name__ == "__main__": |
| args = ArgParser().args |
| if args.dev: |
| uvicorn.run("__main__:app", host=args.host, port=args.port, reload=True) |
| else: |
| uvicorn.run("__main__:app", host=args.host, port=args.port, reload=False) |
|
|
| |
| |
|
|