File size: 3,009 Bytes
4df8249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import re
import copy
import global_vars

from discordbot.utils import (
    get_chat_manager,
    get_global_context
)
from discordbot.flags import (
    parse_known_flags,
    known_flags_def
)

from pingpong import PingPong
from pingpong.context import CtxLastWindowStrategy

from discord import NotFound

from transformers import GenerationConfig

def sync_task(prompt, args):
    input_ids = global_vars.tokenizer(prompt, return_tensors="pt").input_ids.to(global_vars.device)
    
    gen_config = copy.deepcopy(global_vars.gen_config)
    if args["max-new-tokens"] is not None:        
        gen_config.max_new_tokens = args["max-new-tokens"]
    if args["temperature"] is not None:        
        gen_config.temperature = args["temperature"]
    if args["do-sample"] is not None:        
        gen_config.do_sample = args["do-sample"]
    if args["top-p"] is not None:        
        gen_config.top_p = args["top-p"]
    
    generated_ids = global_vars.model.generate(
        input_ids=input_ids, 
        generation_config=gen_config
    )
    response = global_vars.tokenizer.decode(generated_ids[0][input_ids.shape[-1]:])
    return response

async def build_prompt(ppmanager, ctx_include=True, win_size=3):
    dummy_ppm = copy.deepcopy(ppmanager)
    if ctx_include:
        dummy_ppm.ctx = get_global_context(global_vars.model_type)
    else:
        dummy_ppm.ctx = ""
    
    lws = CtxLastWindowStrategy(win_size)
    return lws(dummy_ppm)

async def build_ppm(msg, msg_content, username, user_id):
    ppm = get_chat_manager(global_vars.model_type)
    
    channel = msg.channel
    user_msg = msg_content
    
    packs = []    
    partial_count = 0
    total_count = 0
    
    while True:
        try:
            if msg.reference is not None:
                ref_id = msg.reference.message_id
                msg = await channel.fetch_message(ref_id)
                msg_content = msg.content.replace(f"@{username} ", "").replace(f"<@{user_id}> ", "")
                try: 
                    idx = msg_content.index("💬")
                    msg_content = msg_content[idx+1:].strip()
                except:
                    msg_content = msg_content.strip()
                
                msg_content, _ = parse_known_flags(
                    msg_content, 
                    known_flags_def,
                    global_vars.gen_config
                )
                print(msg_content)
                
                packs.insert(
                    0, msg_content
                )
               
                partial_count = partial_count + 1
                if partial_count >= 2:
                    partial_count = 0
            else:
                break
        
        except NotFound:
            break
    
    for idx in range(0, len(packs), 2):
        ppm.add_pingpong(
            PingPong(packs[idx], packs[idx+1])
        )
        
    ppm.add_pingpong(
        PingPong(user_msg, "")
    )
    print(ppm.pingpongs)
    
    return ppm