Spaces:
Running
on
L4
Running
on
L4
import os | |
import re | |
import time | |
import json | |
from itertools import cycle | |
import torch | |
import gradio as gr | |
from urllib.parse import unquote | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList | |
from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields | |
from examples import examples as input_examples | |
from nuextract_logging import log_event | |
MAX_INPUT_SIZE = 10_000 | |
MAX_NEW_TOKENS = 4_000 | |
MAX_WINDOW_SIZE = 4_000 | |
markdown_description = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>NuExtract-v1.5 Playground</title> | |
</head> | |
<body> | |
<img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;"> | |
<p>NuMind is a startup developing custom information extraction solutions. NuExtract is a zero-shot model.</p> | |
<p>If you want the best performance on your problem, please contact us π.</p> | |
<ul> | |
<li><strong>Website</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li> | |
</ul> | |
<br> | |
<h1>NuExtract-v1.5</h1> | |
<p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction. | |
It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian). | |
To use the model, provide an input text and a JSON template describing the information you need to extract.</p> | |
<ul> | |
<li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li> | |
</ul> | |
<i>NOTE: in this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i> | |
</body> | |
</html> | |
""" | |
def highlight_words(input_text, json_output): | |
colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"]) | |
color_map = {} | |
highlighted_text = input_text | |
leaves = extract_leaves(json_output) | |
for path, value in leaves: | |
path_key = tuple(path) | |
if path_key not in color_map: | |
color_map[path_key] = next(colors) | |
color = color_map[path_key] | |
# highlighted_text = highlighted_text.replace(f" {value}", f" <span style='background-color: {color};'>{unquote(f'{value}')}</span>") | |
pattern = rf"(?<=[ \n\t]){re.escape(value)}(?=[ \n\t\.\,\?\:\;])" | |
replacement = f"<span style='background-color: {color};'>{unquote(value)}</span>" | |
highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE) | |
return highlighted_text | |
def predict_chunk(text, template, current, model, tokenizer): | |
current = clean_json_text(current) | |
input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{" | |
input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda") | |
output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True) | |
return clean_json_text(output.split("<|output|>")[1]) | |
def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128): | |
# Split text into chunks of n tokens | |
tokens = tokenizer.tokenize(text) | |
chunks = split_document(text, window_size, overlap, tokenizer) | |
# Iterate over text chunks | |
prev = template | |
full_pred = "" | |
for i, chunk in enumerate(chunks): | |
print(f"Processing chunk {i}...") | |
pred = predict_chunk(chunk, template, prev, model, tokenizer) | |
# Handle broken output | |
pred = handle_broken_output(pred, prev) | |
# create highlighted text | |
highlighted_pred = highlight_words(text, json.loads(pred)) | |
# Sync empty fields | |
synced_pred = sync_empty_fields(json.loads(pred), json.loads(template)) | |
synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False) | |
# Return progress, current prediction, and updated HTML | |
yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred | |
# Iterate | |
prev = pred | |
###### | |
# Load the model and tokenizer | |
model_name = "numind/NuExtract-v1.5" | |
auth_token = os.environ.get("HF_TOKEN") or True | |
model = AutoModelForCausalLM.from_pretrained(model_name, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", use_auth_token=auth_token) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token) | |
model.eval() | |
def gradio_interface_function(template, text, is_example): | |
# reject invalid JSON | |
try: | |
template_json = json.loads(template) | |
except: | |
yield "", "Invalid JSON template", "" | |
return # End the function since there was an error | |
if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE: | |
yield "", "Input text too long for space. Download model to use unrestricted.", "" | |
return # End the function since there was an error | |
# Initialize the sliding window prediction process | |
prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE) | |
# Iterate over the generator to return values at each step | |
for progress, full_pred, html_content in prediction_generator: | |
# yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content) | |
yield progress, full_pred, html_content | |
if not is_example: | |
log_event(text, template, full_pred) | |
# Set up the Gradio interface | |
iface = gr.Interface( | |
description=markdown_description, | |
fn=gradio_interface_function, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter Template here...", label="Template"), | |
gr.Textbox(lines=2, placeholder="Enter input Text here...", label="Input Text"), | |
gr.Checkbox(label="Is Example?", visible=False), | |
], | |
outputs=[ | |
gr.Textbox(label="Progress"), | |
gr.Textbox(label="Model Output"), | |
gr.HTML(label="Model Output with Highlighted Words"), | |
], | |
examples=input_examples, | |
# live=True # Enable real-time updates | |
) | |
iface.launch(debug=True, share=True) |