Hicham Randrianarivo commited on
Commit
3b2a392
·
1 Parent(s): 8858fdf
Files changed (7) hide show
  1. .gitignore +2 -0
  2. app.py +290 -0
  3. flagged/log.csv +3 -0
  4. models/lid.176.ftz +3 -0
  5. requirements.in +13 -0
  6. requirements.txt +230 -0
  7. tox.ini +4 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .envrc
2
+ .pytype/
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Demo gradio app for some text/query augmentation."""
2
+
3
+ from __future__ import annotations
4
+ from collections import defaultdict
5
+ import functools
6
+ from itertools import chain
7
+ from typing import Any, Callable, Mapping, Optional, Sequence, Tuple
8
+
9
+ import attr
10
+ import environ
11
+ import fasttext # not working with python3.9
12
+ import gradio as gr
13
+ from transformers.pipelines import pipeline
14
+ from transformers.pipelines.base import Pipeline
15
+ from transformers.pipelines.token_classification import AggregationStrategy
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+
18
+
19
+ def compose(*functions) -> Callable:
20
+ """
21
+ Compose functions.
22
+
23
+ Args:
24
+ functions: functions to compose.
25
+ Returns:
26
+ Composed functions.
27
+ """
28
+
29
+ def apply(f, g):
30
+ return lambda x: f(g(x))
31
+
32
+ return functools.reduce(apply, functions[::-1], lambda x: x)
33
+
34
+
35
+ def mapped(fn) -> Callable:
36
+ """
37
+ Decorator to apply map/filter to a function
38
+ """
39
+
40
+ def inner(func):
41
+ partial_fn = functools.partial(fn, func)
42
+
43
+ @functools.wraps(func)
44
+ def wrapper(*args, **kwargs):
45
+ return partial_fn(*args, **kwargs)
46
+
47
+ return wrapper
48
+
49
+ return inner
50
+
51
+
52
+ @attr.frozen
53
+ class Prediction:
54
+ """Dataclass to store prediction results."""
55
+
56
+ label: str
57
+ score: float
58
+
59
+
60
+ @attr.frozen
61
+ class Models:
62
+ identification: Predictor
63
+ translation: Predictor
64
+ classification: Predictor
65
+ ner: Predictor
66
+ recipe: Predictor
67
+
68
+
69
+ @attr.frozen
70
+ class Predictor:
71
+ load_fn: Callable
72
+ predict_fn: Callable = attr.field(default=lambda model, query: model(query))
73
+ model: Any = attr.field(init=False)
74
+
75
+ def __attrs_post_init__(self):
76
+ object.__setattr__(self, "model", self.load_fn())
77
+
78
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
79
+ return self.predict_fn(self.model, *args, **kwds)
80
+
81
+
82
+ @environ.config(prefix="QUERY_INTERPRETATION")
83
+ class AppConfig:
84
+ @environ.config
85
+ class Identification:
86
+ """Identification model configuration."""
87
+
88
+ model = environ.var(default="./models/lid.176.ftz")
89
+ max_results = environ.var(default=3, converter=int)
90
+
91
+ @environ.config
92
+ class Translation:
93
+ """Translation models configuration."""
94
+
95
+ model = environ.var(default="t5-small")
96
+ sources = environ.var(default="de,fr")
97
+ target = environ.var(default="en")
98
+
99
+ @environ.config
100
+ class Classification:
101
+ """Classification model configuration."""
102
+
103
+ model = environ.var(default="typeform/distilbert-base-uncased-mnli")
104
+ max_results = environ.var(default=5, converter=int)
105
+
106
+ @environ.config
107
+ class NER:
108
+ general = environ.var(default="Davlan/xlm-roberta-base-ner-hrl")
109
+ recipe = environ.var(default="adamlin/recipe-tag-model")
110
+
111
+ identification: Identification = environ.group(Identification)
112
+ translation: Translation = environ.group(Translation)
113
+ classification: Classification = environ.group(Classification)
114
+ ner: NER = environ.group(NER)
115
+
116
+
117
+ def predict(
118
+ models: Models,
119
+ query: str,
120
+ categories: Sequence[str],
121
+ supported_languages: Tuple[str, ...] = ("fr", "de"),
122
+ ) -> Tuple[
123
+ Mapping[str, float],
124
+ str,
125
+ Mapping[str, float],
126
+ Sequence[Tuple[str, Optional[str]]],
127
+ Sequence[Tuple[str, Optional[str]]],
128
+ ]:
129
+ """Predict from a textual query:
130
+ - the language
131
+ - classify as a recipe or not
132
+ - extract the recipe
133
+ """
134
+
135
+ def predict_lang(query) -> Mapping[str, float]:
136
+ def predict_fn(query) -> Sequence[Prediction]:
137
+ return tuple(
138
+ Prediction(label=label, score=score)
139
+ for label, score in zip(*models.identification(query, k=176))
140
+ )
141
+
142
+ @mapped(map)
143
+ def format_label(prediction: Prediction) -> Prediction:
144
+ return attr.evolve(
145
+ prediction, label=prediction.label.replace("__label__", "")
146
+ )
147
+
148
+ def filter_labels(prediction: Prediction) -> bool:
149
+ return prediction.label in supported_languages + ("en",)
150
+
151
+ def format_output(predictions: Sequence[Prediction]) -> dict:
152
+ return {pred.label: pred.score for pred in predictions}
153
+
154
+ apply_fn = compose(
155
+ predict_fn,
156
+ format_label,
157
+ functools.partial(filter, filter_labels),
158
+ format_output,
159
+ )
160
+ return apply_fn(query)
161
+
162
+ def translate_query(query: str, languages: Mapping[str, float]) -> str:
163
+ def predicted_language() -> str:
164
+ return max(languages.items(), key=lambda lang: lang[1])[0]
165
+
166
+ def translate(query):
167
+ lang = predicted_language()
168
+ if lang in supported_languages:
169
+ output = models.translation(query, lang)[0]["translation_text"]
170
+ else:
171
+ output = query
172
+
173
+ return output
174
+
175
+ return translate(query)
176
+
177
+ def classify_query(query, categories) -> Mapping[str, float]:
178
+ predictions = models.classification(query, categories)
179
+ return dict(zip(predictions["labels"], predictions["scores"]))
180
+
181
+ def extract_entities(
182
+ predict_fn: Callable, query: str
183
+ ) -> Sequence[Tuple[str, Optional[str]]]:
184
+ def get_entity(pred: Mapping[str, str]):
185
+ return pred.get("entity", pred.get("entity_group", None))
186
+
187
+ mapping = defaultdict(lambda: None)
188
+ mapping.update(**{pred["word"]: get_entity(pred) for pred in predict_fn(query)})
189
+
190
+ query_processed = Whitespace().pre_tokenize_str(query)
191
+ res = tuple(
192
+ chain.from_iterable(
193
+ ((word, mapping[word]), (" ", None)) for word, _ in query_processed
194
+ )
195
+ )
196
+ return res
197
+
198
+ languages = predict_lang(query)
199
+ translation = translate_query(query, languages)
200
+ classifications = classify_query(translation, categories)
201
+ general_entities = extract_entities(models.ner, query)
202
+ recipe_entities = extract_entities(models.recipe, translation)
203
+ return languages, translation, classifications, general_entities, recipe_entities
204
+
205
+
206
+ def main():
207
+ cfg: AppConfig = AppConfig.from_environ()
208
+
209
+ def load_translation_models(
210
+ sources: Sequence[str], target: str, models: Sequence[str]
211
+ ) -> Pipeline:
212
+ result = {
213
+ src: pipeline(f"translation_{src}_to_{target}", models)
214
+ for src, models in zip(sources, models)
215
+ }
216
+ return result
217
+
218
+ def extract_commas_separated_values(value: str) -> Sequence[str]:
219
+ return tuple(filter(None, value.split(",")))
220
+
221
+ models = Models(
222
+ identification=Predictor(
223
+ load_fn=lambda: fasttext.load_model(cfg.identification.model),
224
+ predict_fn=lambda model, query, k: model.predict(query, k=k),
225
+ ),
226
+ translation=Predictor(
227
+ load_fn=functools.partial(
228
+ load_translation_models,
229
+ sources=extract_commas_separated_values(cfg.translation.sources),
230
+ target=cfg.translation.target,
231
+ models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"],
232
+ ),
233
+ predict_fn=lambda models, query, src: models[src](query),
234
+ ),
235
+ classification=Predictor(
236
+ load_fn=lambda: pipeline(
237
+ "zero-shot-classification", model=cfg.classification.model
238
+ ),
239
+ predict_fn=lambda model, query, categories: model(query, categories),
240
+ ),
241
+ ner=Predictor(
242
+ load_fn=lambda: pipeline(
243
+ "ner",
244
+ model=cfg.ner.general,
245
+ aggregation_strategy=AggregationStrategy.SIMPLE,
246
+ ),
247
+ ),
248
+ recipe=Predictor(
249
+ load_fn=lambda: pipeline("ner", model=cfg.ner.recipe),
250
+ ),
251
+ )
252
+
253
+ iface = gr.Interface(
254
+ fn=lambda query, categories: predict(
255
+ models, query.strip(), extract_commas_separated_values(categories)
256
+ ),
257
+ examples=[["gateau au chocolat paris"], ["Newyork LA flight"]],
258
+ inputs=[
259
+ gr.inputs.Textbox(label="Query"),
260
+ gr.inputs.Textbox(
261
+ label="categories (commas separated and in english)",
262
+ default="cooking and recipe,traveling,location,information,buy or sell",
263
+ ),
264
+ ],
265
+ outputs=[
266
+ gr.outputs.Label(
267
+ num_top_classes=cfg.identification.max_results,
268
+ type="auto",
269
+ label="Language identification",
270
+ ),
271
+ gr.outputs.Textbox(
272
+ label="English query",
273
+ type="auto",
274
+ ),
275
+ gr.outputs.Label(
276
+ num_top_classes=cfg.classification.max_results,
277
+ type="auto",
278
+ label="Predicted categories",
279
+ ),
280
+ gr.outputs.HighlightedText(label="NER generic"),
281
+ gr.outputs.HighlightedText(label="NER Recipes"),
282
+ ],
283
+ interpretation="default",
284
+ )
285
+
286
+ iface.launch()
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
flagged/log.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name,Output,timestamp
2
+ ,Hello ,2021-12-26 11:28:41.922022
3
+ ,Hello ,2021-12-26 11:28:43.161869
models/lid.176.ftz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f3472cfe8738a7b6099e8e999c3cbfae0dcd15696aac7d7738a8039db603e83
3
+ size 938013
requirements.in ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/cpu/torch_stable.html
2
+ gradio
3
+ transformers
4
+ fasttext
5
+ huggingface_hub
6
+ requests
7
+ datasets
8
+ tokenizers
9
+ torch
10
+ environ-config
11
+ sentencepiece
12
+ rich
13
+ protobuf
requirements.txt ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # This file is autogenerated by pip-compile with python 3.8
3
+ # To update, run:
4
+ #
5
+ # pip-compile requirements.in
6
+ #
7
+ --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
8
+
9
+ aiohttp==3.8.1
10
+ # via
11
+ # datasets
12
+ # fsspec
13
+ aiosignal==1.2.0
14
+ # via aiohttp
15
+ analytics-python==1.4.0
16
+ # via gradio
17
+ async-timeout==4.0.2
18
+ # via aiohttp
19
+ attrs==21.4.0
20
+ # via
21
+ # aiohttp
22
+ # environ-config
23
+ backoff==1.10.0
24
+ # via analytics-python
25
+ bcrypt==3.2.0
26
+ # via paramiko
27
+ certifi==2021.10.8
28
+ # via requests
29
+ cffi==1.15.0
30
+ # via
31
+ # bcrypt
32
+ # cryptography
33
+ # pynacl
34
+ charset-normalizer==2.0.9
35
+ # via
36
+ # aiohttp
37
+ # requests
38
+ click==8.0.3
39
+ # via
40
+ # flask
41
+ # sacremoses
42
+ colorama==0.4.4
43
+ # via rich
44
+ commonmark==0.9.1
45
+ # via rich
46
+ cryptography==36.0.1
47
+ # via paramiko
48
+ cycler==0.11.0
49
+ # via matplotlib
50
+ datasets==1.17.0
51
+ # via -r requirements.in
52
+ dill==0.3.4
53
+ # via
54
+ # datasets
55
+ # multiprocess
56
+ environ-config==21.2.0
57
+ # via -r requirements.in
58
+ fasttext==0.9.2
59
+ # via -r requirements.in
60
+ ffmpy==0.3.0
61
+ # via gradio
62
+ filelock==3.4.2
63
+ # via
64
+ # huggingface-hub
65
+ # transformers
66
+ flask==2.0.2
67
+ # via
68
+ # flask-cachebuster
69
+ # flask-cors
70
+ # flask-login
71
+ # gradio
72
+ flask-cachebuster==1.0.0
73
+ # via gradio
74
+ flask-cors==3.0.10
75
+ # via gradio
76
+ flask-login==0.5.0
77
+ # via gradio
78
+ fonttools==4.28.5
79
+ # via matplotlib
80
+ frozenlist==1.2.0
81
+ # via
82
+ # aiohttp
83
+ # aiosignal
84
+ fsspec[http]==2021.11.1
85
+ # via datasets
86
+ gradio==2.6.3
87
+ # via -r requirements.in
88
+ huggingface-hub==0.2.1
89
+ # via
90
+ # -r requirements.in
91
+ # datasets
92
+ # transformers
93
+ idna==3.3
94
+ # via
95
+ # requests
96
+ # yarl
97
+ itsdangerous==2.0.1
98
+ # via flask
99
+ jinja2==3.0.3
100
+ # via flask
101
+ joblib==1.1.0
102
+ # via sacremoses
103
+ kiwisolver==1.3.2
104
+ # via matplotlib
105
+ markdown2==2.4.2
106
+ # via gradio
107
+ markupsafe==2.0.1
108
+ # via jinja2
109
+ matplotlib==3.5.1
110
+ # via gradio
111
+ monotonic==1.6
112
+ # via analytics-python
113
+ multidict==5.2.0
114
+ # via
115
+ # aiohttp
116
+ # yarl
117
+ multiprocess==0.70.12.2
118
+ # via datasets
119
+ numpy==1.21.5
120
+ # via
121
+ # datasets
122
+ # fasttext
123
+ # gradio
124
+ # matplotlib
125
+ # pandas
126
+ # pyarrow
127
+ # transformers
128
+ packaging==21.3
129
+ # via
130
+ # datasets
131
+ # huggingface-hub
132
+ # matplotlib
133
+ # transformers
134
+ pandas==1.3.5
135
+ # via
136
+ # datasets
137
+ # gradio
138
+ paramiko==2.9.1
139
+ # via gradio
140
+ pillow==8.4.0
141
+ # via
142
+ # gradio
143
+ # matplotlib
144
+ protobuf==3.19.1
145
+ # via -r requirements.in
146
+ pyarrow==6.0.1
147
+ # via datasets
148
+ pybind11==2.9.0
149
+ # via fasttext
150
+ pycparser==2.21
151
+ # via cffi
152
+ pycryptodome==3.12.0
153
+ # via gradio
154
+ pydub==0.25.1
155
+ # via gradio
156
+ pygments==2.10.0
157
+ # via rich
158
+ pynacl==1.4.0
159
+ # via paramiko
160
+ pyparsing==3.0.6
161
+ # via
162
+ # matplotlib
163
+ # packaging
164
+ python-dateutil==2.8.2
165
+ # via
166
+ # analytics-python
167
+ # matplotlib
168
+ # pandas
169
+ pytz==2021.3
170
+ # via pandas
171
+ pyyaml==6.0
172
+ # via
173
+ # huggingface-hub
174
+ # transformers
175
+ regex==2021.11.10
176
+ # via
177
+ # sacremoses
178
+ # transformers
179
+ requests==2.26.0
180
+ # via
181
+ # -r requirements.in
182
+ # analytics-python
183
+ # datasets
184
+ # fsspec
185
+ # gradio
186
+ # huggingface-hub
187
+ # transformers
188
+ rich==10.16.1
189
+ # via -r requirements.in
190
+ sacremoses==0.0.46
191
+ # via transformers
192
+ sentencepiece==0.1.96
193
+ # via -r requirements.in
194
+ six==1.16.0
195
+ # via
196
+ # analytics-python
197
+ # bcrypt
198
+ # flask-cors
199
+ # pynacl
200
+ # python-dateutil
201
+ # sacremoses
202
+ tokenizers==0.10.3
203
+ # via
204
+ # -r requirements.in
205
+ # transformers
206
+ torch==1.10.1+cpu
207
+ # via -r requirements.in
208
+ tqdm==4.62.3
209
+ # via
210
+ # datasets
211
+ # huggingface-hub
212
+ # sacremoses
213
+ # transformers
214
+ transformers==4.15.0
215
+ # via -r requirements.in
216
+ typing-extensions==4.0.1
217
+ # via
218
+ # huggingface-hub
219
+ # torch
220
+ urllib3==1.26.7
221
+ # via requests
222
+ werkzeug==2.0.2
223
+ # via flask
224
+ xxhash==2.0.2
225
+ # via datasets
226
+ yarl==1.7.2
227
+ # via aiohttp
228
+
229
+ # The following packages are considered to be unsafe in a requirements file:
230
+ # setuptools
tox.ini ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [flake8]
2
+ docstring-convention=google
3
+ max-line-length = 88
4
+ extend-ignore = E203