wesslen's picture
Upload folder using huggingface_hub
505dd8b verified
import copy
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Union, Dict, Any
import murmurhash
from spacy.language import Language
from spacy.tokens.doc import SetEntsDefault # type: ignore
from spacy.training import Example
from spacy.util import filter_spans
from prodigy.components.db import connect
from prodigy.components.decorators import support_both_streams
from prodigy.components.filters import filter_seen_before
from prodigy.components.preprocess import (
add_annot_name,
add_tokens,
add_view_id,
make_ner_suggestions,
make_raw_doc,
resolve_labels,
split_sentences,
)
from prodigy.components.sorters import prefer_uncertain
from prodigy.components.source import GeneratorSource
from prodigy.components.stream import Stream, get_stream, load_noop
from prodigy.core import Arg, recipe
from prodigy.errors import RecipeError
from prodigy.models.matcher import PatternMatcher
from prodigy.models.ner import EntityRecognizerModel, ensure_sentencizer
from prodigy.protocols import ControllerComponentsDict
from prodigy.types import (
ExistingFilePath,
LabelsType,
SourceType,
StreamType,
TaskType,
)
from prodigy.util import (
ANNOTATOR_ID_ATTR,
BINARY_ATTR,
INPUT_HASH_ATTR,
TASK_HASH_ATTR,
combine_models,
copy_nlp,
get_pipe_labels,
log,
msg,
set_hashes,
)
def modify_spans(document):
# Modify the 'spans' key to be an empty list
document['spans'] = []
return document
def spans_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool:
"""Checks if two spans are equal"""
return s1["start"] == s2["start"] and s1["end"] == s2["end"]
def labels_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool:
"""Checks if two spans have the same label"""
return s1["label"] == s2["label"]
def ensure_span_text(eg: TaskType) -> TaskType:
"""Ensure that all spans have a text attribute"""
for span in eg.get("spans", []):
if "text" not in span:
span["text"] = eg["text"][span["start"] : span["end"]]
return eg
def validate_answer(answer: TaskType, *, known_answers_map: Dict[int, TaskType]):
"""Validate the answer against the known answers"""
known_answer = known_answers_map.get(answer[INPUT_HASH_ATTR])
if known_answer is None:
print(f"Skipping validation for answer {answer[INPUT_HASH_ATTR]}, no known answer found to validate against.")
return
known_answer = ensure_span_text(known_answer)
errors = []
known_spans = known_answer.get("spans", [])
answer_spans = answer.get("spans", [])
explanation_label = known_answer.get("meta", {}).get("explanation_label")
explanation_boundaries = known_answer.get("meta", {}).get(
"explanation_boundaries"
)
if not explanation_boundaries:
explanation_boundaries = (
"No explanation boundaries"
)
if len(known_spans) > len(answer_spans):
errors.append(
"You noted fewer entities than expected for this answer. All mentions must be annotated"
)
elif len(known_spans) < len(answer_spans):
errors.append(
"You noted more entities than expected for this answer."
)
if not known_spans:
# For cases where no annotations are expected
errors.append(explanation_label)
for known_span, span in zip(known_spans, answer_spans):
if not labels_equal(known_span, span):
# label error
errors.append(explanation_label)
continue
if not spans_equal(known_span, span):
# boundary error
errors.append(explanation_boundaries)
continue
if len(errors) > 0:
error_msg = "\n".join(errors)
error_msg += "\n\nExpected annotations:"
if known_spans:
expected_spans = [
f'[{s["text"]}]: {s["label"]}' for s in known_spans
]
if expected_spans:
error_msg += "\n"
for span_msg in expected_spans:
error_msg += span_msg + "\n"
else:
error_msg += "\n\nNone."
raise ValueError(error_msg)
@recipe(
"ner.qa.manual",
# fmt: off
dataset=Arg(help="Dataset to save annotations to"),
nlp=Arg(help="Loadable spaCy pipeline for tokenization or blank:lang (e.g. blank:en)"),
source=Arg(help="Data to annotate (file path or '-' to read from standard input)"),
loader=Arg("--loader", "-lo", help="Loader (guessed from file extension if not set)"),
label=Arg("--label", "-l", help="Comma-separated label(s) to annotate or text file with one label per line"),
patterns=Arg("--patterns", "-pt", help="Path to match patterns file"),
exclude=Arg("--exclude", "-e", help="Comma-separated list of dataset IDs whose annotations to exclude"),
highlight_chars=Arg("--highlight-chars", "-C", help="Allow highlighting individual characters instead of tokens"),
# fmt: on
)
def manual(
dataset: str,
nlp: Language,
source: SourceType,
loader: Optional[str] = None,
label: Optional[LabelsType] = None,
patterns: Optional[ExistingFilePath] = None,
exclude: List[str] = [],
highlight_chars: bool = False,
) -> ControllerComponentsDict:
"""
Mark spans by token. Requires only a tokenizer and no entity recognizer,
and doesn't do any active learning. If patterns are provided, their matches
are highlighted in the example, if available. The recipe will present
all examples in order, so even examples without matches are shown. If
character highlighting is enabled, no "tokens" are saved to the database.
"""
log("RECIPE: Starting recipe ner.manual", locals())
labels = get_pipe_labels(label, nlp.pipe_labels.get("ner", []))
stream = get_stream(
source,
loader=loader,
rehash=True,
dedup=True,
input_key="text",
is_binary=False,
)
if patterns is not None:
pattern_matcher = PatternMatcher(nlp, combine_matches=True, all_examples=True)
pattern_matcher = pattern_matcher.from_disk(patterns)
stream.apply(lambda examples: (eg for _, eg in pattern_matcher(examples)))
# Add "tokens" key to the tasks, either with words or characters
stream.apply(lambda examples: (modify_spans(eg) for eg in examples))
exclude_names = [ds.name for ds in exclude] if exclude is not None else None
known_answers = get_stream(
source,
loader=loader,
rehash=True,
dedup=True,
input_key="text",
is_binary=False,
)
known_answers_map = {eg[INPUT_HASH_ATTR]: eg for eg in known_answers}
return {
"view_id": "ner_manual",
"dataset": dataset,
"stream": [_ for _ in stream],
"exclude": exclude_names,
"validate_answer": partial(validate_answer, known_answers_map=known_answers_map),
"config": {
"lang": nlp.lang,
"labels": labels,
"exclude_by": "input",
"ner_manual_highlight_chars": highlight_chars,
},
}
@support_both_streams(stream_arg="stream")
def preprocess_stream(
stream: StreamType,
nlp: Language,
*,
labels: Optional[List[str]],
unsegmented: bool,
set_annotations: bool = True,
) -> StreamType:
if not unsegmented:
stream = split_sentences(nlp, stream) # type: ignore
stream = add_tokens(nlp, stream) # type: ignore
if set_annotations:
spacy_model = f"{nlp.meta['lang']}_{nlp.meta['name']}"
# Add a 'spans' key to each example, with predicted entities
texts = ((eg["text"], eg) for eg in stream)
for doc, eg in nlp.pipe(texts, as_tuples=True, batch_size=10):
task = copy.deepcopy(eg)
spans = []
for ent in doc.ents:
if labels and ent.label_ not in labels:
continue
spans.append(ent)
for span in eg.get("spans", []):
spans.append(doc.char_span(span["start"], span["end"], span["label"]))
spans = filter_spans(spans)
span_dicts = []
for ent in spans:
span_dicts.append(
{
"token_start": ent.start,
"token_end": ent.end - 1,
"start": ent.start_char,
"end": ent.end_char,
"text": ent.text,
"label": ent.label_,
"source": spacy_model,
"input_hash": eg[INPUT_HASH_ATTR],
}
)
task["spans"] = span_dicts
task[BINARY_ATTR] = False
task = set_hashes(task)
yield task
else:
yield from stream
def get_ner_labels(
nlp: Language, *, label: Optional[List[str]], component: str = "ner"
) -> Tuple[List[str], bool]:
model_labels = nlp.pipe_labels.get(component, [])
labels = get_pipe_labels(label, model_labels)
# Check if we're annotating all labels present in the model or a subset
no_missing = len(set(labels).intersection(set(model_labels))) == len(model_labels)
return labels, no_missing
def get_update(nlp: Language, *, no_missing: bool) -> Callable[[List[TaskType]], None]:
def update(answers: List[TaskType]) -> None:
log(f"RECIPE: Updating model with {len(answers)} answers")
examples = []
for eg in answers:
if eg["answer"] == "accept":
doc = make_raw_doc(nlp, eg)
ref = make_raw_doc(nlp, eg)
spans = [
doc.char_span(span["start"], span["end"], label=span["label"])
for span in eg.get("spans", [])
]
value = SetEntsDefault.outside if no_missing else SetEntsDefault.missing
ref.set_ents(spans, default=value)
examples.append(Example(doc, ref))
nlp.update(examples)
return update