MrD05 commited on
Commit
f2899a8
·
1 Parent(s): 04e80ab

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -50
handler.py CHANGED
@@ -14,55 +14,49 @@ template = """{char_name}'s Persona: {char_persona}
14
  class EndpointHandler():
15
 
16
  def __init__(self, path=""):
17
- self.tokenizer = AutoTokenizer.from_pretrained(path)
18
- self.model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
19
- self.local_llm = HuggingFacePipeline(
20
- pipeline = pipeline(
21
- "text-generation",
22
- model = self.model,
23
- tokenizer = self.tokenizer,
24
- max_length = 2048,
25
- temperature = 0.5,
26
- top_p = 0.9,
27
- top_k = 0,
28
- repetition_penalty = 1.1,
29
- pad_token_id = 50256,
30
- num_return_sequences = 1
31
- )
32
- )
33
- self.prompt_template = PromptTemplate(
34
- template = template,
35
- input_variables = [
36
- "user_input",
37
- "user_name",
38
- "char_name",
39
- "char_persona",
40
- "char_greeting",
41
- "chat_history"
42
- ],
43
- validate_template = True
44
- )
45
- self.llm_engine = LLMChain(
46
- llm = self.local_llm,
47
- prompt = self.prompt_template,
48
- verbose = True
49
- )
50
 
51
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
- """
53
- data args:
54
- inputs (:obj: `str`)
55
- date (:obj: `str`)
56
- Return:
57
- A :obj:`list` | `dict`: will be serialized and returned
58
- """
59
- inputs = data.pop("inputs", data)
60
 
61
- return self.llm_engine.predict(
62
- user_input = inputs["user_input"],
63
- user_name = inputs["user_name"],
64
- char_name = inputs["char_name"],
65
- char_persona = inputs["char_persona"],
66
- char_greeting = inputs["char_greeting"],
67
- chat_history = inputs["chat_history"]
68
- ).split("\n",1)[0]
 
14
  class EndpointHandler():
15
 
16
  def __init__(self, path=""):
17
+ pass
18
+ # tokenizer = AutoTokenizer.from_pretrained(path)
19
+ # model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
20
+ # local_llm = HuggingFacePipeline(
21
+ # pipeline = pipeline(
22
+ # "text-generation",
23
+ # model = model,
24
+ # tokenizer = tokenizer,
25
+ # max_length = 2048,
26
+ # temperature = 0.5,
27
+ # top_p = 0.9,
28
+ # top_k = 0,
29
+ # repetition_penalty = 1.1,
30
+ # pad_token_id = 50256,
31
+ # num_return_sequences = 1
32
+ # )
33
+ # )
34
+ # prompt_template = PromptTemplate(
35
+ # template = template,
36
+ # input_variables = [
37
+ # "user_input",
38
+ # "user_name",
39
+ # "char_name",
40
+ # "char_persona",
41
+ # "char_greeting",
42
+ # "chat_history"
43
+ # ],
44
+ # validate_template = True
45
+ # )
46
+ # self.llm_engine = LLMChain(
47
+ # llm = local_llm,
48
+ # prompt = prompt_template
49
+ # )
50
 
51
+ def __call__(self, data: Any) -> Any:
52
+ return data, type(data)
53
+ # inputs = data.pop("inputs", data)
 
 
 
 
 
 
54
 
55
+ # return self.llm_engine.predict(
56
+ # user_input = inputs["user_input"],
57
+ # user_name = inputs["user_name"],
58
+ # char_name = inputs["char_name"],
59
+ # char_persona = inputs["char_persona"],
60
+ # char_greeting = inputs["char_greeting"],
61
+ # chat_history = inputs["chat_history"]
62
+ # ).split("\n",1)[0]