Spaces:
Runtime error
Runtime error
"""Demo gradio app for some text/query augmentation.""" | |
from __future__ import annotations | |
import functools | |
from typing import Any | |
from typing import Callable | |
from typing import Mapping | |
from typing import Sequence | |
import attr | |
import environ | |
import fasttext # not working with python3.9 | |
import gradio as gr | |
from transformers.pipelines import pipeline | |
from transformers.pipelines.base import Pipeline | |
from transformers.pipelines.token_classification import AggregationStrategy | |
def compose(*functions) -> Callable: | |
""" | |
Compose functions. | |
Args: | |
functions: functions to compose. | |
Returns: | |
Composed functions. | |
""" | |
def apply(f, g): | |
return lambda x: f(g(x)) | |
return functools.reduce(apply, functions[::-1], lambda x: x) | |
def mapped(fn) -> Callable: | |
""" | |
Decorator to apply map/filter to a function | |
""" | |
def inner(func): | |
partial_fn = functools.partial(fn, func) | |
def wrapper(*args, **kwargs): | |
return partial_fn(*args, **kwargs) | |
return wrapper | |
return inner | |
class Prediction: | |
"""Dataclass to store prediction results.""" | |
label: str | |
score: float | |
class Models: | |
identification: Predictor | |
translation: Predictor | |
classification: Predictor | |
ner: Predictor | |
recipe: Predictor | |
class Predictor: | |
load_fn: Callable | |
predict_fn: Callable = attr.field(default=lambda model, query: model(query)) | |
model: Any = attr.field(init=False) | |
def __attrs_post_init__(self): | |
object.__setattr__(self, "model", self.load_fn()) | |
def __call__(self, *args: Any, **kwds: Any) -> Any: | |
return self.predict_fn(self.model, *args, **kwds) | |
class AppConfig: | |
class Identification: | |
"""Identification model configuration.""" | |
model = environ.var(default="./models/lid.176.ftz") | |
max_results = environ.var(default=3, converter=int) | |
class Translation: | |
"""Translation models configuration.""" | |
model = environ.var(default="t5-small") | |
sources = environ.var(default="de,fr") | |
target = environ.var(default="en") | |
class Classification: | |
"""Classification model configuration.""" | |
model = environ.var(default="typeform/distilbert-base-uncased-mnli") | |
max_results = environ.var(default=5, converter=int) | |
class NER: | |
general = environ.var( | |
default="asahi417/tner-xlm-roberta-large-uncased-wnut2017", | |
) | |
recipe = environ.var(default="adamlin/recipe-tag-model") | |
identification: Identification = environ.group(Identification) | |
translation: Translation = environ.group(Translation) | |
classification: Classification = environ.group(Classification) | |
ner: NER = environ.group(NER) | |
def predict( | |
models: Models, | |
query: str, | |
categories: Sequence[str], | |
supported_languages: tuple[str, ...] = ("fr", "de"), | |
) -> tuple[ | |
Mapping[str, float], | |
Mapping[str, float], | |
str, | |
Sequence[tuple[str, str | None]], | |
Sequence[tuple[str, str | None]], | |
]: | |
"""Predict from a textual query: | |
- the language | |
- classify as a recipe or not | |
- extract the recipe | |
""" | |
def predict_lang(query) -> Mapping[str, float]: | |
def predict_fn(query) -> Sequence[Prediction]: | |
return tuple( | |
Prediction(label=label, score=score) | |
for label, score in zip(*models.identification(query, k=176)) | |
) | |
def format_label(prediction: Prediction) -> Prediction: | |
return attr.evolve( | |
prediction, | |
label=prediction.label.replace("__label__", ""), | |
) | |
def filter_labels(prediction: Prediction) -> bool: | |
return prediction.label in supported_languages + ("en",) | |
def format_output(predictions: Sequence[Prediction]) -> dict: | |
return {pred.label: pred.score for pred in predictions} | |
apply_fn = compose( | |
predict_fn, | |
format_label, | |
functools.partial(filter, filter_labels), | |
format_output, | |
) | |
return apply_fn(query) | |
def translate_query(query: str, languages: Mapping[str, float]) -> str: | |
def predicted_language() -> str: | |
return max(languages.items(), key=lambda lang: lang[1])[0] | |
def translate(query): | |
lang = predicted_language() | |
if lang in supported_languages: | |
output = models.translation(query, lang)[0]["translation_text"] | |
else: | |
output = query | |
return output | |
return translate(query) | |
def classify_query(query, categories) -> Mapping[str, float]: | |
predictions = models.classification(query, categories) | |
return dict(zip(predictions["labels"], predictions["scores"])) | |
def extract_entities( | |
predict_fn: Callable, | |
query: str, | |
) -> Sequence[tuple[str, str | None]]: | |
predictions = predict_fn(query) | |
if len(predictions) == 0: | |
return [(query, None)] | |
else: | |
return [ | |
(pred["word"], pred.get("entity_group", pred.get("entity", None))) | |
for pred in predictions | |
] | |
languages = predict_lang(query) | |
translation = translate_query(query, languages) | |
classifications = classify_query(translation, categories) | |
general_entities = extract_entities(models.ner, query) | |
recipe_entities = extract_entities(models.recipe, translation) | |
return languages, classifications, translation, general_entities, recipe_entities | |
def main(): | |
cfg: AppConfig = AppConfig.from_environ() | |
def load_translation_models( | |
sources: Sequence[str], | |
target: str, | |
models: Sequence[str], | |
) -> Pipeline: | |
result = { | |
src: pipeline(f"translation_{src}_to_{target}", models) | |
for src, models in zip(sources, models) | |
} | |
return result | |
def extract_commas_separated_values(value: str) -> Sequence[str]: | |
return tuple(filter(None, value.split(","))) | |
models = Models( | |
identification=Predictor( | |
load_fn=lambda: fasttext.load_model(cfg.identification.model), | |
predict_fn=lambda model, query, k: model.predict(query, k=k), | |
), | |
translation=Predictor( | |
load_fn=functools.partial( | |
load_translation_models, | |
sources=extract_commas_separated_values(cfg.translation.sources), | |
target=cfg.translation.target, | |
models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"], | |
), | |
predict_fn=lambda models, query, src: models[src](query), | |
), | |
classification=Predictor( | |
load_fn=lambda: pipeline( | |
"zero-shot-classification", | |
model=cfg.classification.model, | |
), | |
predict_fn=lambda model, query, categories: model(query, categories), | |
), | |
ner=Predictor( | |
load_fn=lambda: pipeline( | |
"ner", | |
model=cfg.ner.general, | |
aggregation_strategy=AggregationStrategy.SIMPLE, | |
), | |
), | |
recipe=Predictor( | |
load_fn=lambda: pipeline("ner", model=cfg.ner.recipe), | |
), | |
) | |
iface = gr.Interface( | |
fn=lambda query, categories: predict( | |
models, | |
query.strip(), | |
extract_commas_separated_values(categories), | |
), | |
examples=[["gateau au chocolat paris"], ["Newyork LA flight"]], | |
inputs=[ | |
gr.inputs.Textbox(label="Query"), | |
gr.inputs.Textbox( | |
label="categories (commas separated and in english)", | |
default="cooking and recipe,traveling,location,information,buy or sell", | |
), | |
], | |
outputs=[ | |
gr.outputs.Label( | |
num_top_classes=cfg.identification.max_results, | |
type="auto", | |
label="Language identification", | |
), | |
gr.outputs.Label( | |
num_top_classes=cfg.classification.max_results, | |
type="auto", | |
label="Predicted categories", | |
), | |
gr.outputs.Textbox( | |
label="English query", | |
type="auto", | |
), | |
gr.outputs.HighlightedText(label="NER generic"), | |
gr.outputs.HighlightedText(label="NER Recipes"), | |
], | |
interpretation="default", | |
) | |
iface.launch(debug=True) | |
if __name__ == "__main__": | |
main() | |