trminhnam20082002's picture
feat: add model device before generating
171c344
raw
history blame
6.26 kB
# -*- coding: utf-8 -*-
import os
import re
import torch
from transformers import (
AutoTokenizer,
AutoModel,
T5ForConditionalGeneration,
MBartForConditionalGeneration,
AutoModelForSeq2SeqLM,
)
from tqdm.auto import tqdm
import streamlit as st
from typing import Dict, List
def get_model(args):
print(f"Using model {args.model_name}")
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
model.to(args.device)
if args.load_model_path:
print(f"Loading model from {args.load_model_path}")
model.load_state_dict(
torch.load(args.load_model_path, map_location=torch.device(args.device))
)
return model
@st.cache(allow_output_mutation=True)
def load_model(model_name, device):
print(f"Using model {model_name}")
os.makedirs("cache", exist_ok=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache")
model.to(device)
model_name = model_name.split("/")[-1]
load_model_path = os.path.join("models", f"{model_name}-best_loss.bin")
print(f"Loading model from {load_model_path}")
model.load_state_dict(
torch.load(load_model_path, map_location=torch.device(device))
)
return model
@st.cache(allow_output_mutation=True)
def load_tokenizer(model_name):
print(f"Loading tokenizer {model_name}")
if "mbart" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(
model_name, src_lang="vi_VN", tgt_lang="vi_VN"
)
# tokenizer.src_lang = "vi_VN"
# tokenizer.tgt_lang = "vi_VN"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"):
inputs = tokenizer(
batch["src"],
text_target=batch["tgt"] if is_train else None,
padding="longest",
max_length=max_len,
truncation=True,
return_tensors="pt",
)
for k, v in inputs.items():
inputs[k] = v.to(device)
return inputs
def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"):
inputs = tokenizer(
src,
padding="longest",
max_length=max_len,
truncation=True,
return_tensors="pt",
)
for k, v in inputs.items():
inputs[k] = v.to(device)
return inputs
def make_input_sentence_from_strings(data):
# data = {
# "CHỈ TIÊU": objective_name,
# "ĐƠN VỊ": unit,
# "ĐIỀU KIỆN": condition,
# "KPI mục tiêu tháng": kpi_target,
# "Đánh giá": evaluation_value,
# "Thời gian báo cáo": current_time,
# f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
# "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
# f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
# "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
# f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
# "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
# f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
# "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
# "Previous month": previous_month,
# "Previous year": previous_year,
# }
previous_month_value_key = data["Previous month value key"]
previous_year_value_key = data["Previous year value key"]
objective_name = data["CHỈ TIÊU"]
unit = data["ĐƠN VỊ"]
condition = data["ĐIỀU KIỆN"]
kpi_target = data["KPI mục tiêu tháng"]
current_time = data["Thời gian báo cáo"]
real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"]
evaluation_value = data["Đánh giá"]
previous_month_value = data[previous_month_value_key]
previous_year_value = data[previous_year_value_key]
previous_month_compare_key = data["Previous month compare key"]
previous_year_compare_key = data["Previous year compare key"]
previous_month_compare = data[previous_month_compare_key]
previous_year_compare = data[previous_year_compare_key]
previous_month = data["Previous month"]
previous_year = data["Previous year"]
# make a template string from the following example:
# """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}"""
template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}'
return template_str.format(
objective_name,
unit,
condition,
kpi_target,
current_time[1],
current_time[0],
real_value,
evaluation_value,
previous_month[1],
previous_month[0],
previous_month_value,
previous_month[1],
previous_month[0],
previous_month_compare,
previous_year[1],
previous_year[0],
previous_year_value,
previous_year[1],
previous_year[0],
previous_year_compare,
)
@torch.no_grad()
def generate_description(
input_string, model, tokenizer, device, max_len, model_name, beam_size
):
model.eval()
model = model.to(device)
inputs = prepare_single_model_inputs(
input_string, tokenizer, max_len=max_len, device=device
)
if "mbart" in model_name.lower():
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]
outputs = model.generate(
**inputs,
max_length=max_len,
num_beams=beam_size,
# early_stopping=True,
)
return tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
)