nsethi610 commited on
Commit
f0a38b0
1 Parent(s): 7aa37cd

Create pipeline_utils.py

Browse files
Files changed (1) hide show
  1. pipeline_utils.py +51 -0
pipeline_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from task import tasks_config
3
+ from transformers import pipeline
4
+
5
+
6
+ def review_training_choices(choice):
7
+ print(choice)
8
+ if choice == "Use Pipeline":
9
+ return gr.Row(visible=True)
10
+ else:
11
+ return gr.Row(visible=False)
12
+
13
+
14
+ def handle_task_change(task):
15
+ visibility = task == "question-answering"
16
+ models = tasks_config[task]["config"]["models"]
17
+ model_choices = [(model, model) for model in models]
18
+ return gr.update(visible=visibility), gr.Dropdown(
19
+ choices=model_choices,
20
+ label="Model",
21
+ allow_custom_value=True,
22
+ interactive=True
23
+ ), gr.Dropdown(info=tasks_config[task]["info"])
24
+
25
+
26
+ def test_pipeline(task, model=None, prompt=None, context=None):
27
+ # configure additional options for each model
28
+ options = {"ner": {"grouped_entities": True}, "question-answering": {},
29
+ "text-generation": {}, "fill-mask": {}, "summarization": {}}
30
+ # configure pipeline
31
+ test = pipeline(task, model=model, **
32
+ options[task]) if model else pipeline(task, **options[task])
33
+ # call pipeline
34
+ if task == "question-answering":
35
+ if not context:
36
+ return "Context is required"
37
+ else:
38
+ result = test(question=prompt, context=context)
39
+ else:
40
+ result = test(prompt)
41
+
42
+ # generated ouput based on task and return
43
+ output_mapping = {
44
+ "text-generation": lambda x: x[0]["generated_text"],
45
+ "fill-mask": lambda x: x[0]["sequence"],
46
+ "summarization": lambda x: x[0]["summary_text"],
47
+ "ner": lambda x: "\n".join(f"{k}={v}" for item in x for k, v in item.items() if k not in ["start", "end", "index"]).rstrip("\n"),
48
+ "question-answering": lambda x: x
49
+ }
50
+
51
+ return gr.TextArea(output_mapping[task](result))