Spaces:
Build error
Build error
File size: 2,865 Bytes
699163a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import time
import traceback
from threading import Thread
from typing import Callable, Optional
from modules.text_generation import get_encoded_length
def build_parameters(body):
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', 2048)),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
return generate_params
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
Thread(target=_start_cloudflared, args=[
port, max_attempts, on_start], daemon=True).start()
def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
traceback.print_exc()
time.sleep(3)
raise Exception('Could not start cloudflared.')
|