File size: 1,737 Bytes
b247dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Manifest as an app service."""

from typing import Any, Dict, cast

from fastapi import APIRouter, FastAPI, HTTPException

from manifest import Manifest
from manifest.response import Response as ManifestResponse
from web_app import schemas

app = FastAPI()
api_router = APIRouter()


@app.get("/")
async def root() -> Dict:
    """Root endpoint."""
    return {"message": "Hello to the Manifest App"}


@api_router.post("/prompt/", status_code=201, response_model=schemas.ManifestResponse)
def prompt_manifest(*, manifest_in: schemas.ManifestCreate) -> Dict:
    """Prompt a manifest session and query."""
    manifest = Manifest(
        client_name=manifest_in.client_name,
        client_connection=manifest_in.client_connection,
        engine=manifest_in.engine,
        cache_name=manifest_in.cache_name,
        cache_connection=manifest_in.cache_connection,
    )
    manifest_prompt_args: Dict[str, Any] = {
        "n": manifest_in.n,
        "max_tokens": manifest_in.max_tokens,
    }
    if manifest_in.temperature:
        manifest_prompt_args["temperature"] = manifest_in.temperature
    if manifest_in.top_k:
        manifest_prompt_args["top_k"] = manifest_in.top_k
    if manifest_in.top_p:
        manifest_prompt_args["top_p"] = manifest_in.top_p

    try:
        response = manifest.run(
            prompt=manifest_in.prompt, return_response=True, **manifest_prompt_args
        )
        response = cast(ManifestResponse, response)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
    return {
        "response": response.get_response(),
        "cached": response.is_cached(),
        "request_params": response.get_request_obj(),
    }


app.include_router(api_router)