Spaces:
Sleeping
Sleeping
import copy | |
from functools import partial | |
from typing import Callable, Iterable, List, Optional, Tuple, Union, Dict, Any | |
import murmurhash | |
from spacy.language import Language | |
from spacy.tokens.doc import SetEntsDefault # type: ignore | |
from spacy.training import Example | |
from spacy.util import filter_spans | |
from prodigy.components.db import connect | |
from prodigy.components.decorators import support_both_streams | |
from prodigy.components.filters import filter_seen_before | |
from prodigy.components.preprocess import ( | |
add_annot_name, | |
add_tokens, | |
add_view_id, | |
make_ner_suggestions, | |
make_raw_doc, | |
resolve_labels, | |
split_sentences, | |
) | |
from prodigy.components.sorters import prefer_uncertain | |
from prodigy.components.source import GeneratorSource | |
from prodigy.components.stream import Stream, get_stream, load_noop | |
from prodigy.core import Arg, recipe | |
from prodigy.errors import RecipeError | |
from prodigy.models.matcher import PatternMatcher | |
from prodigy.models.ner import EntityRecognizerModel, ensure_sentencizer | |
from prodigy.protocols import ControllerComponentsDict | |
from prodigy.types import ( | |
ExistingFilePath, | |
LabelsType, | |
SourceType, | |
StreamType, | |
TaskType, | |
) | |
from prodigy.util import ( | |
ANNOTATOR_ID_ATTR, | |
BINARY_ATTR, | |
INPUT_HASH_ATTR, | |
TASK_HASH_ATTR, | |
combine_models, | |
copy_nlp, | |
get_pipe_labels, | |
log, | |
msg, | |
set_hashes, | |
) | |
def modify_spans(document): | |
# Modify the 'spans' key to be an empty list | |
document['spans'] = [] | |
return document | |
def spans_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool: | |
"""Checks if two spans are equal""" | |
return s1["start"] == s2["start"] and s1["end"] == s2["end"] | |
def labels_equal(s1: Dict[str, Any], s2: Dict[str, Any]) -> bool: | |
"""Checks if two spans have the same label""" | |
return s1["label"] == s2["label"] | |
def ensure_span_text(eg: TaskType) -> TaskType: | |
"""Ensure that all spans have a text attribute""" | |
for span in eg.get("spans", []): | |
if "text" not in span: | |
span["text"] = eg["text"][span["start"] : span["end"]] | |
return eg | |
def validate_answer(answer: TaskType, *, known_answers_map: Dict[int, TaskType]): | |
"""Validate the answer against the known answers""" | |
known_answer = known_answers_map.get(answer[INPUT_HASH_ATTR]) | |
if known_answer is None: | |
print(f"Skipping validation for answer {answer[INPUT_HASH_ATTR]}, no known answer found to validate against.") | |
return | |
known_answer = ensure_span_text(known_answer) | |
errors = [] | |
known_spans = known_answer.get("spans", []) | |
answer_spans = answer.get("spans", []) | |
explanation_label = known_answer.get("meta", {}).get("explanation_label") | |
explanation_boundaries = known_answer.get("meta", {}).get( | |
"explanation_boundaries" | |
) | |
if not explanation_boundaries: | |
explanation_boundaries = ( | |
"No explanation boundaries" | |
) | |
if len(known_spans) > len(answer_spans): | |
errors.append( | |
"You noted fewer entities than expected for this answer. All mentions must be annotated" | |
) | |
elif len(known_spans) < len(answer_spans): | |
errors.append( | |
"You noted more entities than expected for this answer." | |
) | |
if not known_spans: | |
# For cases where no annotations are expected | |
errors.append(explanation_label) | |
for known_span, span in zip(known_spans, answer_spans): | |
if not labels_equal(known_span, span): | |
# label error | |
errors.append(explanation_label) | |
continue | |
if not spans_equal(known_span, span): | |
# boundary error | |
errors.append(explanation_boundaries) | |
continue | |
if len(errors) > 0: | |
error_msg = "\n".join(errors) | |
error_msg += "\n\nExpected annotations:" | |
if known_spans: | |
expected_spans = [ | |
f'[{s["text"]}]: {s["label"]}' for s in known_spans | |
] | |
if expected_spans: | |
error_msg += "\n" | |
for span_msg in expected_spans: | |
error_msg += span_msg + "\n" | |
else: | |
error_msg += "\n\nNone." | |
raise ValueError(error_msg) | |
def manual( | |
dataset: str, | |
nlp: Language, | |
source: SourceType, | |
loader: Optional[str] = None, | |
label: Optional[LabelsType] = None, | |
patterns: Optional[ExistingFilePath] = None, | |
exclude: List[str] = [], | |
highlight_chars: bool = False, | |
) -> ControllerComponentsDict: | |
""" | |
Mark spans by token. Requires only a tokenizer and no entity recognizer, | |
and doesn't do any active learning. If patterns are provided, their matches | |
are highlighted in the example, if available. The recipe will present | |
all examples in order, so even examples without matches are shown. If | |
character highlighting is enabled, no "tokens" are saved to the database. | |
""" | |
log("RECIPE: Starting recipe ner.manual", locals()) | |
labels = get_pipe_labels(label, nlp.pipe_labels.get("ner", [])) | |
stream = get_stream( | |
source, | |
loader=loader, | |
rehash=True, | |
dedup=True, | |
input_key="text", | |
is_binary=False, | |
) | |
if patterns is not None: | |
pattern_matcher = PatternMatcher(nlp, combine_matches=True, all_examples=True) | |
pattern_matcher = pattern_matcher.from_disk(patterns) | |
stream.apply(lambda examples: (eg for _, eg in pattern_matcher(examples))) | |
# Add "tokens" key to the tasks, either with words or characters | |
stream.apply(lambda examples: (modify_spans(eg) for eg in examples)) | |
exclude_names = [ds.name for ds in exclude] if exclude is not None else None | |
known_answers = get_stream( | |
source, | |
loader=loader, | |
rehash=True, | |
dedup=True, | |
input_key="text", | |
is_binary=False, | |
) | |
known_answers_map = {eg[INPUT_HASH_ATTR]: eg for eg in known_answers} | |
return { | |
"view_id": "ner_manual", | |
"dataset": dataset, | |
"stream": [_ for _ in stream], | |
"exclude": exclude_names, | |
"validate_answer": partial(validate_answer, known_answers_map=known_answers_map), | |
"config": { | |
"lang": nlp.lang, | |
"labels": labels, | |
"exclude_by": "input", | |
"ner_manual_highlight_chars": highlight_chars, | |
}, | |
} | |
def preprocess_stream( | |
stream: StreamType, | |
nlp: Language, | |
*, | |
labels: Optional[List[str]], | |
unsegmented: bool, | |
set_annotations: bool = True, | |
) -> StreamType: | |
if not unsegmented: | |
stream = split_sentences(nlp, stream) # type: ignore | |
stream = add_tokens(nlp, stream) # type: ignore | |
if set_annotations: | |
spacy_model = f"{nlp.meta['lang']}_{nlp.meta['name']}" | |
# Add a 'spans' key to each example, with predicted entities | |
texts = ((eg["text"], eg) for eg in stream) | |
for doc, eg in nlp.pipe(texts, as_tuples=True, batch_size=10): | |
task = copy.deepcopy(eg) | |
spans = [] | |
for ent in doc.ents: | |
if labels and ent.label_ not in labels: | |
continue | |
spans.append(ent) | |
for span in eg.get("spans", []): | |
spans.append(doc.char_span(span["start"], span["end"], span["label"])) | |
spans = filter_spans(spans) | |
span_dicts = [] | |
for ent in spans: | |
span_dicts.append( | |
{ | |
"token_start": ent.start, | |
"token_end": ent.end - 1, | |
"start": ent.start_char, | |
"end": ent.end_char, | |
"text": ent.text, | |
"label": ent.label_, | |
"source": spacy_model, | |
"input_hash": eg[INPUT_HASH_ATTR], | |
} | |
) | |
task["spans"] = span_dicts | |
task[BINARY_ATTR] = False | |
task = set_hashes(task) | |
yield task | |
else: | |
yield from stream | |
def get_ner_labels( | |
nlp: Language, *, label: Optional[List[str]], component: str = "ner" | |
) -> Tuple[List[str], bool]: | |
model_labels = nlp.pipe_labels.get(component, []) | |
labels = get_pipe_labels(label, model_labels) | |
# Check if we're annotating all labels present in the model or a subset | |
no_missing = len(set(labels).intersection(set(model_labels))) == len(model_labels) | |
return labels, no_missing | |
def get_update(nlp: Language, *, no_missing: bool) -> Callable[[List[TaskType]], None]: | |
def update(answers: List[TaskType]) -> None: | |
log(f"RECIPE: Updating model with {len(answers)} answers") | |
examples = [] | |
for eg in answers: | |
if eg["answer"] == "accept": | |
doc = make_raw_doc(nlp, eg) | |
ref = make_raw_doc(nlp, eg) | |
spans = [ | |
doc.char_span(span["start"], span["end"], label=span["label"]) | |
for span in eg.get("spans", []) | |
] | |
value = SetEntsDefault.outside if no_missing else SetEntsDefault.missing | |
ref.set_ents(spans, default=value) | |
examples.append(Example(doc, ref)) | |
nlp.update(examples) | |
return update | |