|
import json |
|
import os |
|
import time |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
import random |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "0" |
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
|
|
|
|
def load_model_and_tokenizer(model_name, dtype, kv_bits): |
|
global model, tokenizer |
|
if model is None or tokenizer is None: |
|
print("Loading model and tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
special_tokens = {"pad_token": "<PAD>"} |
|
tokenizer.add_special_tokens(special_tokens) |
|
|
|
config = AutoConfig.from_pretrained(model_name) |
|
if kv_bits != "unquantized": |
|
quantizer_path = f"codebooks/{model_name.split('/')[-1]}_{kv_bits}bit.xmad" |
|
setattr(config, "quantizer_path", quantizer_path) |
|
|
|
if dtype == "bf16": |
|
dtype = torch.bfloat16 |
|
elif dtype == "fp16": |
|
dtype = torch.float16 |
|
elif dtype == "fp32": |
|
dtype = torch.float32 |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=dtype, device_map="auto") |
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
tokenizer.padding_side = "left" |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
|
|
return model, tokenizer |
|
|
|
|
|
def format_response(dialog, response): |
|
question = next((turn['content'] for turn in dialog if turn['role'] == 'user'), 'No question found') |
|
answer = response.split("assistant")[-1].strip() |
|
return {"question": question, "answer": answer} |
|
|
|
|
|
def load_questions(prompts_path, custom_questions): |
|
with open(prompts_path, "r") as file: |
|
dialogs = json.load(file) |
|
|
|
selected_dialogs = [] |
|
|
|
if custom_questions: |
|
for question in custom_questions: |
|
if question.strip(): |
|
custom_dialog = [{"role": "user", "content": question}] |
|
selected_dialogs.append(custom_dialog) |
|
|
|
num_questions = 60 - len(selected_dialogs) |
|
random.shuffle(dialogs) |
|
selected_dialogs.extend(dialogs[:num_questions]) |
|
|
|
return selected_dialogs[:60] |
|
|
|
|
|
def infer(model_name, dialogs, num_new_tokens, temperature, dtype, kv_bits, progress=gr.Progress()): |
|
print("Starting inference...") |
|
model, tokenizer = load_model_and_tokenizer(model_name, dtype, kv_bits) |
|
batch_inputs = [ |
|
tokenizer.apply_chat_template(dialog, tokenize=False, add_generation_prompt=True) |
|
for dialog in dialogs |
|
] |
|
|
|
responses = [] |
|
start_time = time.time() |
|
|
|
batch_size = 60 |
|
num_dialogs = len(dialogs) |
|
total_time = 0 |
|
total_tokens = 0 |
|
num_batches = (num_dialogs + batch_size - 1) // batch_size |
|
|
|
for batch_idx in range(num_batches): |
|
start_idx = batch_idx * batch_size |
|
end_idx = min(start_idx + batch_size, num_dialogs) |
|
batch = batch_inputs[start_idx:end_idx] |
|
|
|
encoded_inputs = tokenizer(batch, padding=True, truncation=False, return_tensors="pt") |
|
input_ids = encoded_inputs["input_ids"].to(model.device) |
|
attention_mask = encoded_inputs["attention_mask"].to(model.device) |
|
|
|
with torch.no_grad(): |
|
torch.cuda.synchronize() |
|
batch_start_time = time.perf_counter() |
|
|
|
|
|
output_tokens = model.generate( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=num_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
torch.cuda.synchronize() |
|
batch_end_time = time.perf_counter() |
|
|
|
batch_time = batch_end_time - batch_start_time |
|
total_time += batch_time |
|
total_tokens += output_tokens.numel() |
|
|
|
|
|
if batch_idx == 0: |
|
ttft = batch_time / input_ids.size(0) |
|
|
|
decoded_outputs = tokenizer.batch_decode(output_tokens, skip_special_tokens=True) |
|
|
|
for i, response in enumerate(decoded_outputs): |
|
original_dialog = dialogs[start_idx + i] |
|
formatted_response = format_response(original_dialog, response) |
|
responses.append(formatted_response) |
|
|
|
formatted_responses = "\n\n---\n\n".join([f"**Question**: {res['question']}\n\n**Answer**: {res['answer']}" for res in responses]) |
|
yield formatted_responses |
|
progress((batch_idx + 1) / num_batches, desc="Processing batches") |
|
|
|
elapsed_time = time.time() - start_time |
|
tokens_per_second = total_tokens / total_time if total_time > 0 else 0 |
|
print(f"Inference completed in {elapsed_time:.2f} seconds.") |
|
|
|
yield { |
|
"Time Taken (seconds)": elapsed_time, |
|
"Tokens per Second": tokens_per_second, |
|
"Time to First Token (TTFT, seconds)": ttft, |
|
"Formatted Responses": formatted_responses |
|
} |
|
|
|
|
|
def demo(num_new_tokens, temperature, custom_questions_text, kv_bits=1, progress=gr.Progress()): |
|
custom_questions = custom_questions_text.split("\n") |
|
print("Loading questions...") |
|
dialogs = load_questions("chats_sys_none.json", custom_questions) |
|
print(f"{len(dialogs)} questions loaded. Starting inference...") |
|
|
|
result_gen = infer("NousResearch/Meta-Llama-3-8B-Instruct", dialogs, num_new_tokens, temperature, "fp16", kv_bits, progress=progress) |
|
|
|
formatted_responses = "" |
|
for result in result_gen: |
|
if isinstance(result, str): |
|
formatted_responses = result |
|
yield None, None, None, formatted_responses |
|
else: |
|
time_taken = result["Time Taken (seconds)"] |
|
tokens_per_second = result["Tokens per Second"] |
|
ttft = result["Time to First Token (TTFT, seconds)"] |
|
formatted_responses = result["Formatted Responses"] |
|
yield time_taken, tokens_per_second, ttft, formatted_responses |
|
|
|
|
|
with open("chats_sys_none.json", "r") as file: |
|
json_data = json.load(file) |
|
|
|
|
|
def load_default_questions(): |
|
random.shuffle(json_data) |
|
default_questions = [dialog[0]['content'] for dialog in json_data[:50] if 'content' in dialog[0]] |
|
return "\n".join(default_questions) |
|
|
|
|
|
def load_questions_action(): |
|
return load_default_questions() |
|
|
|
|
|
css = """ |
|
body, html { |
|
height: 100vh; |
|
margin: 0; |
|
} |
|
|
|
.gradio-container { |
|
height: 100vh; |
|
} |
|
|
|
#main-row { |
|
height: 100%; |
|
} |
|
|
|
#control-panel, #formatted-responses-container { |
|
height: 100%; |
|
box-sizing: border-box; |
|
} |
|
|
|
#custom-questions-text, #formatted-responses { |
|
flex-grow: 1; |
|
overflow-y: auto; |
|
border: 1px solid #ccc; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as app: |
|
with gr.Row(elem_id="main-row", equal_height=True): |
|
with gr.Column(elem_id="control-panel", scale=1): |
|
num_new_tokens = gr.Slider(label="Number of New Tokens", minimum=128, maximum=1024, step=128, value=512) |
|
temperature = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.4) |
|
custom_questions_text = gr.Textbox(label="Custom Questions", placeholder="Type your custom questions here, one per line...", lines=22, elem_id="custom-questions-text") |
|
with gr.Row(elem_id="metrics-panel"): |
|
time_taken = gr.Number(label="Time Taken (seconds)", interactive=False, elem_classes=["metric"]) |
|
tokens_per_second = gr.Number(label="Tokens per Second", interactive=False, elem_classes=["metric"]) |
|
ttft = gr.Number(label="Time to First Token (TTFT, seconds)", interactive=False, elem_classes=["metric"]) |
|
with gr.Row(elem_id="buttons-container"): |
|
load_questions_btn = gr.Button("Load Default Questions") |
|
demo_btn = gr.Button("Run Inference", elem_id="run-inference-btn") |
|
|
|
|
|
formatted_responses = gr.Textbox(label="Formatted Responses", elem_id="formatted-responses", value="No responses yet. Run the inference to see results.", lines=35, autoscroll=False, show_copy_button=True) |
|
|
|
load_questions_btn.click(fn=load_questions_action, inputs=[], outputs=custom_questions_text) |
|
demo_btn.click(demo, inputs=[num_new_tokens, temperature, custom_questions_text], outputs=[time_taken, tokens_per_second, ttft, formatted_responses]) |
|
|
|
if __name__ == "__main__": |
|
print("Loading model and tokenizer on startup...") |
|
|
|
print("Model and tokenizer loaded. Starting Gradio interface...") |
|
app.launch() |
|
|