Spaces:
Runtime error
Runtime error
import os | |
import spacy | |
from accelerate import PartialState | |
from accelerate.utils import set_seed | |
from flask import Flask, request, jsonify | |
from gpt2_generation import Translator | |
from gpt2_generation import generate_prompt, MODEL_CLASSES | |
os.environ["http_proxy"] = "http://127.0.0.1:7890" | |
os.environ["https_proxy"] = "http://127.0.0.1:7890" | |
app = Flask(__name__) | |
path_for_model = "./output/gpt2_openprompt/checkpoint-4500" | |
args = { | |
"model_type": "gpt2", | |
"model_name_or_path": path_for_model, | |
"length": 80, | |
"stop_token": None, | |
"temperature": 1.0, | |
"length_penalty": 1.2, | |
"repetition_penalty": 1.2, | |
"k": 3, | |
"p": 0.9, | |
"prefix": "", | |
"padding_text": "", | |
"xlm_language": "", | |
"seed": 42, | |
"use_cpu": False, | |
"num_return_sequences": 1, | |
"fp16": False, | |
"jit": False, | |
} | |
distributed_state = PartialState(cpu=args["use_cpu"]) | |
if args["seed"] is not None: | |
set_seed(args["seed"]) | |
tokenizer = None | |
model = None | |
zh_en_translator = None | |
nlp = None | |
def load_model_and_components(): | |
global tokenizer, model, zh_en_translator, nlp | |
# Initialize the model and tokenizer | |
try: | |
args["model_type"] = args["model_type"].lower() | |
model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]] | |
except KeyError: | |
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") | |
tokenizer = tokenizer_class.from_pretrained(args["model_name_or_path"], padding_side='left') | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.mask_token = tokenizer.eos_token | |
model = model_class.from_pretrained(args["model_name_or_path"]) | |
print("Model loaded!") | |
# translator | |
zh_en_translator = Translator("Helsinki-NLP/opus-mt-zh-en") | |
print("Translator loaded!") | |
# filter | |
nlp = spacy.load('en_core_web_sm') | |
print("Filter loaded!") | |
# Set the model to the right device | |
model.to(distributed_state.device) | |
if args["fp16"]: | |
model.half() | |
def chat(): | |
phrase = request.json.get('phrase') | |
if tokenizer is None or model is None or zh_en_translator is None or nlp is None: | |
load_model_and_components() | |
messages = generate_prompt( | |
prompt_text=phrase, | |
args=args, | |
zh_en_translator=zh_en_translator, | |
nlp=nlp, | |
model=model, | |
tokenizer=tokenizer, | |
distributed_state=distributed_state, | |
) | |
return jsonify(messages) | |
if __name__ == '__main__': | |
load_model_and_components() | |
app.run(host='0.0.0.0', port=10008, debug=False) | |