Aston-xMAD's picture
init commit
9382e3f verified
import json
import os
import time
import torch
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import random
from PIL import Image
# Environment variables
os.environ["TOKENIZERS_PARALLELISM"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# Global variables to store the model and tokenizer
model = None
tokenizer = None
# Load model and tokenizer
def load_model_and_tokenizer(model_name, dtype, kv_bits):
global model, tokenizer
if model is None or tokenizer is None:
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens = {"pad_token": "<PAD>"}
tokenizer.add_special_tokens(special_tokens)
config = AutoConfig.from_pretrained(model_name)
if kv_bits != "unquantized":
quantizer_path = f"codebooks/{model_name.split('/')[-1]}_{kv_bits}bit.xmad"
setattr(config, "quantizer_path", quantizer_path)
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
elif dtype == "fp32":
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto")
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
model.resize_token_embeddings(len(tokenizer))
tokenizer.padding_side = "left"
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
# Format response
def format_response(dialog, response):
question = next((turn['content'] for turn in dialog if turn['role'] == 'user'), 'No question found')
answer = response.split("assistant")[-1].strip()
return {"question": question, "answer": answer}
# Load questions
def load_questions(prompts_path, custom_questions):
with open(prompts_path, "r") as file:
dialogs = json.load(file)
selected_dialogs = []
if custom_questions:
for question in custom_questions:
if question.strip():
custom_dialog = [{"role": "user", "content": question}]
selected_dialogs.append(custom_dialog)
num_questions = 60 - len(selected_dialogs)
random.shuffle(dialogs)
selected_dialogs.extend(dialogs[:num_questions])
return selected_dialogs[:60]
# Inference
def infer(model_name, dialogs, num_new_tokens, temperature, dtype, kv_bits, progress=gr.Progress()):
print("Starting inference...")
model, tokenizer = load_model_and_tokenizer(model_name, dtype, kv_bits)
batch_inputs = [
tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True)
for dialog in dialogs
]
responses = []
start_time = time.time()
batch_size = 30 # Set batch size for processing, this can be adjusted
num_dialogs = len(dialogs)
total_time = 0
total_tokens = 0
num_batches = (num_dialogs + batch_size - 1) // batch_size
for batch_idx in range(num_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, num_dialogs)
batch = batch_inputs[start_idx:end_idx]
encoded_inputs = tokenizer(batch, padding=True, truncation=False, return_tensors="pt")
input_ids = encoded_inputs["input_ids"].to(model.device)
attention_mask = encoded_inputs["attention_mask"].to(model.device)
with torch.no_grad():
torch.cuda.synchronize()
batch_start_time = time.perf_counter()
# Generate responses and measure time to first token
output_tokens = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=num_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
batch_end_time = time.perf_counter()
batch_time = batch_end_time - batch_start_time
total_time += batch_time
total_tokens += output_tokens.numel()
# Calculate TTFT
if batch_idx == 0:
ttft = batch_time / input_ids.size(0) # Time to first token for the first batch
decoded_outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
for i, response in enumerate(decoded_outputs):
original_dialog = dialogs[start_idx + i]
formatted_response = format_response(original_dialog, response)
responses.append(formatted_response)
formatted_responses = "\n\n---\n\n".join([f"**Question**: {res['question']}\n\n**Answer**: {res['answer']}" for res in responses])
yield formatted_responses
progress((batch_idx + 1) / num_batches, desc="Processing batches")
elapsed_time = time.time() - start_time
tokens_per_second = total_tokens / total_time if total_time > 0 else 0
print(f"Inference completed in {elapsed_time:.2f} seconds.")
yield {
"Time Taken (seconds)": elapsed_time,
"Tokens per Second": tokens_per_second,
"Time to First Token (TTFT, seconds)": ttft,
"Formatted Responses": formatted_responses
}
# Demo function
def demo(num_new_tokens, temperature, custom_questions_text, kv_bits, progress=gr.Progress()):
custom_questions = custom_questions_text.split("\n")
print("Loading questions...")
dialogs = load_questions("chats_sys_none.json", custom_questions)
print(f"{len(dialogs)} questions loaded. Starting inference...")
result_gen = infer("NousResearch/Meta-Llama-3-8B-Instruct", dialogs, num_new_tokens, temperature, "fp16", kv_bits, progress=progress)
formatted_responses = ""
for result in result_gen:
if isinstance(result, str):
formatted_responses = result
yield None, None, None, formatted_responses
else:
time_taken = result["Time Taken (seconds)"]
tokens_per_second = result["Tokens per Second"]
ttft = result["Time to First Token (TTFT, seconds)"]
formatted_responses = result["Formatted Responses"]
yield time_taken, tokens_per_second, ttft, formatted_responses
# Load JSON data
with open("chats_sys_none.json", "r") as file:
json_data = json.load(file)
# Load 50 random questions into the input area by default
def load_default_questions():
random.shuffle(json_data)
default_questions = [dialog[0]['content'] for dialog in json_data[:50] if 'content' in dialog[0]]
return "\n".join(default_questions)
# Gradio interface
demo_interface = gr.Interface(
fn=demo,
inputs=[
gr.Slider(label="Number of New Tokens", minimum=128, maximum=1024, step=128, value=512),
gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4),
gr.Textbox(label="Custom Questions", placeholder="Type your custom questions here, one per line...", lines=5),
gr.Dropdown(label="KV Bits", choices=["1", "2", "4", "unquantized"], value="1")
],
outputs=[
gr.Number(label="Time Taken (seconds)", interactive=False),
gr.Number(label="Tokens per Second", interactive=False),
gr.Number(label="Time to First Token (TTFT, seconds)", interactive=False),
gr.Markdown(label="Formatted Responses", elem_id="scrollable-output")
],
live=False
)
# Gradio Blocks for additional controls
with gr.Blocks(css=".scrollable-output {height: 400px; overflow-y: auto; padding: 10px; border: 1px solid #ccc;}") as app:
with gr.Column():
gr.Markdown("### LLM Inference Demo")
with gr.Row():
num_new_tokens = gr.Slider(label="Number of New Tokens", minimum=128, maximum=1024, step=128, value=512)
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4)
kv_bits = gr.Dropdown(label="KV Bits", choices=["1", "2", "4", "unquantized"], value="1")
custom_questions_text = gr.Textbox(label="Custom Questions", placeholder="Type your custom questions here, one per line...", lines=5)
load_questions_btn = gr.Button("Load Default Questions")
with gr.Row():
time_taken = gr.Number(label="Time Taken (seconds)", interactive=False)
tokens_per_second = gr.Number(label="Tokens per Second", interactive=False)
ttft = gr.Number(label="Time to First Token (TTFT, seconds)", interactive=False)
formatted_responses = gr.Markdown(label="Formatted Responses", elem_id="scrollable-output")
demo_btn = gr.Button("Run Inference")
load_questions_btn.click(fn=lambda: load_default_questions(), inputs=[], outputs=custom_questions_text)
demo_btn.click(demo, inputs=[num_new_tokens, temperature, custom_questions_text, kv_bits], outputs=[time_taken, tokens_per_second, ttft, formatted_responses])
if __name__ == "__main__":
print("Checking if the image path is correct...")
check_image_path("memory_usage.png") # Check image path on startup
print("Loading model and tokenizer on startup...")
load_model_and_tokenizer("NousResearch/Meta-Llama-3-8B-Instruct", "fp16", "1")
print("Model and tokenizer loaded. Starting Gradio interface...")
app.launch()