faori commited on
Commit
e370330
1 Parent(s): 5cd3a6a

Upload folder using huggingface_hub

Browse files
app1.py CHANGED
@@ -19,13 +19,30 @@ def load_model(config_path):
19
  return RetroReader.load(config_file=config_path)
20
 
21
  # Loading models
22
- model_base = load_model("configs/inference_en_electra_base.yaml")
23
- model_large = load_model("configs/inference_en_electra_large.yaml")
 
 
24
 
25
  def retro_reader_demo(query, context, model_choice):
26
- model = model_base if model_choice == "Base" else model_large
 
 
 
 
 
 
 
 
 
 
 
 
27
  outputs = model(query=query, context=context, return_submodule_outputs=True)
 
 
28
  answer = outputs[0]["id-01"] if outputs[0]["id-01"] else "No answer found"
 
29
  return answer
30
 
31
  # Gradio app interface
 
19
  return RetroReader.load(config_file=config_path)
20
 
21
  # Loading models
22
+ model_electra_base = load_model("configs/inference_en_electra_base.yaml")
23
+ model_electra_large = load_model("configs/inference_en_electra_large.yaml")
24
+ model_roberta = load_model("configs/inference_en_roberta.yaml")
25
+ model_distilbert = load_model("configs/inference_en_distilbert.yaml")
26
 
27
  def retro_reader_demo(query, context, model_choice):
28
+ # Choose the model based on the model_choice
29
+ if model_choice == "Electra Base":
30
+ model = model_electra_base
31
+ elif model_choice == "Electra Large":
32
+ model = model_electra_large
33
+ elif model_choice == "Roberta":
34
+ model = model_roberta
35
+ elif model_choice == "DistilBERT":
36
+ model = model_distilbert
37
+ else:
38
+ return "Invalid model choice"
39
+
40
+ # Generate outputs using the chosen model
41
  outputs = model(query=query, context=context, return_submodule_outputs=True)
42
+
43
+ # Extract the answer
44
  answer = outputs[0]["id-01"] if outputs[0]["id-01"] else "No answer found"
45
+
46
  return answer
47
 
48
  # Gradio app interface
configs/inference_en_distilbert.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RetroDataModelArguments:
2
+
3
+ # DataArguments
4
+ max_seq_length: 512
5
+ max_answer_length: 30
6
+ doc_stride: 128
7
+ return_token_type_ids: True
8
+ pad_to_max_length: True
9
+ preprocessing_num_workers: 5
10
+ overwrite_cache: False
11
+ version_2_with_negative: True
12
+ null_score_diff_threshold: 0.0
13
+ rear_threshold: 0.0
14
+ n_best_size: 20
15
+ use_choice_logits: False
16
+ start_n_top: -1
17
+ end_n_top: -1
18
+ beta1: 1
19
+ beta2: 1
20
+ best_cof: 1
21
+
22
+ # ModelArguments
23
+ use_auth_token: False
24
+
25
+ # SketchModelArguments
26
+ sketch_revision: en-distilbert-sketch
27
+ sketch_model_name: faori/retro_reeader
28
+ # sketch_model_mode: transfer
29
+ sketch_architectures: ElectraForSequenceClassification
30
+
31
+ # IntensiveModelArguments
32
+ intensive_revision: en-distilbert-intensive1
33
+ intensive_model_name: faori/retro_reeader
34
+ # intensive_model_mode: transfer
35
+ intensive_architectures: ElectraForQuestionAnsweringAVPool
36
+
37
+
38
+ TrainingArguments:
39
+ output_dir: outputs
40
+ no_cuda: True # If you want to use cuda,
41
+ # change `no_cuda` to False and `fp16` to True
42
+ per_device_train_batch_size: 1
43
+ per_device_eval_batch_size: 12
configs/inference_en_roberta.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RetroDataModelArguments:
2
+
3
+ # DataArguments
4
+ max_seq_length: 512
5
+ max_answer_length: 30
6
+ doc_stride: 128
7
+ return_token_type_ids: True
8
+ pad_to_max_length: True
9
+ preprocessing_num_workers: 5
10
+ overwrite_cache: False
11
+ version_2_with_negative: True
12
+ null_score_diff_threshold: 0.0
13
+ rear_threshold: 0.0
14
+ n_best_size: 20
15
+ use_choice_logits: False
16
+ start_n_top: -1
17
+ end_n_top: -1
18
+ beta1: 1
19
+ beta2: 1
20
+ best_cof: 1
21
+
22
+ # ModelArguments
23
+ use_auth_token: False
24
+
25
+ # SketchModelArguments
26
+ sketch_revision: en-roberta-sketch
27
+ sketch_model_name: faori/retro_reeader
28
+ # sketch_model_mode: transfer
29
+ sketch_architectures: ElectraForSequenceClassification
30
+
31
+ # IntensiveModelArguments
32
+ intensive_revision: en-roberta-intensive
33
+ intensive_model_name: faori/retro_reeader
34
+ # intensive_model_mode: transfer
35
+ intensive_architectures: ElectraForQuestionAnsweringAVPool
36
+
37
+
38
+ TrainingArguments:
39
+ output_dir: outputs
40
+ no_cuda: True # If you want to use cuda,
41
+ # change `no_cuda` to False and `fp16` to True
42
+ per_device_train_batch_size: 1
43
+ per_device_eval_batch_size: 12