query_analysis / app.py
chicham
Modify the way the results a shown (#4)
3ad6577 unverified
"""Demo gradio app for some text/query augmentation."""
from __future__ import annotations
import functools
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Sequence
import attr
import environ
import fasttext # not working with python3.9
import gradio as gr
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline
from transformers.pipelines.token_classification import AggregationStrategy
def compose(*functions) -> Callable:
"""
Compose functions.
Args:
functions: functions to compose.
Returns:
Composed functions.
"""
def apply(f, g):
return lambda x: f(g(x))
return functools.reduce(apply, functions[::-1], lambda x: x)
def mapped(fn) -> Callable:
"""
Decorator to apply map/filter to a function
"""
def inner(func):
partial_fn = functools.partial(fn, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return partial_fn(*args, **kwargs)
return wrapper
return inner
@attr.frozen
class Prediction:
"""Dataclass to store prediction results."""
label: str
score: float
@attr.frozen
class Models:
identification: Predictor
translation: Predictor
classification: Predictor
ner: Predictor
recipe: Predictor
@attr.frozen
class Predictor:
load_fn: Callable
predict_fn: Callable = attr.field(default=lambda model, query: model(query))
model: Any = attr.field(init=False)
def __attrs_post_init__(self):
object.__setattr__(self, "model", self.load_fn())
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.predict_fn(self.model, *args, **kwds)
@environ.config(prefix="QUERY_INTERPRETATION")
class AppConfig:
@environ.config
class Identification:
"""Identification model configuration."""
model = environ.var(default="./models/lid.176.ftz")
max_results = environ.var(default=3, converter=int)
@environ.config
class Translation:
"""Translation models configuration."""
model = environ.var(default="t5-small")
sources = environ.var(default="de,fr")
target = environ.var(default="en")
@environ.config
class Classification:
"""Classification model configuration."""
model = environ.var(default="typeform/distilbert-base-uncased-mnli")
max_results = environ.var(default=5, converter=int)
@environ.config
class NER:
general = environ.var(
default="asahi417/tner-xlm-roberta-large-uncased-wnut2017",
)
recipe = environ.var(default="adamlin/recipe-tag-model")
identification: Identification = environ.group(Identification)
translation: Translation = environ.group(Translation)
classification: Classification = environ.group(Classification)
ner: NER = environ.group(NER)
def predict(
models: Models,
query: str,
categories: Sequence[str],
supported_languages: tuple[str, ...] = ("fr", "de"),
) -> tuple[
Mapping[str, float],
Mapping[str, float],
str,
Sequence[tuple[str, str | None]],
Sequence[tuple[str, str | None]],
]:
"""Predict from a textual query:
- the language
- classify as a recipe or not
- extract the recipe
"""
def predict_lang(query) -> Mapping[str, float]:
def predict_fn(query) -> Sequence[Prediction]:
return tuple(
Prediction(label=label, score=score)
for label, score in zip(*models.identification(query, k=176))
)
@mapped(map)
def format_label(prediction: Prediction) -> Prediction:
return attr.evolve(
prediction,
label=prediction.label.replace("__label__", ""),
)
def filter_labels(prediction: Prediction) -> bool:
return prediction.label in supported_languages + ("en",)
def format_output(predictions: Sequence[Prediction]) -> dict:
return {pred.label: pred.score for pred in predictions}
apply_fn = compose(
predict_fn,
format_label,
functools.partial(filter, filter_labels),
format_output,
)
return apply_fn(query)
def translate_query(query: str, languages: Mapping[str, float]) -> str:
def predicted_language() -> str:
return max(languages.items(), key=lambda lang: lang[1])[0]
def translate(query):
lang = predicted_language()
if lang in supported_languages:
output = models.translation(query, lang)[0]["translation_text"]
else:
output = query
return output
return translate(query)
def classify_query(query, categories) -> Mapping[str, float]:
predictions = models.classification(query, categories)
return dict(zip(predictions["labels"], predictions["scores"]))
def extract_entities(
predict_fn: Callable,
query: str,
) -> Sequence[tuple[str, str | None]]:
predictions = predict_fn(query)
if len(predictions) == 0:
return [(query, None)]
else:
return [
(pred["word"], pred.get("entity_group", pred.get("entity", None)))
for pred in predictions
]
languages = predict_lang(query)
translation = translate_query(query, languages)
classifications = classify_query(translation, categories)
general_entities = extract_entities(models.ner, query)
recipe_entities = extract_entities(models.recipe, translation)
return languages, classifications, translation, general_entities, recipe_entities
def main():
cfg: AppConfig = AppConfig.from_environ()
def load_translation_models(
sources: Sequence[str],
target: str,
models: Sequence[str],
) -> Pipeline:
result = {
src: pipeline(f"translation_{src}_to_{target}", models)
for src, models in zip(sources, models)
}
return result
def extract_commas_separated_values(value: str) -> Sequence[str]:
return tuple(filter(None, value.split(",")))
models = Models(
identification=Predictor(
load_fn=lambda: fasttext.load_model(cfg.identification.model),
predict_fn=lambda model, query, k: model.predict(query, k=k),
),
translation=Predictor(
load_fn=functools.partial(
load_translation_models,
sources=extract_commas_separated_values(cfg.translation.sources),
target=cfg.translation.target,
models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"],
),
predict_fn=lambda models, query, src: models[src](query),
),
classification=Predictor(
load_fn=lambda: pipeline(
"zero-shot-classification",
model=cfg.classification.model,
),
predict_fn=lambda model, query, categories: model(query, categories),
),
ner=Predictor(
load_fn=lambda: pipeline(
"ner",
model=cfg.ner.general,
aggregation_strategy=AggregationStrategy.SIMPLE,
),
),
recipe=Predictor(
load_fn=lambda: pipeline("ner", model=cfg.ner.recipe),
),
)
iface = gr.Interface(
fn=lambda query, categories: predict(
models,
query.strip(),
extract_commas_separated_values(categories),
),
examples=[["gateau au chocolat paris"], ["Newyork LA flight"]],
inputs=[
gr.inputs.Textbox(label="Query"),
gr.inputs.Textbox(
label="categories (commas separated and in english)",
default="cooking and recipe,traveling,location,information,buy or sell",
),
],
outputs=[
gr.outputs.Label(
num_top_classes=cfg.identification.max_results,
type="auto",
label="Language identification",
),
gr.outputs.Label(
num_top_classes=cfg.classification.max_results,
type="auto",
label="Predicted categories",
),
gr.outputs.Textbox(
label="English query",
type="auto",
),
gr.outputs.HighlightedText(label="NER generic"),
gr.outputs.HighlightedText(label="NER Recipes"),
],
interpretation="default",
)
iface.launch(debug=True)
if __name__ == "__main__":
main()