Spaces:
Runtime error
Runtime error
import os | |
import copy | |
import json | |
import types | |
import asyncio | |
import argparse | |
from urlextract import URLExtract | |
from urllib.request import urlopen | |
from concurrent.futures import ThreadPoolExecutor | |
import discord | |
from discord.errors import HTTPException | |
import global_vars | |
from pingpong.context import InternetSearchStrategy, SimilaritySearcher | |
from discordbot.req import ( | |
sync_task, build_prompt, build_ppm | |
) | |
from discordbot.flags import parse_req | |
from discordbot import helps, post | |
from dumb_utils import URLSearchStrategy | |
model_info = json.load(open("model_cards.json")) | |
intents = discord.Intents.default() | |
intents.members = True | |
client = discord.Client(intents=intents) | |
queue = asyncio.Queue() | |
special_words = [ | |
"help", | |
"model-info", | |
"default-params", | |
] | |
max_response_length = 2000 | |
async def build_prompt_and_reply(executor, user_name, user_id): | |
other_job_on_progress = False | |
loop = asyncio.get_running_loop() | |
print(queue.qsize()) | |
msg = await queue.get() | |
user_msg, user_args = parse_req( | |
msg.content.replace(f"@{user_name} ", "").replace(f"<@{user_id}> ", ""), global_vars.gen_config | |
) | |
if user_msg == "help": | |
await msg.channel.send(helps.get_help()) | |
elif user_msg == "model-info": | |
await msg.channel.send(helps.get_model_info(model_name, model_info)) | |
elif user_msg == "default-params": | |
await msg.channel.send(helps.get_default_params(global_vars.gen_config, user_args["max-windows"])) | |
else: | |
try: | |
ppm = await build_ppm(msg, user_msg, user_name, user_id) | |
if user_args["internet"] and serper_api_key is not None: | |
other_job_on_progress = True | |
progress_msg = await msg.reply("Progress π§", mention_author=False) | |
internet_search_ppm = copy.deepcopy(ppm) | |
internet_search_prompt = f"My question is '{user_msg}'. Based on the conversation history, give me an appropriate query to answer my question for google search. You should not say more than query. You should not say any words except the query." | |
internet_search_ppm.pingpongs[-1].ping = internet_search_prompt | |
internet_search_prompt = await build_prompt( | |
internet_search_ppm, | |
ctx_include=False, | |
win_size=user_args["max-windows"] | |
) | |
internet_search_prompt_response = await loop.run_in_executor( | |
executor, sync_task, internet_search_prompt, user_args | |
) | |
internet_search_prompt_response = post.clean(internet_search_prompt_response) | |
ppm.pingpongs[-1].ping = internet_search_prompt_response | |
await progress_msg.edit( | |
content=f"β’ Search query re-organized by LLM: {internet_search_prompt_response}", | |
suppress=True | |
) | |
searcher = SimilaritySearcher.from_pretrained(device=global_vars.device) | |
logs = "" | |
for step_ppm, step_msg in InternetSearchStrategy( | |
searcher, serper_api_key=serper_api_key | |
)(ppm, search_query=internet_search_prompt_response, top_k=8): | |
ppm = step_ppm | |
logs = logs + step_msg + "\n" | |
await progress_msg.edit(content=logs, suppress=True) | |
else: | |
url_extractor = URLExtract() | |
urls = url_extractor.find_urls(user_msg) | |
print(f"urls = {urls}") | |
if len(urls) > 0: | |
progress_msg = await msg.reply("Progress π§", mention_author=False) | |
other_job_on_progress = True | |
searcher = SimilaritySearcher.from_pretrained(device=global_vars.device) | |
logs = "" | |
for step_result, step_ppm, step_msg in URLSearchStrategy(searcher)(ppm, urls, top_k=8): | |
if step_result is True: | |
ppm = step_ppm | |
logs = logs + step_msg + "\n" | |
await progress_msg.edit(content=logs, suppress=True) | |
else: | |
ppm = step_ppm | |
logs = logs + step_msg + "\n" | |
await progress_msg.edit(content=logs, suppress=True) | |
await asyncio.sleep(2) | |
break | |
prompt = await build_prompt(ppm, win_size=user_args["max-windows"]) | |
response = await loop.run_in_executor( | |
executor, sync_task, | |
prompt, user_args | |
) | |
response = post.clean(response) | |
response = f"**{model_name}** π¬\n{response.strip()}" | |
if len(response) >= max_response_length: | |
response = response[:max_response_length] | |
if other_job_on_progress is True: | |
await progress_msg.delete() | |
await msg.reply(response, mention_author=False) | |
except IndexError: | |
await msg.channel.send("Index error") | |
except HTTPException: | |
pass | |
async def background_task(user_name, user_id, max_workers): | |
executor = ThreadPoolExecutor(max_workers=max_workers) | |
print("Task Started. Waiting for inputs.") | |
while True: | |
await build_prompt_and_reply(executor, user_name, user_id) | |
async def on_ready(): | |
print(f"Logged in as {client.user}") | |
asyncio.get_running_loop().create_task( | |
background_task( | |
client.user.name, | |
client.user.id, | |
max_workers, | |
) | |
) | |
async def on_message(message): | |
if message.author == client.user: | |
return | |
if isinstance(message.channel, discord.channel.DMChannel) or\ | |
(client.user and client.user.mentioned_in(message)): | |
await queue.put(message) | |
def off_modes(args): | |
args.mode_cpu = False | |
args.mode_mps = False | |
args.mode_8bit = False | |
args.mode_4bit = False | |
args.mode_full_gpu = False | |
return args | |
def discord_main(args): | |
if args.token is None: | |
args.token = os.getenv('DISCORD_BOT_TOKEN') | |
if args.model_name is None: | |
args.model_name = os.getenv('DISCORD_BOT_MODEL_NAME') | |
if args.token is None or args.model_name is None: | |
print('Either or both of token and model-name is not provided') | |
print('Set them through CLI or environment variables(DISCORD_BOT_TOKEN, DISCORD_BOT_MODEL_NAME)') | |
quit() | |
if os.getenv('DISCORD_BOT_MAX_WORKERS'): | |
args.max_workers = int(os.getenv('DISCORD_BOT_MAX_WORKERS')) | |
if os.getenv('DISCORD_BOT_LOAD_MODE'): | |
mode = os.getenv('DISCORD_BOT_LOAD_MODE') | |
if mode == "CPU": | |
off_modes(args) | |
args.mode_cpu = True | |
elif mode == "MPS": | |
off_modes(args) | |
args.mode_mps = True | |
elif mode == "8BIT": | |
off_modes(args) | |
args.mode_8bit = True | |
elif mode == "4BIT": | |
off_modes(args) | |
args.mode_4bit = True | |
elif mode == "HALF": | |
off_modes(args) | |
args.mode_full_gpu = True | |
global max_workers | |
global model_name | |
global serper_api_key | |
max_workers = args.max_workers | |
model_name = args.model_name | |
serper_api_key = args.serper_api_key | |
selected_model_info = model_info[model_name] | |
tmp_args = types.SimpleNamespace() | |
tmp_args.base_url = selected_model_info['hub(base)'] | |
tmp_args.ft_ckpt_url = selected_model_info['hub(ckpt)'] | |
tmp_args.gen_config_path = selected_model_info['default_gen_config'] | |
tmp_args.gen_config_summarization_path = selected_model_info['default_gen_config'] | |
tmp_args.force_download_ckpt = False | |
tmp_args.thumbnail_tiny = selected_model_info['thumb-tiny'] | |
tmp_args.mode_cpu = args.mode_cpu | |
tmp_args.mode_mps = args.mode_mps | |
tmp_args.mode_8bit = args.mode_8bit | |
tmp_args.mode_4bit = args.mode_4bit | |
tmp_args.mode_full_gpu = args.mode_full_gpu | |
tmp_args.local_files_only = args.local_files_only | |
try: | |
global_vars.initialize_globals(tmp_args) | |
except RuntimeError as e: | |
print("GPU memory is not enough to load this model.") | |
quit() | |
client.run(args.token) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# can be set via environment variable | |
# --token == DISCORD_BOT_TOKEN | |
# --model-name == DISCORD_BOT_MODEL_NAME | |
parser.add_argument('--token', default=None, type=str) | |
parser.add_argument('--model-name', default=None, type=str) | |
parser.add_argument('--max-workers', default=1, type=int) | |
parser.add_argument('--mode-cpu', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--mode-mps', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--mode-8bit', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--mode-4bit', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--mode-full-gpu', default=True, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--local-files-only', default=False, action=argparse.BooleanOptionalAction) | |
parser.add_argument('--serper-api-key', default=None, type=str) | |
args = parser.parse_args() | |
discord_main(args) | |