Spaces:
Running
Running
import sys | |
from typing import List | |
import traceback | |
import os | |
import base64 | |
import json | |
import pprint | |
from huggingface_hub import Repository | |
from text_generation import Client | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
import modules.cloud_logging | |
# from flask import Flask, request, render_template | |
# from flask_cors import CORS | |
# app = Flask(__name__, static_folder='static') | |
# app.config['TEMPLATES_AUTO_RELOAD'] = Tru | |
# CORS(app, resources= { | |
# r"/generate": {"origins": origins}, | |
# r"/infill": {"origins": origins}, | |
# }) | |
# origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"] | |
PORT = 7860 | |
VERBOSE = False | |
if os.path.exists('unlock'): | |
MAX_LENGTH = 8192 | |
else: | |
MAX_LENGTH = 8192 | |
TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.' | |
from fastapi import FastAPI, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse, StreamingResponse | |
app = FastAPI(docs_url=None, redoc_url=None) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
API_URL = os.environ.get("API_URL") | |
with open("./HHH_prompt.txt", "r") as f: | |
HHH_PROMPT = f.read() + "\n\n" | |
FIM_PREFIX = "<fim_prefix>" | |
FIM_MIDDLE = "<fim_middle>" | |
FIM_SUFFIX = "<fim_suffix>" | |
END_OF_TEXT = "<|endoftext|>" | |
FIM_INDICATOR = "<infill>" | |
client = Client( | |
API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
def index() -> FileResponse: | |
return FileResponse(path="static/index.html", media_type="text/html") | |
def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
fim_mode = suffix is not None | |
if suffix is not None: | |
prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" | |
print("----prompt----") | |
print(prompt) | |
else: | |
prompt = prefix | |
output = client.generate(prompt, **generate_kwargs) | |
# TODO | |
generated_text = output.generated_text | |
truncated = False | |
while generated_text.endswith(END_OF_TEXT): | |
generated_text = generated_text[:-len(END_OF_TEXT)] | |
generation = { | |
'truncated': truncated, | |
} | |
if fim_mode: | |
generation['text'] = prefix + generated_text + suffix | |
generation['parts'] = [prefix, suffix] | |
generation['infills'] = [generated_text] | |
generation['type'] = 'infill' | |
else: | |
generation['text'] = prompt + generated_text | |
generation['parts'] = [prompt] | |
generation['type'] = 'generate' | |
return generation | |
# async def generate_maybe(request: Request): | |
async def generate_maybe(info: str): | |
# form = await info.json() | |
# form = await request.json() | |
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) | |
# fix padding, following https://stackoverflow.com/a/9956217/1319683 | |
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') | |
form = json.loads(info) | |
# print(form) | |
prompt = form['prompt'] | |
length_limit = int(form['length']) | |
temperature = float(form['temperature']) | |
logging.info(json.dumps({ | |
'length': length_limit, | |
'temperature': temperature, | |
'prompt': prompt, | |
})) | |
try: | |
generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0) | |
if generation['truncated']: | |
message = TRUNCATION_MESSAGE | |
else: | |
message = '' | |
return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message} | |
except Exception as e: | |
traceback.print_exception(*sys.exc_info()) | |
logging.error(e) | |
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'} | |
# async def infill_maybe(request: Request): | |
async def infill_maybe(info: str): | |
# form = await info.json() | |
# form = await request.json() | |
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues) | |
# fix padding, following https://stackoverflow.com/a/9956217/1319683 | |
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8') | |
form = json.loads(info) | |
length_limit = int(form['length']) | |
temperature = float(form['temperature']) | |
max_retries = 1 | |
extra_sentinel = True | |
logging.info(json.dumps({ | |
'length': length_limit, | |
'temperature': temperature, | |
'parts_joined': '<infill>'.join(form['parts']), | |
})) | |
try: | |
if len(form['parts']) > 2: | |
return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single infill is supported!"} | |
prefix, suffix = form['parts'] | |
generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0) | |
generation['result'] = 'success' | |
if generation['truncated']: | |
generation['message'] = TRUNCATION_MESSAGE | |
else: | |
generation['message'] = '' | |
return generation | |
# return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']} | |
except Exception as e: | |
traceback.print_exception(*sys.exc_info()) | |
logging.error(e) | |
return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'} | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=PORT, threaded=False) | |