neggles's picture
ok try this
c17e2d8
raw
history blame contribute delete
7.02 kB
import html
import logging
from pathlib import Path
import gradio as gr
from gradio.themes.utils import colors
from transformers import CLIPTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
gr_logger = logging.getLogger("gradio")
gr_logger.setLevel(logging.INFO)
class ClipUtil:
def __init__(self):
logger.info("Loading ClipUtil")
self.theme = gr.themes.Base(
primary_hue=colors.violet,
secondary_hue=colors.indigo,
neutral_hue=colors.slate,
font=[gr.themes.GoogleFont("Fira Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
font_mono=[gr.themes.GoogleFont("Fira Code"), "ui-monospace", "Consolas", "monospace"],
).set(
slider_color_dark="*primary_500",
)
try:
self.css = Path(__file__).with_suffix(".css").read_text()
except Exception:
logger.exception("Failed to load CSS file")
self.css = ""
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.vocab = {v: k for k, v in self.tokenizer.get_vocab().items()}
self.blocks = gr.Blocks(
title="ClipTokenizerUtil", analytics_enabled=False, theme=self.theme, css=self.css
)
def tokenize(self, text: str, input_ids: bool = False):
if input_ids:
tokens = [int(x.strip()) for x in text.split(",")]
else:
tokens = self.tokenizer(text, return_tensors="np").input_ids.squeeze().tolist()
code = ""
ids = []
current_ids = []
class_index = 0
byte_decoder = self.tokenizer.byte_decoder
def dump(last=False):
nonlocal code, ids, current_ids
words = [self.vocab.get(x, "") for x in current_ids]
def wordscode(ids, word):
nonlocal class_index
word_title = html.escape(", ".join([str(x) for x in ids]))
res = f"""
<span class='tokenizer-token tokenizer-token-{class_index % 4}' title='{word_title}'>
{html.escape(word)}
</span>
"""
class_index += 1
return res
try:
word = bytearray([byte_decoder[x] for x in "".join(words)]).decode("utf-8")
except UnicodeDecodeError:
if last:
word = "❌" * len(current_ids)
elif len(current_ids) > 4:
id = current_ids[0]
ids += [id]
local_ids = current_ids[1:]
code += wordscode([id], "❌")
current_ids = []
for id in local_ids:
current_ids.append(id)
dump()
return
else:
return
# word = word.replace("</w>", " ")
code += wordscode(current_ids, word)
ids += current_ids
current_ids = []
for token in tokens:
token = int(token)
current_ids.append(token)
dump()
dump(last=True)
ids_html = f"""
<p>
Token count: {len(ids)}
<br>
{", ".join([str(x) for x in ids])}
</p>"""
return code, ids_html
def tokenize_ids(self, text: str):
return self.tokenize(text, input_ids=True)
def create_components(self):
with self.blocks:
# title bar
with gr.Row().style(equal_height=True):
with gr.Column(scale=12, elem_id="header_col"):
self.header_title = gr.Markdown(
"## CLIP Tokenizer Util",
elem_id="header_title",
)
with gr.Column(scale=1, min_width=90, elem_id="button_col"):
with gr.Row(elem_id="button_row"):
self.reload_btn = gr.Button(
label="refresh",
elem_id="refresh_btn",
type="button",
value="πŸ”„",
variant="primary",
)
with gr.Tabs() as in_tabs:
with gr.Tab(label="Text Input", id="text_input_tab"):
with gr.Row().style(equal_height=True):
with gr.Column(scale=12, elem_id="text_input_col"):
self.text_input = gr.Textbox(
label="Text Input",
elem_id="tokenizer_prompt",
show_label=False,
lines=8,
placeholder="Prompt for tokenization",
)
self.text_button = gr.Button(
label="Tokenize",
elem_id="go_button",
value="Go",
variant="primary",
)
with gr.Tab(label="Token Input", id="token_input_tab"):
with gr.Row().style(equal_height=True):
with gr.Column(scale=12, elem_id="text_input_col"):
self.token_input = gr.Textbox(
lines=5,
label="Text Input",
elem_id="text_input",
placeholder="Enter text here",
)
self.token_button = gr.Button(
label="Tokenize",
elem_id="go_button",
type="button",
value="Go",
variant="primary",
)
with gr.Tabs():
with gr.Tab("Text"):
tokenized_text = gr.HTML(elem_id="tokenized_text")
with gr.Tab("Tokens"):
tokenized_ids = gr.HTML(elem_id="tokenized_ids")
self.text_button.click(
fn=self.tokenize,
inputs=[self.text_input],
outputs=[tokenized_text, tokenized_ids],
)
self.token_button.click(
fn=self.tokenize_ids,
inputs=[self.token_input],
outputs=[tokenized_text, tokenized_ids],
)
def launch(self, **kwargs) -> None:
return self.blocks.launch(
server_name="0.0.0.0",
show_error=True,
enable_queue=True,
**kwargs,
)
if __name__ == "__main__":
clip_util = ClipUtil()
clip_util.create_components()
clip_util.launch()