Bram Vanroy
make style
d1f2e36
raw
history blame
2.52 kB
from typing import List, Tuple
import streamlit as st
import torch
from optimum.bettertransformer import BetterTransformer
from torch import nn, qint8
from torch.quantization import quantize_dynamic
from transformers import T5ForConditionalGeneration, T5Tokenizer
@st.cache_resource(show_spinner=False)
def get_resources(quantize: bool = True, no_cuda: bool = False) -> Tuple[T5ForConditionalGeneration, T5Tokenizer]:
"""Load a T5 model and its (slow) tokenizer"""
tokenizer = T5Tokenizer.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023", use_fast=False)
model = T5ForConditionalGeneration.from_pretrained("BramVanroy/ul2-base-dutch-simplification-mai-2023")
model = BetterTransformer.transform(model, keep_original_model=False)
model.resize_token_embeddings(len(tokenizer))
if torch.cuda.is_available() and not no_cuda:
model = model.to("cuda")
elif quantize: # Quantization not supported on CUDA
model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8)
model.eval()
return model, tokenizer
def batchify(iterable, batch_size=16):
"""Turn an iterable in a batch generator
:param iterable: iterable to batchify
:param batch_size: batch size
"""
num_items = len(iterable)
for idx in range(0, num_items, batch_size):
yield iterable[idx : min(idx + batch_size, num_items)]
def simplify(
texts: List[str], model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, batch_size: int = 16
) -> List[str]:
"""Simplify a given set of texts with a given model and tokenizer. Yields results in batches of 'batch_size'
:param texts: texts to simplify
:param model: model to use for simplification
:param tokenizer: tokenizer to use for simplification
:param batch_size: batch size to yield results in
"""
for batch_texts in batchify(texts, batch_size=batch_size):
nlg_batch_texts = ["[NLG] " + text for text in batch_texts]
encoded = tokenizer(nlg_batch_texts, return_tensors="pt", padding=True, truncation=True)
encoded = {k: v.to(model.device) for k, v in encoded.items()}
gen_kwargs = {
"max_new_tokens": 128,
"num_beams": 3,
}
with torch.no_grad():
encoded = {k: v.to(model.device) for k, v in encoded.items()}
generated = model.generate(**encoded, **gen_kwargs).cpu()
yield batch_texts, tokenizer.batch_decode(generated, skip_special_tokens=True)