Aston-xMAD's picture
init commit
9382e3f verified
import json
import os
import time
import torch
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import random
# Environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Global variables to store the model and tokenizer
model = None
tokenizer = None
# Load model and tokenizer
def load_model_and_tokenizer(model_name, dtype, kv_bits):
global model, tokenizer
if model is None or tokenizer is None:
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens = {"pad_token": "<PAD>"}
tokenizer.add_special_tokens(special_tokens)
config = AutoConfig.from_pretrained(model_name)
if kv_bits != "unquantized":
quantizer_path = f"codebooks/{model_name.split('/')[-1]}_{kv_bits}bit.xmad"
setattr(config, "quantizer_path", quantizer_path)
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
elif dtype == "fp32":
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto")
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
model.resize_token_embeddings(len(tokenizer))
tokenizer.padding_side = "left"
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
# Format response
def format_response(dialog, response):
question = next((turn['content'] for turn in dialog if turn['role'] == 'user'), 'No question found')
answer = response.split("assistant")[-1].strip()
return {"question": question, "answer": answer}
# Load questions
def load_questions(prompts_path, custom_questions):
with open(prompts_path, "r") as file:
dialogs = json.load(file)
selected_dialogs = []
if custom_questions:
for question in custom_questions:
if question.strip():
custom_dialog = [{"role": "user", "content": question}]
selected_dialogs.append(custom_dialog)
num_questions = 30 - len(selected_dialogs)
random.shuffle(dialogs)
selected_dialogs.extend(dialogs[:num_questions])
return selected_dialogs[:30]
# Inference
def infer(model_name, dialogs, num_new_tokens, temperature, dtype, kv_bits):
print("Starting inference...")
model, tokenizer = load_model_and_tokenizer(model_name, dtype, kv_bits)
batch_inputs = [
tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True)
for dialog in dialogs
]
responses = []
start_time = time.time()
batch_size = 30 # Set batch size for processing, this can be adjusted
num_dialogs = len(dialogs)
total_time = 0
total_tokens = 0
total_ttft = 0
num_batches = (num_dialogs + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, num_dialogs)
batch = batch_inputs[start_idx:end_idx]
encoded_inputs = tokenizer(batch, padding=True, truncation=False, return_tensors="pt")
input_ids = encoded_inputs["input_ids"].to(model.device)
attention_mask = encoded_inputs["attention_mask"].to(model.device)
with torch.no_grad():
torch.cuda.synchronize()
batch_start_time = time.perf_counter()
output_tokens = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=num_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
batch_end_time = time.perf_counter()
batch_time = batch_end_time - batch_start_time
total_time += batch_time
total_tokens += output_tokens.numel()
if batch_idx == 0:
total_ttft = batch_time
decoded_outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
for i, response in enumerate(decoded_outputs):
original_dialog = dialogs[start_idx + i]
formatted_response = format_response(original_dialog, response)
responses.append(formatted_response)
elapsed_time = time.time() - start_time
ttft = total_ttft / batch_size if batch_size > 0 else 0
print(f"Inference completed in {elapsed_time:.2f} seconds.")
formatted_responses = "\n\n---\n\n".join([f"**Question**: {res['question']}\n\n**Answer**: {res['answer']}" for res in responses])
results = {
"Time Taken (seconds)": elapsed_time,
"Tokens per Second": total_tokens / total_time if total_time > 0 else 0,
"Time to First Token (seconds)": ttft,
"Responses": responses,
"Formatted Responses": formatted_responses
}
return results
# Demo function
def demo(num_new_tokens, temperature, custom_questions_text, kv_bits):
custom_questions = custom_questions_text.split("\n")
print("Loading questions...")
dialogs = load_questions("chats_sys_none.json", custom_questions)
print(f"{len(dialogs)} questions loaded. Starting inference...")
results = infer("NousResearch/Meta-Llama-3-8B-Instruct", dialogs, num_new_tokens, temperature, "fp16", kv_bits)
return results["Time Taken (seconds)"], results["Tokens per Second"], results["Time to First Token (seconds)"], results["Formatted Responses"]
# Load JSON data
with open("chats_sys_none.json", "r") as file:
json_data = json.load(file)
json_data_str = json.dumps(json_data, indent=2)
# Show JSON function
def show_json():
return json_data_str
# Gradio interface
interface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(label="Number of New Tokens", minimum=128, maximum=1024, step=128, value=512),
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4),
gr.Textbox(label="Custom Questions", placeholder="Type your custom questions here, one per line...", lines=5),
gr.Dropdown(label="KV Bits", choices=["1", "2", "4", "unquantized"], value="1")
],
outputs=[
gr.Number(label="Time Taken (seconds)", value=0),
gr.Number(label="Tokens per Second", value=0),
gr.Number(label="Time to First Token (seconds)", value=0),
gr.Markdown(label="Formatted Responses", value="No responses yet.")
],
title="LLM Inference Demo",
description="A demo for running LLM inference using Gradio and Hugging Face.",
live=False # Set to False to have a submit button
)
json_interface = gr.Interface(
fn=show_json,
inputs=[],
outputs=[
gr.HTML("<pre>{}</pre>".format(json_data_str))
],
live=False # Set to False to have a submit button
)
app = gr.Blocks()
with app:
with gr.Tab("LLM Inference Demo"):
interface.render()
with gr.Tab("Show JSON"):
json_interface.render()
if __name__ == "__main__":
print("Loading model and tokenizer on startup...")
load_model_and_tokenizer("NousResearch/Meta-Llama-3-8B-Instruct", "fp16", "1")
print("Model and tokenizer loaded. Starting Gradio interface...")
app.launch()