File size: 1,949 Bytes
8bc8944 239290b 826db2e 8bc8944 239290b 7bbb8c5 239290b 826db2e 239290b 8bc8944 7bbb8c5 239290b 8bc8944 826db2e 1821ee5 7bbb8c5 826db2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from fastapi import FastAPI, Request, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from QuoteGenerator import QuoteGenerator
from typing import Union
from pydantic import BaseModel
import time
import os
# API to key to validate the Referer
API_KEY = os.getenv('API_KEY')
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Function to check of the incoming API call is from valid host or not
def api_key_auth(api_key:str = Depends(oauth2_scheme)):
if api_key != API_KEY:
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail="Forbidden Access"
)
class QuoteRequest(BaseModel):
tags: Union[None, str] = None
do_sample: bool = False
max_new_tokens: int = 16
num_beams: int = 1
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
app = FastAPI()
#Middleware to note time
@app.middleware("http")
async def note_response_time(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time()
print(f'Time taken = {process_time-start_time:.1f}s')
return response
quote_generator = QuoteGenerator()
quote_generator.load_generator()
@app.post("/", dependencies=[Depends(api_key_auth)])
def root(request: QuoteRequest):
print("Incoming request\n", request.__dict__)
return {"quote": "<bot>:A beautiful quote generated by bot"}
@app.post("/generate_quote", dependencies=[Depends(api_key_auth)])
def generate_quote(req: QuoteRequest):
print("\nIncoming request \n", req.__dict__, end='\n\n')
generated_quote_oup = quote_generator.generate_quote(
tags = req.tags,
max_new_tokens = req.max_new_tokens,
num_beams = req.num_beams,
temperature = req.temperature,
top_k = req.top_k,
top_p = req.top_p,
do_sample = req.do_sample
)
return {'quote': generated_quote_oup} |