Spaces:
Runtime error
Runtime error
File size: 3,009 Bytes
4df8249 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import re
import copy
import global_vars
from discordbot.utils import (
get_chat_manager,
get_global_context
)
from discordbot.flags import (
parse_known_flags,
known_flags_def
)
from pingpong import PingPong
from pingpong.context import CtxLastWindowStrategy
from discord import NotFound
from transformers import GenerationConfig
def sync_task(prompt, args):
input_ids = global_vars.tokenizer(prompt, return_tensors="pt").input_ids.to(global_vars.device)
gen_config = copy.deepcopy(global_vars.gen_config)
if args["max-new-tokens"] is not None:
gen_config.max_new_tokens = args["max-new-tokens"]
if args["temperature"] is not None:
gen_config.temperature = args["temperature"]
if args["do-sample"] is not None:
gen_config.do_sample = args["do-sample"]
if args["top-p"] is not None:
gen_config.top_p = args["top-p"]
generated_ids = global_vars.model.generate(
input_ids=input_ids,
generation_config=gen_config
)
response = global_vars.tokenizer.decode(generated_ids[0][input_ids.shape[-1]:])
return response
async def build_prompt(ppmanager, ctx_include=True, win_size=3):
dummy_ppm = copy.deepcopy(ppmanager)
if ctx_include:
dummy_ppm.ctx = get_global_context(global_vars.model_type)
else:
dummy_ppm.ctx = ""
lws = CtxLastWindowStrategy(win_size)
return lws(dummy_ppm)
async def build_ppm(msg, msg_content, username, user_id):
ppm = get_chat_manager(global_vars.model_type)
channel = msg.channel
user_msg = msg_content
packs = []
partial_count = 0
total_count = 0
while True:
try:
if msg.reference is not None:
ref_id = msg.reference.message_id
msg = await channel.fetch_message(ref_id)
msg_content = msg.content.replace(f"@{username} ", "").replace(f"<@{user_id}> ", "")
try:
idx = msg_content.index("💬")
msg_content = msg_content[idx+1:].strip()
except:
msg_content = msg_content.strip()
msg_content, _ = parse_known_flags(
msg_content,
known_flags_def,
global_vars.gen_config
)
print(msg_content)
packs.insert(
0, msg_content
)
partial_count = partial_count + 1
if partial_count >= 2:
partial_count = 0
else:
break
except NotFound:
break
for idx in range(0, len(packs), 2):
ppm.add_pingpong(
PingPong(packs[idx], packs[idx+1])
)
ppm.add_pingpong(
PingPong(user_msg, "")
)
print(ppm.pingpongs)
return ppm
|