Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import os | |
import aiohttp | |
import gradio as gr | |
import numpy as np | |
import spaces | |
from huggingface_hub import InferenceClient | |
import random | |
import torch | |
from huggingface_hub import AsyncInferenceClient | |
from transformers import LlamaTokenizer, LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM | |
async def query_llm(payload, model_name): | |
headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"} | |
async with aiohttp.ClientSession() as session: | |
async with session.post(f"https://api-inference.huggingface.co/models/{model_name}", headers=headers, | |
json=payload) as response: | |
return await response.json() | |
async def generate_mistral_7bvo1(system_input, user_input): | |
client = AsyncInferenceClient( | |
"mistralai/Mistral-7B-Instruct-v0.1", | |
token=os.getenv('HF_TOKEN'), | |
) | |
async for message in await client.chat_completion( | |
messages=[ | |
{"role": "system", "content": system_input}, | |
{"role": "user", "content": user_input}, ], | |
max_tokens=256, | |
stream=True, | |
): | |
yield message.choices[0].delta.content | |
async def generate_t5(system_input, user_input): | |
output = await query_llm({ | |
"inputs": (inputs := f"{system_input}\n{user_input}"), | |
}, "google/flan-t5-large") | |
try: | |
yield output[0]["generated_text"] | |
except (IndexError, KeyError): | |
yield str(output) | |
async def generate_gpt2(system_input, user_input): | |
output = await query_llm({ | |
"inputs": (inputs := f"{system_input}\n{user_input}"), | |
}, "openai-community/gpt2") | |
yield output[0]["generated_text"][:532] | |
async def generate_llama2(system_input, user_input): | |
client = AsyncInferenceClient( | |
"meta-llama/Llama-2-7b-chat-hf", | |
token=os.getenv('HF_TOKEN') | |
) | |
async for message in await client.chat_completion( | |
messages=[ | |
{"role": "system", "content": system_input}, | |
{"role": "user", "content": user_input}, ], | |
max_tokens=256, | |
stream=True, | |
): | |
yield message.choices[0].delta.content | |
async def generate_llama3(system_input, user_input): | |
client = AsyncInferenceClient( | |
"meta-llama/Meta-Llama-3.1-8B-Instruct", | |
token=os.getenv('HF_TOKEN') | |
) | |
try: | |
async for message in await client.chat_completion( | |
messages=[ | |
{"role": "system", "content": system_input}, | |
{"role": "user", "content": user_input}, ], | |
max_tokens=256, | |
stream=True, | |
): | |
yield message.choices[0].delta.content | |
except json.JSONDecodeError: | |
pass | |
async def generate_mixtral(system_input, user_input): | |
client = AsyncInferenceClient( | |
"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
token=os.getenv('HF_TOKEN') | |
) | |
try: | |
async for message in await client.chat_completion( | |
messages=[ | |
{"role": "system", "content": system_input}, | |
{"role": "user", "content": user_input}, ], | |
max_tokens=256, | |
stream=True, | |
): | |
yield message.choices[0].delta.content | |
except json.JSONDecodeError: | |
pass | |