HROM-V1 / app.py
TimurHromek's picture
Uploaded model code and more.
95d187a
raw
history blame
3.75 kB
import gradio as gr
import torch
from tokenizers import Tokenizer
import os
from HROM_Trainer import HROM, CONFIG, SafetyManager
def load_latest_checkpoint(model, device):
checkpoint_dir = "checkpoints"
checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")]
if not checkpoints:
raise FileNotFoundError("No checkpoints found.")
checkpoints = sorted(checkpoints, key=lambda x: os.path.getmtime(os.path.join(checkpoint_dir, x)), reverse=True)
latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[0])
checkpoint = torch.load(latest_checkpoint, map_location=device)
model.load_state_dict(checkpoint['model'])
return model
def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
device = next(model.parameters()).device
generated_ids = input_ids.copy()
for _ in range(max_length):
input_tensor = torch.tensor([generated_ids], device=device)
with torch.no_grad():
logits = model(input_tensor)
next_token = logits.argmax(-1)[:, -1].item()
if next_token == tokenizer.token_to_id("</s>"):
break
current_text = tokenizer.decode(generated_ids + [next_token])
if not safety_manager.content_filter(current_text):
break
generated_ids.append(next_token)
return generated_ids[len(input_ids):]
# Initialize components once
tokenizer = Tokenizer.from_file("tokenizer/hrom_tokenizer.json")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HROM().to(device)
model = load_latest_checkpoint(model, device)
model.eval()
safety = SafetyManager(model, tokenizer)
max_response_length = 200
def process_message(user_input, chat_history, token_history):
# Process user input
user_turn = f"<user> {user_input} </s>"
user_tokens = tokenizer.encode(user_turn).ids
token_history.extend(user_tokens)
# Prepare input sequence
input_sequence = [tokenizer.token_to_id("<s>")] + token_history
# Truncate if needed
max_input_len = CONFIG["max_seq_len"] - max_response_length
if len(input_sequence) > max_input_len:
input_sequence = input_sequence[-max_input_len:]
token_history = input_sequence[1:]
# Generate response
response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
# Process assistant response
assistant_text = "I couldn't generate a proper response."
if response_ids:
if response_ids[0] == tokenizer.token_to_id("<assistant>"):
try:
end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
assistant_text = tokenizer.decode(response_ids[1:end_idx])
token_history.extend(response_ids[:end_idx+1])
except ValueError:
assistant_text = tokenizer.decode(response_ids[1:])
token_history.extend(response_ids)
else:
assistant_text = tokenizer.decode(response_ids)
token_history.extend(response_ids)
chat_history.append((user_input, assistant_text))
return chat_history, token_history
def clear_history():
return [], []
with gr.Blocks() as demo:
gr.Markdown("# HROM Chatbot")
chatbot = gr.Chatbot(height=500)
msg = gr.Textbox(label="Your Message")
token_state = gr.State([])
msg.submit(
process_message,
[msg, chatbot, token_state],
[chatbot, token_state],
queue=False
).then(
lambda: "", None, msg
)
clear_btn = gr.Button("Clear Chat History")
clear_btn.click(
clear_history,
outputs=[chatbot, token_state],
queue=False
)
demo.launch()