Spaces:
Running
Running
import gradio as gr | |
import torch | |
from tokenizers import Tokenizer | |
import os | |
from HROM_Trainer import HROM, CONFIG, SafetyManager | |
def load_latest_checkpoint(model, device): | |
checkpoint_dir = "checkpoints" | |
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] | |
if not checkpoints: | |
raise FileNotFoundError("No checkpoints found.") | |
checkpoints = sorted(checkpoints, key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True) | |
latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0]) | |
checkpoint = torch.load(latest_checkpoint, map_location=device) | |
model.load_state_dict(checkpoint['model']) | |
return model | |
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200): | |
device = next(model.parameters()).device | |
generated_ids = input_ids.copy() | |
for _ in range(max_length): | |
input_tensor = torch.tensor([generated_ids], device=device) | |
with torch.no_grad(): | |
logits = model(input_tensor) | |
next_token = logits.argmax(-1)[:, -1].item() | |
if next_token == tokenizer.token_to_id("</s>"): | |
break | |
current_text = tokenizer.decode(generated_ids + [next_token]) | |
if not safety_manager.content_filter(current_text): | |
break | |
generated_ids.append(next_token) | |
return generated_ids[len(input_ids):] | |
# Initialize components once | |
tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = HROM().to(device) | |
model = load_latest_checkpoint(model, device) | |
model.eval() | |
safety = SafetyManager(model, tokenizer) | |
max_response_length = 200 | |
def process_message(user_input, chat_history, token_history): | |
# Process user input | |
user_turn = f"<user> {user_input} </s>" | |
user_tokens = tokenizer.encode(user_turn).ids | |
token_history.extend(user_tokens) | |
# Prepare input sequence | |
input_sequence = [tokenizer.token_to_id("<s>")] + token_history | |
# Truncate if needed | |
max_input_len = CONFIG["max_seq_len"] - max_response_length | |
if len(input_sequence) > max_input_len: | |
input_sequence = input_sequence[-max_input_len:] | |
token_history = input_sequence[1:] | |
# Generate response | |
response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length) | |
# Process assistant response | |
assistant_text = "I couldn't generate a proper response." | |
if response_ids: | |
if response_ids[0] == tokenizer.token_to_id("<assistant>"): | |
try: | |
end_idx = response_ids.index(tokenizer.token_to_id("</s>")) | |
assistant_text = tokenizer.decode(response_ids[1:end_idx]) | |
token_history.extend(response_ids[:end_idx+1]) | |
except ValueError: | |
assistant_text = tokenizer.decode(response_ids[1:]) | |
token_history.extend(response_ids) | |
else: | |
assistant_text = tokenizer.decode(response_ids) | |
token_history.extend(response_ids) | |
chat_history.append((user_input, assistant_text)) | |
return chat_history, token_history | |
def clear_history(): | |
return [], [] | |
with gr.Blocks() as demo: | |
gr.Markdown("# HROM Chatbot") | |
chatbot = gr.Chatbot(height=500) | |
msg = gr.Textbox(label="Your Message") | |
token_state = gr.State([]) | |
msg.submit( | |
process_message, | |
[msg, chatbot, token_state], | |
[chatbot, token_state], | |
queue=False | |
).then( | |
lambda: "", None, msg | |
) | |
clear_btn = gr.Button("Clear Chat History") | |
clear_btn.click( | |
clear_history, | |
outputs=[chatbot, token_state], | |
queue=False | |
) | |
demo.launch() |