Spaces:
Runtime error
Runtime error
File size: 8,813 Bytes
3b2a392 54ab090 3b2a392 54ab090 3b2a392 ead4e9e 54ab090 ead4e9e 3b2a392 54ab090 3b2a392 3ad6577 54ab090 3b2a392 54ab090 3b2a392 54ab090 3b2a392 3ad6577 3b2a392 3ad6577 3b2a392 3ad6577 3b2a392 54ab090 3b2a392 54ab090 3b2a392 3ad6577 3b2a392 54ab090 3b2a392 3ad6577 3b2a392 7bae3ea 3b2a392 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
"""Demo gradio app for some text/query augmentation."""
from __future__ import annotations
import functools
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Sequence
import attr
import environ
import fasttext # not working with python3.9
import gradio as gr
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline
from transformers.pipelines.token_classification import AggregationStrategy
def compose(*functions) -> Callable:
"""
Compose functions.
Args:
functions: functions to compose.
Returns:
Composed functions.
"""
def apply(f, g):
return lambda x: f(g(x))
return functools.reduce(apply, functions[::-1], lambda x: x)
def mapped(fn) -> Callable:
"""
Decorator to apply map/filter to a function
"""
def inner(func):
partial_fn = functools.partial(fn, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return partial_fn(*args, **kwargs)
return wrapper
return inner
@attr.frozen
class Prediction:
"""Dataclass to store prediction results."""
label: str
score: float
@attr.frozen
class Models:
identification: Predictor
translation: Predictor
classification: Predictor
ner: Predictor
recipe: Predictor
@attr.frozen
class Predictor:
load_fn: Callable
predict_fn: Callable = attr.field(default=lambda model, query: model(query))
model: Any = attr.field(init=False)
def __attrs_post_init__(self):
object.__setattr__(self, "model", self.load_fn())
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.predict_fn(self.model, *args, **kwds)
@environ.config(prefix="QUERY_INTERPRETATION")
class AppConfig:
@environ.config
class Identification:
"""Identification model configuration."""
model = environ.var(default="./models/lid.176.ftz")
max_results = environ.var(default=3, converter=int)
@environ.config
class Translation:
"""Translation models configuration."""
model = environ.var(default="t5-small")
sources = environ.var(default="de,fr")
target = environ.var(default="en")
@environ.config
class Classification:
"""Classification model configuration."""
model = environ.var(default="typeform/distilbert-base-uncased-mnli")
max_results = environ.var(default=5, converter=int)
@environ.config
class NER:
general = environ.var(
default="asahi417/tner-xlm-roberta-large-uncased-wnut2017",
)
recipe = environ.var(default="adamlin/recipe-tag-model")
identification: Identification = environ.group(Identification)
translation: Translation = environ.group(Translation)
classification: Classification = environ.group(Classification)
ner: NER = environ.group(NER)
def predict(
models: Models,
query: str,
categories: Sequence[str],
supported_languages: tuple[str, ...] = ("fr", "de"),
) -> tuple[
Mapping[str, float],
Mapping[str, float],
str,
Sequence[tuple[str, str | None]],
Sequence[tuple[str, str | None]],
]:
"""Predict from a textual query:
- the language
- classify as a recipe or not
- extract the recipe
"""
def predict_lang(query) -> Mapping[str, float]:
def predict_fn(query) -> Sequence[Prediction]:
return tuple(
Prediction(label=label, score=score)
for label, score in zip(*models.identification(query, k=176))
)
@mapped(map)
def format_label(prediction: Prediction) -> Prediction:
return attr.evolve(
prediction,
label=prediction.label.replace("__label__", ""),
)
def filter_labels(prediction: Prediction) -> bool:
return prediction.label in supported_languages + ("en",)
def format_output(predictions: Sequence[Prediction]) -> dict:
return {pred.label: pred.score for pred in predictions}
apply_fn = compose(
predict_fn,
format_label,
functools.partial(filter, filter_labels),
format_output,
)
return apply_fn(query)
def translate_query(query: str, languages: Mapping[str, float]) -> str:
def predicted_language() -> str:
return max(languages.items(), key=lambda lang: lang[1])[0]
def translate(query):
lang = predicted_language()
if lang in supported_languages:
output = models.translation(query, lang)[0]["translation_text"]
else:
output = query
return output
return translate(query)
def classify_query(query, categories) -> Mapping[str, float]:
predictions = models.classification(query, categories)
return dict(zip(predictions["labels"], predictions["scores"]))
def extract_entities(
predict_fn: Callable,
query: str,
) -> Sequence[tuple[str, str | None]]:
predictions = predict_fn(query)
if len(predictions) == 0:
return [(query, None)]
else:
return [
(pred["word"], pred.get("entity_group", pred.get("entity", None)))
for pred in predictions
]
languages = predict_lang(query)
translation = translate_query(query, languages)
classifications = classify_query(translation, categories)
general_entities = extract_entities(models.ner, query)
recipe_entities = extract_entities(models.recipe, translation)
return languages, classifications, translation, general_entities, recipe_entities
def main():
cfg: AppConfig = AppConfig.from_environ()
def load_translation_models(
sources: Sequence[str],
target: str,
models: Sequence[str],
) -> Pipeline:
result = {
src: pipeline(f"translation_{src}_to_{target}", models)
for src, models in zip(sources, models)
}
return result
def extract_commas_separated_values(value: str) -> Sequence[str]:
return tuple(filter(None, value.split(",")))
models = Models(
identification=Predictor(
load_fn=lambda: fasttext.load_model(cfg.identification.model),
predict_fn=lambda model, query, k: model.predict(query, k=k),
),
translation=Predictor(
load_fn=functools.partial(
load_translation_models,
sources=extract_commas_separated_values(cfg.translation.sources),
target=cfg.translation.target,
models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"],
),
predict_fn=lambda models, query, src: models[src](query),
),
classification=Predictor(
load_fn=lambda: pipeline(
"zero-shot-classification",
model=cfg.classification.model,
),
predict_fn=lambda model, query, categories: model(query, categories),
),
ner=Predictor(
load_fn=lambda: pipeline(
"ner",
model=cfg.ner.general,
aggregation_strategy=AggregationStrategy.SIMPLE,
),
),
recipe=Predictor(
load_fn=lambda: pipeline("ner", model=cfg.ner.recipe),
),
)
iface = gr.Interface(
fn=lambda query, categories: predict(
models,
query.strip(),
extract_commas_separated_values(categories),
),
examples=[["gateau au chocolat paris"], ["Newyork LA flight"]],
inputs=[
gr.inputs.Textbox(label="Query"),
gr.inputs.Textbox(
label="categories (commas separated and in english)",
default="cooking and recipe,traveling,location,information,buy or sell",
),
],
outputs=[
gr.outputs.Label(
num_top_classes=cfg.identification.max_results,
type="auto",
label="Language identification",
),
gr.outputs.Label(
num_top_classes=cfg.classification.max_results,
type="auto",
label="Predicted categories",
),
gr.outputs.Textbox(
label="English query",
type="auto",
),
gr.outputs.HighlightedText(label="NER generic"),
gr.outputs.HighlightedText(label="NER Recipes"),
],
interpretation="default",
)
iface.launch(debug=True)
if __name__ == "__main__":
main()
|