inflaton commited on
Commit
251efe4
1 Parent(s): 8a58456

initial working code

Browse files
Files changed (7) hide show
  1. .env.example +81 -0
  2. .gitignore +149 -0
  3. app.py +109 -25
  4. eval_modules/calc_repetitions_v2e.py +1333 -0
  5. eval_modules/utils.py +262 -0
  6. ms_macro.json +0 -0
  7. requirements.txt +20 -1
.env.example ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_MODEL_TYPE=huggingface
2
+ # LLM_MODEL_TYPE=openai
3
+ # LLM_MODEL_TYPE=hftgi
4
+ # LLM_MODEL_TYPE=ollama
5
+ # LLM_MODEL_TYPE=google
6
+ # LLM_MODEL_TYPE=vllm
7
+
8
+ HUGGINGFACE_AUTH_TOKEN=
9
+
10
+ HFTGI_SERVER_URL=
11
+
12
+ OPENAI_API_KEY=
13
+
14
+ GOOGLE_API_KEY=
15
+
16
+ # if unset, default to "gpt-3.5-turbo"
17
+ OPENAI_MODEL_NAME=
18
+
19
+ # GEMINI_MODEL_NAME=gemini-1.5-pro-latest
20
+
21
+ # OLLAMA_MODEL_NAME=orca2:7b
22
+ # OLLAMA_MODEL_NAME=mistral:7b
23
+ # OLLAMA_MODEL_NAME=gemma:7b
24
+ # OLLAMA_MODEL_NAME=llama2:7b
25
+ OLLAMA_MODEL_NAME=llama3:8b
26
+
27
+ OLLAMA_RP=1.15
28
+ HF_RP=1.15
29
+
30
+ LANGCHAIN_DEBUG=false
31
+ BATCH_SIZE=1
32
+ APPLY_CHAT_TEMPLATE_FOR_RAG=true
33
+
34
+ # cpu, mps or cuda:0 - if unset, use whatever detected
35
+ HF_EMBEDDINGS_DEVICE_TYPE=
36
+ HF_PIPELINE_DEVICE_TYPE=
37
+
38
+ # uncomment one of the below to load corresponding quantized model
39
+ # LOAD_QUANTIZED_MODEL=4bit
40
+ # LOAD_QUANTIZED_MODEL=8bit
41
+
42
+ QA_WITH_RAG=true
43
+ # QA_WITH_RAG=false
44
+
45
+ RETRIEVER_TYPE=questions_file
46
+ # RETRIEVER_TYPE=vectorstore
47
+
48
+ QUESTIONS_FILE_PATH="./data/datasets/ms_macro.json"
49
+
50
+ DISABLE_MODEL_PRELOADING=true
51
+ CHAT_HISTORY_ENABLED=false
52
+ SHOW_PARAM_SETTINGS=false
53
+ SHARE_GRADIO_APP=false
54
+
55
+ # if unset, default to "hkunlp/instructor-xl"
56
+ HF_EMBEDDINGS_MODEL_NAME="hkunlp/instructor-large"
57
+
58
+ # number of cpu cores - used to set n_threads for GPT4ALL & LlamaCpp models
59
+ NUMBER_OF_CPU_CORES=
60
+
61
+ USING_TORCH_BFLOAT16=true
62
+
63
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-3b"
64
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-7b"
65
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-12b"
66
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="TheBloke/wizardLM-7B-HF"
67
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="TheBloke/vicuna-7B-1.1-HF"
68
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="nomic-ai/gpt4all-j"
69
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="nomic-ai/gpt4all-falcon"
70
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="lmsys/fastchat-t5-3b-v1.0"
71
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-7b-chat-hf"
72
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-13b-chat-hf"
73
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-70b-chat-hf"
74
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
75
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-70B-Instruct"
76
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Orca-2-7b"
77
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Orca-2-13b"
78
+ HUGGINGFACE_MODEL_NAME_OR_PATH="google/gemma-1.1-2b-it"
79
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="google/gemma-1.1-7b-it"
80
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-128k-instruct"
81
+ # HUGGINGFACE_MODEL_NAME_OR_PATH="mistralai/Mistral-7B-Instruct-v0.2"
.gitignore ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.out
2
+ *.log
3
+ pdfs/
4
+ .vscode/
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ pip-wheel-metadata/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ # *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100
+ __pypackages__/
101
+
102
+ # Celery stuff
103
+ celerybeat-schedule
104
+ celerybeat.pid
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/
135
+
136
+ # JetBrains
137
+ .idea
138
+
139
+ *.db
140
+
141
+ .DS_Store
142
+
143
+ vectorstore.pkl
144
+ langchain.readthedocs.io/
145
+
146
+ models/
147
+ data/logs/hftgi-2024-03-18.txt
148
+ qa_*_all_results.csv
149
+ qa_*_test_results.csv
app.py CHANGED
@@ -1,54 +1,141 @@
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
8
 
9
 
10
- def respond(
11
  message,
12
  history: list[tuple[str, str]],
13
  system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
 
26
  messages.append({"role": "user", "content": message})
27
 
28
- response = ""
29
 
 
30
  for message in client.chat_completion(
31
  messages,
32
  max_tokens=max_tokens,
33
  stream=True,
34
  temperature=temperature,
 
 
35
  top_p=top_p,
 
36
  ):
37
- token = message.choices[0].delta.content
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
  demo = gr.ChatInterface(
47
- respond,
 
 
 
 
 
48
  additional_inputs=[
49
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  gr.Slider(
53
  minimum=0.1,
54
  maximum=1.0,
@@ -58,7 +145,4 @@ demo = gr.ChatInterface(
58
  ),
59
  ],
60
  )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ import json
2
+ import os
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
+ from eval_modules.utils import calc_bleu_rouge_scores
6
+ from eval_modules.calc_repetitions_v2e import detect_repetitions
7
+
8
+ questions_file_path = os.getenv("QUESTIONS_FILE_PATH") or "./ms_macro.json"
9
+
10
+ questions = json.loads(open(questions_file_path).read())
11
+ examples = [[question["question"].strip()] for question in questions]
12
+ print(f"Loaded {len(examples)} examples")
13
+
14
+ qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
15
 
16
  """
17
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
18
  """
19
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
20
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-gemma-v0.1")
21
+ client = InferenceClient("microsoft/Phi-3.5-mini-instruct")
22
 
23
 
24
+ def chat(
25
  message,
26
  history: list[tuple[str, str]],
27
  system_message,
28
+ temperature=0,
29
+ frequency_penalty=0,
30
+ presence_penalty=0,
31
+ max_tokens=256,
32
+ top_p=0.95,
33
  ):
34
+ chat = []
35
+ for item in history:
36
+ chat.append({"role": "user", "content": item[0]})
37
+ if item[1] is not None:
38
+ chat.append({"role": "assistant", "content": item[1]})
39
+
40
+ index = -1
41
+ if [message] in examples:
42
+ index = examples.index([message])
43
+ message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
44
+ print("RAG prompt:", message)
45
 
46
+ chat.append({"role": "user", "content": message})
 
 
 
 
47
 
48
+ messages = [{"role": "system", "content": system_message}]
49
  messages.append({"role": "user", "content": message})
50
 
51
+ partial_text = ""
52
 
53
+ finish_reason = None
54
  for message in client.chat_completion(
55
  messages,
56
  max_tokens=max_tokens,
57
  stream=True,
58
  temperature=temperature,
59
+ frequency_penalty=None, # frequency_penalty,
60
+ presence_penalty=None, # presence_penalty,
61
  top_p=top_p,
62
+ seed=42,
63
  ):
64
+ finish_reason = message.choices[0].finish_reason
65
+ # print("finish_reason:", finish_reason)
66
 
67
+ if finish_reason is None:
68
+ new_text = message.choices[0].delta.content
69
+ partial_text += new_text
70
+ yield partial_text
71
+ else:
72
+ break
73
+
74
+ answer = partial_text
75
+ (whitespace_score, repetition_score, total_repetitions) = detect_repetitions(answer)
76
+ partial_text += "\n\nRepetition Metrics:\n"
77
+ partial_text += f"1. Whitespace Score: {whitespace_score:.3f}\n"
78
+ partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
79
+ partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
80
+ partial_text += (
81
+ f"1. Non-Repetitive Ratio: {1 - total_repetitions / len(answer):.3f}\n"
82
+ )
83
+
84
+ if index >= 0: # RAG
85
+ key = (
86
+ "wellFormedAnswers"
87
+ if "wellFormedAnswers" in questions[index]
88
+ else "answers"
89
+ )
90
+ scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True)
91
+
92
+ partial_text += "\n\n Performance Metrics:\n"
93
+ partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
94
+ partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
95
+
96
+ partial_text += f"\n\nGround truth: {questions[index][key][0]}\n"
97
+
98
+ partial_text += f"\n\nThe text generation has ended because: {finish_reason}\n"
99
+
100
+ yield partial_text
101
 
102
 
 
 
 
103
  demo = gr.ChatInterface(
104
+ fn=chat,
105
+ examples=examples,
106
+ cache_examples=False,
107
+ additional_inputs_accordion=gr.Accordion(
108
+ label="⚙️ Parameters", open=False, render=False
109
+ ),
110
  additional_inputs=[
111
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
112
+ gr.Slider(
113
+ minimum=0, maximum=2, step=0.1, value=0, label="Temperature", render=False
114
+ ),
115
+ gr.Slider(
116
+ minimum=-2,
117
+ maximum=2,
118
+ step=0.1,
119
+ value=0,
120
+ label="Frequency Penalty",
121
+ render=False,
122
+ ),
123
+ gr.Slider(
124
+ minimum=-2,
125
+ maximum=2,
126
+ step=0.1,
127
+ value=0,
128
+ label="Presence Penalty",
129
+ render=False,
130
+ ),
131
+ gr.Slider(
132
+ minimum=128,
133
+ maximum=4096,
134
+ step=1,
135
+ value=512,
136
+ label="Max new tokens",
137
+ render=False,
138
+ ),
139
  gr.Slider(
140
  minimum=0.1,
141
  maximum=1.0,
 
145
  ),
146
  ],
147
  )
148
+ demo.launch()
 
 
 
eval_modules/calc_repetitions_v2e.py ADDED
@@ -0,0 +1,1333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import math
4
+ import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.ticker as mtick
8
+ import seaborn as sns
9
+ import nltk
10
+ import evaluate
11
+ import traceback
12
+
13
+ bert_score = evaluate.load("bertscore")
14
+ meteor = evaluate.load("meteor")
15
+
16
+ print(f"loading: {__file__}")
17
+
18
+ # pattern_non_word_char_repetition = re.compile(r"\s{5,}")
19
+ # pattern_text_repetitions = re.compile(r"(.{5}.*)\s*((\1)\s*)+", re.M | re.DOTALL)
20
+
21
+ # final version
22
+ pattern_non_word_char_repetition = re.compile(r"[\s\W]{5,}")
23
+ pattern_text_repetitions = re.compile(
24
+ r"(?P<repeat>.{5}.*?)(?:[\s\W]*(?P=repeat))+", re.M | re.DOTALL | re.IGNORECASE
25
+ )
26
+ # Explanation of the Regex Pattern:
27
+ # (?P<repeat>.{5}.*?): Captures any sequence of characters with minimal length of 5 and names this group repeat.
28
+ # .*?: Matches zero or more characters, non-greedily (as few as possible).
29
+ # (?:[\s\W]+(?P=repeat))+: A non-capturing group that matches one or more repetitions of:
30
+ # [\s\W]+: One or more whitespace or non-word characters (spaces, punctuation, etc.).
31
+ # (?P=repeat): A backreference to the named group repeat.
32
+
33
+
34
+ def del_non_word_char_repetition(text, debug=False):
35
+ count = 0
36
+
37
+ if isinstance(text, str):
38
+ if debug:
39
+ print("----detect non-word characters repetition----")
40
+ count = len(text)
41
+ text = pattern_non_word_char_repetition.sub("\t", text)
42
+ count -= len(text)
43
+ if debug and count:
44
+ print(f"removed non-word characters repetition: {count}")
45
+ return text, count
46
+
47
+
48
+ # final version for repetition detection
49
+ def detect_text_repetitions(text, debug=False):
50
+ count = 0
51
+
52
+ if isinstance(text, str):
53
+ if debug:
54
+ print("----detect text repetitions----")
55
+ matches = pattern_text_repetitions.finditer(text)
56
+ for match in matches:
57
+ if debug:
58
+ print(match)
59
+ for groupNum in range(0, len(match.groups())):
60
+ groupNum = groupNum + 1
61
+ print(
62
+ "Group {groupNum} found at {start}-{end}: `{group}`".format(
63
+ groupNum=groupNum,
64
+ start=match.start(groupNum),
65
+ end=match.end(groupNum),
66
+ group=match.group(groupNum),
67
+ )
68
+ )
69
+
70
+ start, end = match.span()
71
+ count += end - start - len(match.group(1))
72
+
73
+ return count
74
+
75
+
76
+ def detect_repetitions(text, debug=False):
77
+ if isinstance(text, str) is False:
78
+ return 0, 0, 0
79
+ text, count_non_word_char_repetition = del_non_word_char_repetition(
80
+ text, debug=debug
81
+ )
82
+ count_text_repetitions = detect_text_repetitions(text, debug=debug)
83
+ total_repetitions = count_non_word_char_repetition + count_text_repetitions
84
+
85
+ result = (count_non_word_char_repetition, count_text_repetitions, total_repetitions)
86
+
87
+ if debug:
88
+ print(result)
89
+ return result
90
+
91
+
92
+ def detect_scores(
93
+ row, debug=False, answer_col="answer", ground_truth_col="ground_truth"
94
+ ):
95
+ newline_score, repetition_score, total_repetitions = detect_repetitions(
96
+ row[answer_col], debug=debug
97
+ )
98
+
99
+ if ground_truth_col:
100
+ ground_truth_newline_score, ground_truth_repetition_score, _ = (
101
+ detect_repetitions(row[ground_truth_col], debug=debug)
102
+ )
103
+
104
+ newline_score -= ground_truth_newline_score
105
+ if newline_score < 0:
106
+ newline_score = 0
107
+
108
+ repetition_score -= ground_truth_repetition_score
109
+ if repetition_score < 0:
110
+ repetition_score = 0
111
+
112
+ total_repetitions = newline_score + repetition_score
113
+
114
+ return pd.Series([newline_score, repetition_score, total_repetitions])
115
+
116
+
117
+ def load_with_newline_and_repetition_scores(result_file, force_recalculate=False):
118
+ print(f"loading result file: {result_file}")
119
+ df = pd.read_csv(result_file, comment="#", on_bad_lines="warn")
120
+
121
+ if (
122
+ force_recalculate
123
+ or "newline_score" not in df.columns
124
+ or "repetition_score" not in df.columns
125
+ or "total_repetitions" not in df.columns
126
+ or "nrr" not in df.columns
127
+ or "rr" not in df.columns
128
+ ):
129
+ if (
130
+ force_recalculate
131
+ or "newline_score" not in df.columns
132
+ or "repetition_score" not in df.columns
133
+ or "total_repetitions" not in df.columns
134
+ ):
135
+ df[["newline_score", "repetition_score", "total_repetitions"]] = df.apply(
136
+ detect_scores, axis=1
137
+ )
138
+
139
+ df["answer_len"] = df["answer"].apply(
140
+ lambda x: len(x) if isinstance(x, str) else 0
141
+ )
142
+
143
+ df["nrr"] = df.apply(
144
+ lambda x: (
145
+ 1
146
+ if x["answer_len"] == 0
147
+ else 1 - (x["newline_score"] + x["repetition_score"]) / x["answer_len"]
148
+ ),
149
+ axis=1,
150
+ )
151
+
152
+ df["rr"] = df["nrr"].apply(lambda x: 1 - x)
153
+
154
+ df.to_csv(result_file, index=False)
155
+
156
+ return df
157
+
158
+
159
+ def replace_last(source_string, old_string, new_string):
160
+ head, _sep, tail = source_string.rpartition(old_string)
161
+ return head + new_string + tail
162
+
163
+
164
+ def load_for_repetition_penalty(
165
+ csv_result_file, repetition_penalty, force_recalculate=False
166
+ ):
167
+ result_file = replace_last(
168
+ csv_result_file, ".csv", f"_RP_{repetition_penalty:.3f}.csv"
169
+ )
170
+ return load_with_newline_and_repetition_scores(
171
+ result_file, force_recalculate=force_recalculate
172
+ )
173
+
174
+
175
+ rap_penalty_functions = {
176
+ "linear": lambda x: x,
177
+ "quadratic": lambda x: x * x,
178
+ "cubic": lambda x: x * x * x,
179
+ "logarithmic": lambda x: math.log(x + 1, 2),
180
+ "exponential": lambda x: math.exp(x - 1),
181
+ }
182
+
183
+
184
+ def calc_adjusted_performance(f, r, l=1, penalty_function="cubic"):
185
+ n = 1 - r / l if l > 0 else 0
186
+ return f * rap_penalty_functions[penalty_function](n)
187
+
188
+
189
+ def calculate_adjusted_performance(row):
190
+ r = row["total_repetitions"]
191
+ l = row["answer_len"]
192
+ adjusted_precision = calc_adjusted_performance(row["precision"], r, l)
193
+ adjusted_recall = calc_adjusted_performance(row["recall"], r, l)
194
+ return pd.Series([adjusted_precision, adjusted_recall])
195
+
196
+
197
+ def load_performance_df(csv_result_file, repetition_penalty):
198
+ result_file = replace_last(
199
+ csv_result_file, ".csv", f"_RP_{repetition_penalty:.3f}-t2_evaluated.json"
200
+ )
201
+ result_file = result_file.replace("/results/", "/eval/")
202
+ print(f"loading json file: {result_file}")
203
+ df = pd.read_json(result_file)
204
+
205
+ return df
206
+
207
+
208
+ def calculate_performance_score(
209
+ csv_result_file, repetition_penalty, force_recalculate=False
210
+ ):
211
+ result_file = replace_last(
212
+ csv_result_file, ".csv", f"_rpp_{repetition_penalty:.2f}.csv"
213
+ )
214
+
215
+ if os.path.exists(result_file):
216
+ print(f"loading result file: {result_file}")
217
+ df = load_with_newline_and_repetition_scores(
218
+ result_file, force_recalculate=force_recalculate
219
+ )
220
+ else:
221
+ print(f"re-creating result file: {result_file}")
222
+ df = pd.DataFrame()
223
+ force_recalculate = True
224
+
225
+ if force_recalculate or "f2" in df.columns or "f1" not in df.columns:
226
+ try:
227
+ perf_df = load_performance_df(csv_result_file, repetition_penalty)
228
+ df.drop(
229
+ columns=[
230
+ "precision",
231
+ "recall",
232
+ "f1",
233
+ "f2",
234
+ "entities_in_answer",
235
+ "entities_in_question",
236
+ "word_count",
237
+ ],
238
+ errors="ignore",
239
+ inplace=True,
240
+ )
241
+
242
+ df["id"] = perf_df["id"]
243
+ df["question"] = perf_df["question"]
244
+ df["answer"] = perf_df["pred_answer"]
245
+ df["word_count"] = df["answer"].apply(
246
+ lambda x: len(nltk.word_tokenize(x)) if isinstance(x, str) else 0
247
+ )
248
+ df["ground_truth"] = perf_df["ground_truth"]
249
+
250
+ df["eval_gemini_1.0_pro"] = perf_df["eval_gemini_1.0_pro"]
251
+ df["precision"] = perf_df["score"].apply(lambda x: x[0])
252
+ df["recall"] = perf_df["score"].apply(lambda x: x[1])
253
+ df["f1"] = perf_df["score"].apply(lambda x: x[2])
254
+ except Exception as e:
255
+ print(f"\tignored error: {e}")
256
+ # traceback.print_exc()
257
+
258
+ df[["newline_score", "repetition_score", "total_repetitions"]] = df.apply(
259
+ detect_scores, axis=1
260
+ )
261
+ df["answer_len"] = df["answer"].apply(
262
+ lambda x: len(x) if isinstance(x, str) else 0
263
+ )
264
+
265
+ df[["adjusted_precision", "adjusted_recall"]] = df.apply(
266
+ calculate_adjusted_performance, axis=1
267
+ )
268
+
269
+ df.to_csv(result_file, index=False)
270
+ print(f"performance scores saved to result file: {result_file}")
271
+
272
+ # print(f"df len: {len(df)}")
273
+
274
+ return df
275
+
276
+
277
+ def adjust_perf_scores_with_repetition_penalty(result, precision, recall):
278
+ newline_score = [
279
+ df["newline_score"].mean() for df in result["df_list_repetition_penalty"]
280
+ ]
281
+
282
+ repetition_score = [
283
+ df["repetition_score"].mean() for df in result["df_list_repetition_penalty"]
284
+ ]
285
+
286
+ answer_len = [
287
+ df["answer_len"].mean() for df in result["df_list_repetition_penalty"]
288
+ ]
289
+
290
+ precision = [
291
+ calc_adjusted_performance(f, n + r, l)
292
+ for f, n, r, l in zip(precision, newline_score, repetition_score, answer_len)
293
+ ]
294
+ recall = [
295
+ calc_adjusted_performance(f, n + r, l)
296
+ for f, n, r, l in zip(recall, newline_score, repetition_score, answer_len)
297
+ ]
298
+
299
+ return precision, recall
300
+
301
+
302
+ def plot_performance_scores(
303
+ result,
304
+ models=None,
305
+ title="Performance",
306
+ ):
307
+ if models is None:
308
+ models = result.keys()
309
+ for model in models:
310
+ print(f"model: {model}")
311
+ df = result[model]["df_overall"]
312
+
313
+ # Calculate the statistics
314
+ precision = [
315
+ df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
316
+ ]
317
+ recall = [
318
+ df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
319
+ ]
320
+ f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
321
+ best_f1 = max(f1)
322
+ best_f1_index = f1.index(best_f1)
323
+
324
+ precision, recall = adjust_perf_scores_with_repetition_penalty(
325
+ result[model], precision, recall
326
+ )
327
+ afrp = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
328
+
329
+ # f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
330
+ best_afrp = max(afrp)
331
+ best_afrp_index = afrp.index(best_afrp)
332
+
333
+ adjusted_precision = [
334
+ df["adjusted_precision"].mean()
335
+ for df in result[model]["df_list_repetition_penalty"]
336
+ ]
337
+ adjusted_recall = [
338
+ df["adjusted_recall"].mean()
339
+ for df in result[model]["df_list_repetition_penalty"]
340
+ ]
341
+ afrp2 = [
342
+ 2 * (p * r) / (p + r) for p, r in zip(adjusted_precision, adjusted_recall)
343
+ ]
344
+ best_afrp2 = max(afrp2)
345
+ best_afrp2_index = afrp2.index(best_afrp2)
346
+
347
+ repetition_penalties = list(df["repetition_penalty"])
348
+
349
+ # line plot for precision, recall, f1
350
+ plt.figure(figsize=(10, 6))
351
+
352
+ plt.axvspan(
353
+ repetition_penalties[best_f1_index] - 0.01,
354
+ repetition_penalties[best_f1_index] + 0.01,
355
+ alpha=0.5,
356
+ edgecolor="none",
357
+ facecolor="blue",
358
+ )
359
+
360
+ # plt.axvspan(
361
+ # repetition_penalties[best_afrp2_index] - 0.01,
362
+ # repetition_penalties[best_afrp2_index] + 0.01,
363
+ # alpha=0.5,
364
+ # edgecolor="none",
365
+ # facecolor="green",
366
+ # )
367
+
368
+ plt.axvspan(
369
+ repetition_penalties[best_afrp_index] - 0.01,
370
+ repetition_penalties[best_afrp_index] + 0.01,
371
+ alpha=0.5,
372
+ edgecolor="none",
373
+ facecolor="orange",
374
+ )
375
+
376
+ plt.plot(repetition_penalties, f1, label="F1", marker="D", color="blue")
377
+ # plt.plot(
378
+ # repetition_penalties,
379
+ # afrp2,
380
+ # label="Per-question RAP - F1",
381
+ # marker="s",
382
+ # color="green",
383
+ # )
384
+ plt.plot(
385
+ repetition_penalties,
386
+ afrp,
387
+ label="RAP - F1",
388
+ marker="o",
389
+ color="orange",
390
+ )
391
+ plt.xlabel("Repetition Penalties")
392
+ plt.ylabel("Score")
393
+ # plt.xlim(0.99, 1.31)
394
+ # y in percentage
395
+ plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
396
+ plt.title(f"{model} {title}")
397
+ plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
398
+
399
+ plt.show()
400
+
401
+
402
+ def plot_best_afrp(
403
+ result,
404
+ models=None,
405
+ title="Models with Best RAP - F1",
406
+ ref_result=None,
407
+ ):
408
+ # Initialize lists to store the statistics
409
+ model_names = []
410
+ best_f1 = []
411
+ best_afrp = []
412
+ best_repetition_penalty = []
413
+ best_mtr = []
414
+
415
+ if models is None:
416
+ models = result.keys()
417
+ for model in models:
418
+ print(f"model: {model}")
419
+ df = result[model]["df_overall"]
420
+
421
+ # Calculate the statistics
422
+ precision = [
423
+ df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
424
+ ]
425
+ recall = [
426
+ df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
427
+ ]
428
+ # f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
429
+ f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
430
+
431
+ newline_score = [
432
+ df["newline_score"].mean()
433
+ for df in result[model]["df_list_repetition_penalty"]
434
+ ]
435
+ # print(f"newline_score: {newline_score}")
436
+
437
+ repetition_score = [
438
+ df["repetition_score"].mean()
439
+ for df in result[model]["df_list_repetition_penalty"]
440
+ ]
441
+ # print(f"repetition_score: {repetition_score}")
442
+
443
+ answer_len = [
444
+ df["answer_len"].mean()
445
+ for df in result[model]["df_list_repetition_penalty"]
446
+ ]
447
+
448
+ afrp = [
449
+ calc_adjusted_performance(f, n + r, l)
450
+ for f, n, r, l in zip(f1, newline_score, repetition_score, answer_len)
451
+ ]
452
+
453
+ best_afrp.append(max(afrp))
454
+ best_afrp_index = afrp.index(best_afrp[-1])
455
+ best_repetition_penalty.append(df["repetition_penalty"][best_afrp_index])
456
+
457
+ best_f1.append(f1[best_afrp_index])
458
+ best_mtr.append(
459
+ newline_score[best_afrp_index] + repetition_score[best_afrp_index]
460
+ )
461
+
462
+ # print(
463
+ # f"best repetition penalty: {best_repetition_penalty[-1]}, best afrp: {best_afrp[-1]}, f1: {best_f1[-1]}"
464
+ # )
465
+
466
+ df = result[model]["df_list_repetition_penalty"][best_afrp_index]
467
+
468
+ model_names.append(
469
+ f"{model} (RP={best_repetition_penalty[-1]})"
470
+ ) # Add the model name to the list
471
+
472
+ if ref_result is not None:
473
+ print("ref_result:", ref_result)
474
+ for model in ref_result.keys():
475
+ model_names.append(model)
476
+ df = pd.read_csv(ref_result[model])
477
+ # df = df[df["id"].isin(wikidata_df["id"])]
478
+
479
+ p = df["precision"].mean()
480
+ r = df["recall"].mean()
481
+
482
+ f1 = 2 * p * r / (p + r) if p + r > 0 else 0
483
+ best_f1.append(f1)
484
+ best_afrp.append(f1)
485
+ best_mtr.append(0)
486
+
487
+ print("model_names:", model_names)
488
+ # print("best_f1:", best_f1)
489
+ # print("best_afrp:", best_afrp)
490
+
491
+ # Create a DataFrame with the statistics
492
+ data = pd.DataFrame(
493
+ {
494
+ "Model": model_names,
495
+ "RAP - F1": best_afrp,
496
+ "F1": best_f1,
497
+ }
498
+ )
499
+
500
+ # Melt the DataFrame to a long format
501
+ data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
502
+
503
+ # Pivot the DataFrame to a wide format
504
+ data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
505
+
506
+ # make sure the columns are following the order of the models
507
+ data_pivoted = data_pivoted[model_names]
508
+
509
+ # make sure three groups in the order of precision, recall, f1
510
+ data_pivoted = data_pivoted.reindex(["RAP - F1", "F1"])
511
+
512
+ # Plot the statistics
513
+ plt.figure(figsize=(15, 6))
514
+ ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
515
+ plt.title(title)
516
+ plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
517
+
518
+ # Set the rotation of the x-axis labels to 0 degrees
519
+ plt.xticks(rotation=0)
520
+
521
+ # Format the y-axis to display as percentage
522
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
523
+
524
+ # get the max value of the y-axis
525
+ a1 = max(best_afrp)
526
+ a2 = max(best_f1)
527
+
528
+ max_value = max([a1, a2]) * 1.12
529
+ print("max_value:", max_value)
530
+
531
+ # Set the y-axis limit up to 70%
532
+ ax.set_ylim(0, max_value)
533
+
534
+ # Add the values above each bar
535
+ for p in ax.patches:
536
+ ax.annotate(
537
+ f"{p.get_height() * 100:.1f}",
538
+ (p.get_x() + p.get_width() / 2.0, p.get_height()),
539
+ ha="center",
540
+ va="bottom",
541
+ xytext=(0, 10),
542
+ textcoords="offset points",
543
+ rotation=90,
544
+ )
545
+
546
+ plt.show()
547
+ return data_pivoted, best_mtr
548
+
549
+
550
+ def plot_best_performance(
551
+ result,
552
+ models=None,
553
+ title="Models with Best F1 Score",
554
+ adjusted_f1=False,
555
+ ref_result=None,
556
+ ):
557
+ # Initialize lists to store the statistics
558
+ model_names = []
559
+ best_precision = []
560
+ best_recall = []
561
+ best_f1 = []
562
+ best_repetition_penalty = []
563
+ best_mtr = []
564
+
565
+ if models is None:
566
+ models = result.keys()
567
+ for model in models:
568
+ print(f"model: {model}")
569
+ df = result[model]["df_overall"]
570
+
571
+ # Calculate the statistics
572
+ precision = [
573
+ df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
574
+ ]
575
+ recall = [
576
+ df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
577
+ ]
578
+ newline_score = [
579
+ df["newline_score"].mean()
580
+ for df in result[model]["df_list_repetition_penalty"]
581
+ ]
582
+
583
+ repetition_score = [
584
+ df["repetition_score"].mean()
585
+ for df in result[model]["df_list_repetition_penalty"]
586
+ ]
587
+
588
+ if adjusted_f1:
589
+ precision, recall = adjust_perf_scores_with_repetition_penalty(
590
+ result[model], precision, recall
591
+ )
592
+
593
+ # f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
594
+ f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
595
+
596
+ best_f1.append(max(f1))
597
+ best_f1_index = f1.index(best_f1[-1])
598
+ best_repetition_penalty.append(df["repetition_penalty"][best_f1_index])
599
+
600
+ best_precision.append(precision[best_f1_index])
601
+ best_recall.append(recall[best_f1_index])
602
+ best_mtr.append(newline_score[best_f1_index] + repetition_score[best_f1_index])
603
+
604
+ print(
605
+ f"best repetition penalty: {best_repetition_penalty[-1]}, best f1: {best_f1[-1]}, precision: {best_precision[-1]}, recall: {best_recall[-1]}"
606
+ )
607
+
608
+ df = result[model]["df_list_repetition_penalty"][best_f1_index]
609
+
610
+ model_names.append(
611
+ f"{model} (RP={best_repetition_penalty[-1]})"
612
+ ) # Add the model name to the list
613
+
614
+ # print sum for columns: newline_score, repetition_score
615
+ print(
616
+ f"newline_score: {df['newline_score'].sum()}, repetition_score: {df['repetition_score'].sum()}"
617
+ )
618
+
619
+ if ref_result is not None:
620
+ print("ref_result:", ref_result)
621
+ for model in ref_result.keys():
622
+ model_names.append(model)
623
+ df = pd.read_csv(ref_result[model])
624
+ # df = df[df["id"].isin(wikidata_df["id"])]
625
+
626
+ best_precision.append(df["precision"].mean())
627
+ best_recall.append(df["recall"].mean())
628
+ f1 = (
629
+ 2
630
+ * (best_precision[-1] * best_recall[-1])
631
+ / (best_precision[-1] + best_recall[-1])
632
+ )
633
+ # best_f1.append(df["f1"].mean())
634
+ best_f1.append(f1)
635
+ best_mtr.append(0)
636
+
637
+ # Create a DataFrame with the statistics
638
+ data = (
639
+ pd.DataFrame(
640
+ {
641
+ "Model": model_names,
642
+ "Adjusted Precision with RP": best_precision,
643
+ "Adjusted Recall with RP": best_recall,
644
+ "Adjusted F1 with RP": best_f1,
645
+ }
646
+ )
647
+ if adjusted_f1
648
+ else pd.DataFrame(
649
+ {
650
+ "Model": model_names,
651
+ "Precision": best_precision,
652
+ "Recall": best_recall,
653
+ "F1": best_f1,
654
+ }
655
+ )
656
+ )
657
+ columns = list(data.columns)
658
+
659
+ # Melt the DataFrame to a long format
660
+ data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
661
+
662
+ # Pivot the DataFrame to a wide format
663
+ data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
664
+
665
+ # make sure the columns are following the order of the models
666
+ data_pivoted = data_pivoted[model_names]
667
+
668
+ # make sure three groups in the order of precision, recall, f1
669
+ data_pivoted = data_pivoted.reindex(columns[1:])
670
+
671
+ # Plot the statistics
672
+ plt.figure(figsize=(10, 6))
673
+ ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
674
+ plt.title(title)
675
+ plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
676
+
677
+ # Set the rotation of the x-axis labels to 0 degrees
678
+ plt.xticks(rotation=0)
679
+
680
+ # Format the y-axis to display as percentage
681
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
682
+
683
+ # get the max value of the y-axis
684
+ a1 = max(best_precision)
685
+ a2 = max(best_recall)
686
+ a3 = max(best_f1)
687
+
688
+ max_value = max([a1, a2, a3]) * 1.12
689
+ print("max_value:", max_value)
690
+
691
+ # Set the y-axis limit up to 70%
692
+ ax.set_ylim(0, max_value)
693
+
694
+ # Add the values above each bar
695
+ for p in ax.patches:
696
+ ax.annotate(
697
+ f"{p.get_height() * 100:.1f}",
698
+ (p.get_x() + p.get_width() / 2.0, p.get_height()),
699
+ ha="center",
700
+ va="bottom",
701
+ xytext=(0, 10),
702
+ textcoords="offset points",
703
+ rotation=90,
704
+ )
705
+
706
+ plt.show()
707
+ return data_pivoted, best_mtr
708
+
709
+
710
+ def plot_best_performance_ms_macro(
711
+ result,
712
+ models=None,
713
+ title="Models with Best RAP - Performance",
714
+ ref_result=None,
715
+ skip_generic_prompt=False,
716
+ include_adjusted_performance=True,
717
+ ):
718
+ # Initialize lists to store the statistics
719
+ model_names = []
720
+ best_f1 = []
721
+ best_afrp = []
722
+ best_repetition_penalty = []
723
+ best_bleu1 = []
724
+ best_rougeL = []
725
+ best_mtr = []
726
+
727
+ if models is None:
728
+ models = result.keys()
729
+ for model in models:
730
+ if skip_generic_prompt and "generic prompt" in model:
731
+ continue
732
+ print(f"model: {model}")
733
+ df = result[model]["df_overall"]
734
+
735
+ # Calculate the statistics
736
+ bleu1 = [x for x in df["bleu1"]]
737
+ rougeL = [x for x in df["rougeL"]]
738
+ f1 = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
739
+
740
+ newline_score = [
741
+ df["newline_score"].mean()
742
+ for df in result[model]["df_list_repetition_penalty"]
743
+ ]
744
+ # print(f"newline_score: {newline_score}")
745
+
746
+ repetition_score = [
747
+ df["repetition_score"].mean()
748
+ for df in result[model]["df_list_repetition_penalty"]
749
+ ]
750
+ # print(f"repetition_score: {repetition_score}")
751
+
752
+ answer_len = [
753
+ df["answer_len"].mean()
754
+ for df in result[model]["df_list_repetition_penalty"]
755
+ ]
756
+
757
+ afrp = [
758
+ calc_adjusted_performance(f, n + r, l)
759
+ for f, n, r, l in zip(f1, newline_score, repetition_score, answer_len)
760
+ ]
761
+
762
+ best_afrp.append(max(afrp if include_adjusted_performance else f1))
763
+ best_afrp_index = (
764
+ afrp.index(best_afrp[-1])
765
+ if include_adjusted_performance
766
+ else f1.index(best_afrp[-1])
767
+ )
768
+ best_repetition_penalty.append(df["repetition_penalty"][best_afrp_index])
769
+
770
+ best_f1.append(f1[best_afrp_index])
771
+ best_bleu1.append(bleu1[best_afrp_index])
772
+ best_rougeL.append(rougeL[best_afrp_index])
773
+ best_mtr.append(
774
+ newline_score[best_afrp_index] + repetition_score[best_afrp_index]
775
+ )
776
+
777
+ # print(
778
+ # f"best repetition penalty: {best_repetition_penalty[-1]}, best afrp: {best_afrp[-1]}, f1: {best_f1[-1]}"
779
+ # )
780
+
781
+ df = result[model]["df_list_repetition_penalty"][best_afrp_index]
782
+
783
+ model_names.append(
784
+ f"{model} (RP={best_repetition_penalty[-1]})"
785
+ ) # Add the model name to the list
786
+
787
+ if ref_result is not None:
788
+ print("ref_result:", ref_result)
789
+ for model in ref_result.keys():
790
+ model_names.append(model)
791
+ df = pd.read_csv(ref_result[model], comment="#", on_bad_lines="warn")
792
+ # df = df[df["id"].isin(wikidata_df["id"])]
793
+
794
+ p = df["bleu1"][0]
795
+ best_bleu1.append(p)
796
+
797
+ r = df["rougeL"][0]
798
+ best_rougeL.append(r)
799
+
800
+ f1 = 2 * p * r / (p + r) if p + r > 0 else 0
801
+ best_f1.append(f1)
802
+ best_afrp.append(f1)
803
+ best_mtr.append(0)
804
+
805
+ # print("model_names:", model_names)
806
+ # print("best_f1:", best_f1)
807
+ # print("best_afrp:", best_afrp)
808
+
809
+ # Create a DataFrame with the statistics
810
+ data = (
811
+ pd.DataFrame(
812
+ {
813
+ "Model": model_names,
814
+ "RAP - Perf Score": best_afrp,
815
+ "Overall Perf Score": best_f1,
816
+ }
817
+ )
818
+ if include_adjusted_performance
819
+ else pd.DataFrame(
820
+ {
821
+ "Model": model_names,
822
+ "Bleu-1": best_bleu1,
823
+ "Rouge-L": best_rougeL,
824
+ "Overall Perf Score": best_f1,
825
+ }
826
+ )
827
+ )
828
+
829
+ # Melt the DataFrame to a long format
830
+ data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
831
+
832
+ # Pivot the DataFrame to a wide format
833
+ data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
834
+
835
+ # make sure the columns are following the order of the models
836
+ data_pivoted = data_pivoted[model_names]
837
+
838
+ columns = list(data.columns)
839
+ data_pivoted = data_pivoted.reindex(columns[1:])
840
+
841
+ # Plot the statistics
842
+ plt.figure(figsize=(10, 6))
843
+ ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
844
+ plt.title(title)
845
+ plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
846
+
847
+ # Set the rotation of the x-axis labels to 0 degrees
848
+ plt.xticks(rotation=0)
849
+
850
+ # Format the y-axis to display as percentage
851
+ ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
852
+
853
+ # get the max value of the y-axis
854
+ a1 = max(best_afrp)
855
+ a2 = max(best_f1)
856
+ a3 = max(best_bleu1)
857
+ a4 = max(best_rougeL)
858
+
859
+ max_value = (
860
+ max([a1, a2] if include_adjusted_performance else [a1, a2, a3, a4]) * 1.12
861
+ )
862
+ print("max_value:", max_value)
863
+
864
+ # Set the y-axis limit up to 70%
865
+ ax.set_ylim(0, max_value)
866
+
867
+ # Add the values above each bar
868
+ for p in ax.patches:
869
+ ax.annotate(
870
+ f"{p.get_height() * 100:.1f}",
871
+ (p.get_x() + p.get_width() / 2.0, p.get_height()),
872
+ ha="center",
873
+ va="bottom",
874
+ xytext=(0, 10),
875
+ textcoords="offset points",
876
+ rotation=90,
877
+ )
878
+
879
+ plt.show()
880
+ return data_pivoted, best_mtr
881
+
882
+
883
+ all_open_source_models = [
884
+ "gemma-1.1-2b-it",
885
+ "Phi-3-mini-128k-instruct",
886
+ "gemma-1.1-7b-it",
887
+ "Llama-2-7b-chat-hf",
888
+ "Mistral-7B-Instruct-v0.2",
889
+ "Meta-Llama-3-8B-Instruct",
890
+ "Llama-2-13b-chat-hf",
891
+ "Llama-2-70b-chat-hf",
892
+ "Meta-Llama-3-70B-Instruct",
893
+ ]
894
+
895
+
896
+ def load_for_repetition_penalty_ms_macro(
897
+ csv_result_file, repetition_penalty, force_recalculate=False
898
+ ):
899
+ result_file = replace_last(
900
+ csv_result_file, ".csv", f"_rpp_{repetition_penalty:.2f}.csv"
901
+ )
902
+ df = load_with_newline_and_repetition_scores(
903
+ result_file, force_recalculate=force_recalculate
904
+ )
905
+
906
+ return df
907
+
908
+
909
+ # MS MACRO
910
+ def plot_performance_scores_ms_macro(
911
+ result,
912
+ models=None,
913
+ title="Performance",
914
+ ):
915
+ if models is None:
916
+ models = result.keys()
917
+ for model in models:
918
+ print(f"model: {model}")
919
+ df = result[model]["df_overall"]
920
+ # print(result[model]["df_list_repetition_penalty"][0].describe())
921
+
922
+ # Calculate the statistics
923
+ bleu1 = list(df["bleu1"])
924
+ rougeL = list(df["rougeL"])
925
+ f1 = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
926
+ best_f1 = max(f1)
927
+ best_f1_index = f1.index(best_f1)
928
+
929
+ bleu1, rougeL = adjust_perf_scores_with_repetition_penalty(
930
+ result[model], bleu1, rougeL
931
+ )
932
+ afrp = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
933
+
934
+ # f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
935
+ best_afrp = max(afrp)
936
+ best_afrp_index = afrp.index(best_afrp)
937
+
938
+ repetition_penalties = list(df["repetition_penalty"])
939
+
940
+ # line plot for precision, recall, f1
941
+ plt.figure(figsize=(10, 6))
942
+
943
+ plt.axvspan(
944
+ repetition_penalties[best_f1_index] - 0.01,
945
+ repetition_penalties[best_f1_index] + 0.01,
946
+ alpha=0.5,
947
+ edgecolor="none",
948
+ facecolor="blue",
949
+ )
950
+
951
+ plt.axvspan(
952
+ repetition_penalties[best_afrp_index] - 0.01,
953
+ repetition_penalties[best_afrp_index] + 0.01,
954
+ alpha=0.5,
955
+ edgecolor="none",
956
+ facecolor="orange",
957
+ )
958
+
959
+ plt.plot(
960
+ repetition_penalties,
961
+ f1,
962
+ label="Overall Perf Score",
963
+ marker="D",
964
+ color="blue",
965
+ )
966
+ plt.plot(
967
+ repetition_penalties,
968
+ afrp,
969
+ label="RAP - Perf Score",
970
+ marker="o",
971
+ color="orange",
972
+ )
973
+
974
+ plt.xlabel("Repetition Penalties")
975
+ plt.ylabel("Score")
976
+ # plt.xlim(0.99, 1.31)
977
+ # y in percentage
978
+ plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
979
+ plt.title(f"{model} {title}")
980
+ plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
981
+
982
+ plt.show()
983
+
984
+
985
+ def plot_repetition_factors(result, groups):
986
+ for group in groups:
987
+ # Plot the statistics
988
+ plt.figure(figsize=(10, 6))
989
+
990
+ max_value = 0
991
+ for model in result.keys():
992
+ if not group in model.lower():
993
+ continue
994
+ print(f"model: {model}")
995
+ df = result[model]["df_overall"]
996
+ repetition_panelties = [
997
+ repetition_penalty for repetition_penalty in df["repetition_penalty"]
998
+ ]
999
+
1000
+ mean_score = [
1001
+ df["total_repetitions"].mean()
1002
+ for df in result[model]["df_list_repetition_penalty"]
1003
+ ]
1004
+
1005
+ sns.lineplot(x=repetition_panelties, y=mean_score, label=model)
1006
+
1007
+ new_max = max(mean_score)
1008
+ if new_max > max_value:
1009
+ max_value = new_max
1010
+
1011
+ max_value = max_value * 1.05
1012
+ # if max_value < 1.5:
1013
+ # max_value = 1.5
1014
+ # set ylimit
1015
+ plt.ylim(0, max_value)
1016
+
1017
+ # show grid
1018
+ plt.grid(True)
1019
+ plt.xlabel("Repetition Penalties")
1020
+ plt.ylabel("Mean Total Repetitions")
1021
+ plt.title("Mean Total Repetitions vs Repetition Penalties")
1022
+ plt.legend()
1023
+
1024
+ plt.show()
1025
+
1026
+
1027
+ def plot_repetition_factors_by_group(result, group_filter=None):
1028
+ markers = ["D", "o", "s", "x"]
1029
+ colors = ["blue", "orange", "green", "red"]
1030
+
1031
+ # Plot the statistics
1032
+ plt.figure(figsize=(10, 6))
1033
+ index = 0
1034
+ max_value = 0
1035
+
1036
+ for model in result.keys():
1037
+ if group_filter is not None and group_filter not in model:
1038
+ continue
1039
+
1040
+ print(f"model: {model}")
1041
+
1042
+ df = result[model]["df_overall"]
1043
+ repetition_panelties = [
1044
+ repetition_penalty for repetition_penalty in df["repetition_penalty"]
1045
+ ]
1046
+
1047
+ # Calculate the statistics
1048
+ mean_score = [
1049
+ df["total_repetitions"].mean()
1050
+ for df in result[model]["df_list_repetition_penalty"]
1051
+ ]
1052
+ if len(mean_score) != len(repetition_panelties):
1053
+ print(
1054
+ f"model: {model} has different length of repetition penalties and mean score"
1055
+ )
1056
+ print("repetition_panelties:", len(repetition_panelties))
1057
+ print("mean_score:", len(mean_score))
1058
+ continue
1059
+
1060
+ new_max = max(mean_score)
1061
+ if new_max > max_value:
1062
+ max_value = new_max
1063
+
1064
+ sns.lineplot(
1065
+ x=repetition_panelties,
1066
+ y=mean_score,
1067
+ label=model,
1068
+ marker=markers[index],
1069
+ color=colors[index],
1070
+ )
1071
+
1072
+ index += 1
1073
+
1074
+ max_value = max_value * 1.05
1075
+ # if max_value < 1.5:
1076
+ # max_value = 1.5
1077
+ # set ylimit
1078
+ plt.ylim(0, max_value)
1079
+ max_value = 0
1080
+
1081
+ plt.xlabel("Repetition Penalties")
1082
+ plt.ylabel("Mean Total Repetitions")
1083
+ plt.title("Mean Total Repetitions vs Repetition Penalties")
1084
+ plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
1085
+
1086
+ plt.show()
1087
+
1088
+
1089
+ ms_marco_csv_result_files = [
1090
+ "data/results_v2/gemma-1.1-2b-it(RAG - Generic Prompt)_mm.csv",
1091
+ "data/results_v2/gemma-1.1-2b-it(RAG - Chat Template)_mm.csv",
1092
+ "data/results_v2/gemma-1.1-2b-it(Non-RAG)_mm.csv",
1093
+ "data/results_v2/Phi-3-mini-128k-instruct(RAG - Generic Prompt)_mm.csv",
1094
+ "data/results_v2/Phi-3-mini-128k-instruct(RAG - Chat Template)_mm.csv",
1095
+ "data/results_v2/Phi-3-mini-128k-instruct(Non-RAG)_mm.csv",
1096
+ "data/results_v2/gemma-1.1-7b-it(RAG - Generic Prompt)_mm.csv",
1097
+ "data/results_v2/gemma-1.1-7b-it(RAG - Chat Template)_mm.csv",
1098
+ "data/results_v2/gemma-1.1-7b-it(Non-RAG)_mm.csv",
1099
+ "data/results_v2/Llama-2-7b-chat-hf(RAG - Generic Prompt)_mm.csv",
1100
+ "data/results_v2/Llama-2-7b-chat-hf(RAG - Chat Template)_mm.csv",
1101
+ "data/results_v2/Llama-2-7b-chat-hf(Non-RAG)_mm.csv",
1102
+ "data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Generic Prompt)_mm.csv",
1103
+ "data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Chat Template)_mm.csv",
1104
+ "data/results_v2/Mistral-7B-Instruct-v0.2(Non-RAG)_mm.csv",
1105
+ "data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Generic Prompt)_mm.csv",
1106
+ "data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Chat Template)_mm.csv",
1107
+ "data/results_v2/Meta-Llama-3-8B-Instruct(Non-RAG)_mm.csv",
1108
+ "data/results_v2/Llama-2-13b-chat-hf(RAG - Generic Prompt)_mm.csv",
1109
+ "data/results_v2/Llama-2-13b-chat-hf(RAG - Chat Template)_mm.csv",
1110
+ "data/results_v2/Llama-2-13b-chat-hf(Non-RAG)_mm.csv",
1111
+ "data/results_v2/Llama-2-70b-chat-hf(RAG - Generic Prompt)_mm.csv",
1112
+ "data/results_v2/Llama-2-70b-chat-hf(RAG - Chat Template)_mm.csv",
1113
+ "data/results_v2/Llama-2-70b-chat-hf(Non-RAG)_mm.csv",
1114
+ "data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Generic Prompt)_mm.csv",
1115
+ "data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Chat Template)_mm.csv",
1116
+ "data/results_v2/Meta-Llama-3-70B-Instruct(Non-RAG)_mm.csv",
1117
+ ]
1118
+
1119
+ webqsp_csv_result_files = [
1120
+ "data/results_v2/gemma-1.1-2b-it(RAG - Generic Prompt)_wd.csv",
1121
+ "data/results_v2/gemma-1.1-2b-it(RAG - Chat Template)_wd.csv",
1122
+ "data/results_v2/gemma-1.1-2b-it(Non-RAG)_wd.csv",
1123
+ "data/results_v2/Phi-3-mini-128k-instruct(RAG - Generic Prompt)_wd.csv",
1124
+ "data/results_v2/Phi-3-mini-128k-instruct(RAG - Chat Template)_wd.csv",
1125
+ "data/results_v2/Phi-3-mini-128k-instruct(Non-RAG)_wd.csv",
1126
+ "data/results_v2/gemma-1.1-7b-it(RAG - Generic Prompt)_wd.csv",
1127
+ "data/results_v2/gemma-1.1-7b-it(RAG - Chat Template)_wd.csv",
1128
+ "data/results_v2/gemma-1.1-7b-it(Non-RAG)_wd.csv",
1129
+ "data/results_v2/Llama-2-7b-chat-hf(RAG - Generic Prompt)_wd.csv",
1130
+ "data/results_v2/Llama-2-7b-chat-hf(RAG - Chat Template)_wd.csv",
1131
+ "data/results_v2/Llama-2-7b-chat-hf(Non-RAG)_wd.csv",
1132
+ "data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Generic Prompt)_wd.csv",
1133
+ "data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Chat Template)_wd.csv",
1134
+ "data/results_v2/Mistral-7B-Instruct-v0.2(Non-RAG)_wd.csv",
1135
+ "data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Generic Prompt)_wd.csv",
1136
+ "data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Chat Template)_wd.csv",
1137
+ "data/results_v2/Meta-Llama-3-8B-Instruct(Non-RAG)_wd.csv",
1138
+ "data/results_v2/Llama-2-13b-chat-hf(RAG - Generic Prompt)_wd.csv",
1139
+ "data/results_v2/Llama-2-13b-chat-hf(RAG - Chat Template)_wd.csv",
1140
+ "data/results_v2/Llama-2-13b-chat-hf(Non-RAG)_wd.csv",
1141
+ "data/results_v2/Llama-2-70b-chat-hf(RAG - Generic Prompt)_wd.csv",
1142
+ "data/results_v2/Llama-2-70b-chat-hf(RAG - Chat Template)_wd.csv",
1143
+ "data/results_v2/Llama-2-70b-chat-hf(Non-RAG)_wd.csv",
1144
+ "data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Generic Prompt)_wd.csv",
1145
+ "data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Chat Template)_wd.csv",
1146
+ "data/results_v2/Meta-Llama-3-70B-Instruct(Non-RAG)_wd.csv",
1147
+ ]
1148
+
1149
+
1150
+ def calc_rap_scores(
1151
+ result, precision="precision", recall="recall", penalty_function="cubic"
1152
+ ):
1153
+ newline_score = [
1154
+ df["newline_score"].mean() for df in result["df_list_repetition_penalty"]
1155
+ ]
1156
+
1157
+ repetition_score = [
1158
+ df["repetition_score"].mean() for df in result["df_list_repetition_penalty"]
1159
+ ]
1160
+
1161
+ if precision in result["df_list_repetition_penalty"][0].columns:
1162
+ precision = [
1163
+ df[precision].mean() for df in result["df_list_repetition_penalty"]
1164
+ ]
1165
+ recall = [df[recall].mean() for df in result["df_list_repetition_penalty"]]
1166
+ else:
1167
+ precision = result["df_overall"][precision]
1168
+ recall = result["df_overall"][recall]
1169
+
1170
+ f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
1171
+
1172
+ nrr = [
1173
+ 1 - (n + r) / s
1174
+ for f, n, r, s in zip(
1175
+ f1, newline_score, repetition_score, result["df_overall"]["answer_len"]
1176
+ )
1177
+ ]
1178
+
1179
+ rap = [
1180
+ calc_adjusted_performance(f, 1 - n, penalty_function=penalty_function)
1181
+ for f, n in zip(f1, nrr)
1182
+ ]
1183
+
1184
+ return newline_score, repetition_score, f1, rap, nrr
1185
+
1186
+
1187
+ def get_model_name(csv_result_file):
1188
+ parts = re.split(r"[_/]", csv_result_file)
1189
+ print(f"parts: {parts}")
1190
+ model_name = parts[3]
1191
+ return model_name
1192
+
1193
+
1194
+ def load_webqsp_result(
1195
+ csv_result_files, force_recalculate=False, save=False, penalty_function="cubic"
1196
+ ):
1197
+ result = {}
1198
+ for i, csv_result_file in enumerate(csv_result_files):
1199
+ try:
1200
+ df = pd.read_csv(csv_result_file)
1201
+ model_name = get_model_name(csv_result_file)
1202
+ print(f"\tmodel_name: {model_name}")
1203
+
1204
+ dfs = [
1205
+ calculate_performance_score(
1206
+ csv_result_file,
1207
+ repetition_penalty,
1208
+ force_recalculate=force_recalculate,
1209
+ )
1210
+ for repetition_penalty in df["repetition_penalty"]
1211
+ ]
1212
+
1213
+ answer_lens = []
1214
+ for df_rpp in dfs:
1215
+ answer_lens.append(df_rpp["answer_len"].mean())
1216
+ df["answer_len"] = answer_lens
1217
+
1218
+ result[model_name] = {
1219
+ "df_overall": df,
1220
+ "df_list_repetition_penalty": dfs,
1221
+ "file": csv_result_file,
1222
+ }
1223
+ newline_score, repetition_score, perf, rap, nrr = calc_rap_scores(
1224
+ result[model_name], penalty_function=penalty_function
1225
+ )
1226
+ df["newline_score"] = newline_score
1227
+ df["repetition_score"] = repetition_score
1228
+ df["total_repetitions"] = df["newline_score"] + df["repetition_score"]
1229
+ df["perf"] = perf
1230
+ df["nrr"] = nrr
1231
+ df["rap"] = rap
1232
+ df["rr"] = df["nrr"].apply(lambda x: 1 - x)
1233
+ df["rrp"] = df["rr"].apply(lambda x: x * 100)
1234
+ if save:
1235
+ df.to_csv(csv_result_file, index=False)
1236
+ except Exception as e:
1237
+ print(f"Error: {e}")
1238
+ traceback.print_exc()
1239
+
1240
+ return result
1241
+
1242
+
1243
+ def load_ms_marco_result(
1244
+ csv_result_files,
1245
+ force_recalculate=False,
1246
+ calc_bertscore=True,
1247
+ save=False,
1248
+ penalty_function="cubic",
1249
+ ):
1250
+ result = {}
1251
+ for csv_result_file in csv_result_files:
1252
+ try:
1253
+ df = pd.read_csv(csv_result_file)
1254
+ model_name = get_model_name(csv_result_file)
1255
+ print(f"\tmodel_name: {model_name}")
1256
+
1257
+ dfs = [
1258
+ load_for_repetition_penalty_ms_macro(
1259
+ csv_result_file,
1260
+ repetition_penalty,
1261
+ force_recalculate=force_recalculate,
1262
+ )
1263
+ for repetition_penalty in df["repetition_penalty"]
1264
+ ]
1265
+
1266
+ answer_lens = []
1267
+ for df_rpp in dfs:
1268
+ answer_lens.append(df_rpp["answer_len"].mean())
1269
+ df["answer_len"] = answer_lens
1270
+
1271
+ col = "bert_score" if calc_bertscore else "meteor"
1272
+ score_unavailable = col not in df.columns
1273
+
1274
+ if score_unavailable:
1275
+ save = True
1276
+ bert_meteor_scores = []
1277
+ bert_score_references = None
1278
+ for df_rpp in dfs:
1279
+ if calc_bertscore:
1280
+ bert_meteor_score = 0
1281
+
1282
+ for i, row in df_rpp.iterrows():
1283
+ answer = row["answer"]
1284
+ if not isinstance(answer, str):
1285
+ answer = ""
1286
+ bert_meteor_score += bert_score.compute(
1287
+ predictions=[answer],
1288
+ references=[row["ground_truth"][0]],
1289
+ lang="en",
1290
+ model_type="microsoft/deberta-large-mnli",
1291
+ )["f1"][0]
1292
+ # get average of bertscore
1293
+ bert_meteor_score = bert_meteor_score / len(df_rpp)
1294
+
1295
+ print(f"bert_score: {bert_meteor_score}")
1296
+ else:
1297
+ bert_meteor_score = meteor.compute(
1298
+ predictions=df_rpp["answer"],
1299
+ references=df_rpp["ground_truth"],
1300
+ )["meteor"]
1301
+
1302
+ bert_meteor_scores.append(bert_meteor_score)
1303
+
1304
+ df[col] = bert_meteor_scores
1305
+
1306
+ result[model_name] = {
1307
+ "df_overall": df,
1308
+ "df_list_repetition_penalty": dfs,
1309
+ "file": csv_result_file,
1310
+ }
1311
+ newline_score, repetition_score, perf, rap, nrr = calc_rap_scores(
1312
+ result[model_name],
1313
+ precision=col,
1314
+ recall=col,
1315
+ penalty_function=penalty_function,
1316
+ )
1317
+ df["newline_score"] = newline_score
1318
+ df["repetition_score"] = repetition_score
1319
+ df["total_repetitions"] = df["newline_score"] + df["repetition_score"]
1320
+ df["perf"] = perf
1321
+ df["nrr"] = nrr
1322
+ df["rap"] = rap
1323
+ df["rr"] = df["nrr"].apply(lambda x: 1 - x)
1324
+ df["rrp"] = df["rr"].apply(lambda x: x * 100)
1325
+
1326
+ if save:
1327
+ df.to_csv(csv_result_file, index=False)
1328
+ except Exception as e:
1329
+ print("An error occurred:", e)
1330
+ traceback.print_exc()
1331
+ print(f"csv_result_file: {csv_result_file}")
1332
+
1333
+ return result
eval_modules/utils.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import logging
6
+ import os
7
+ import platform
8
+ import re
9
+ from pathlib import Path
10
+ import evaluate
11
+ import pandas as pd
12
+ import requests
13
+ import torch
14
+ from tqdm import tqdm
15
+
16
+
17
+ class LogRecord(logging.LogRecord):
18
+ def getMessage(self):
19
+ msg = self.msg
20
+ if self.args:
21
+ if isinstance(self.args, dict):
22
+ msg = msg.format(**self.args)
23
+ else:
24
+ msg = msg.format(*self.args)
25
+ return msg
26
+
27
+
28
+ class Logger(logging.Logger):
29
+ def makeRecord(
30
+ self,
31
+ name,
32
+ level,
33
+ fn,
34
+ lno,
35
+ msg,
36
+ args,
37
+ exc_info,
38
+ func=None,
39
+ extra=None,
40
+ sinfo=None,
41
+ ):
42
+ rv = LogRecord(name, level, fn, lno, msg, args, exc_info, func, sinfo)
43
+ if extra is not None:
44
+ for key in extra:
45
+ rv.__dict__[key] = extra[key]
46
+ return rv
47
+
48
+
49
+ def init_settings():
50
+ logging.setLoggerClass(Logger)
51
+ logging.basicConfig(
52
+ level=logging.WARNING,
53
+ format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
54
+ )
55
+
56
+
57
+ def remove_extra_spaces(text):
58
+ return re.sub(" +", " ", text.strip())
59
+
60
+
61
+ def print_llm_response(llm_response, debug_retrieval=True):
62
+ answer = llm_response["answer"] if "answer" in llm_response else None
63
+ if answer is None:
64
+ answer = llm_response["response"] if "response" in llm_response else None
65
+
66
+ if answer is not None:
67
+ print("\n\n***Answer:")
68
+ print(answer)
69
+
70
+ source_documents = (
71
+ llm_response["source_documents"] if "source_documents" in llm_response else None
72
+ )
73
+ if source_documents is None:
74
+ source_documents = (
75
+ llm_response["sourceDocs"] if "sourceDocs" in llm_response else None
76
+ )
77
+
78
+ if debug_retrieval and source_documents is not None:
79
+ print("\nSources:")
80
+ for index, source in enumerate(source_documents):
81
+ metadata = source["metadata"] if "metadata" in source else source.metadata
82
+ if "page" in metadata:
83
+ print(f" Page: {metadata['page']}", end="")
84
+
85
+ print(
86
+ f" Source {index + 1}: "
87
+ + str(metadata["url"] if "url" in metadata else metadata["source"])
88
+ )
89
+ print(
90
+ source["page_content"]
91
+ if "page_content" in source
92
+ else source.page_content
93
+ )
94
+
95
+ if "chat_history" in llm_response:
96
+ print("\nChat History:")
97
+ print(llm_response["chat_history"])
98
+
99
+
100
+ def get_device_types():
101
+ print("Running on: ", platform.platform())
102
+ print("MPS is", "NOT" if not torch.backends.mps.is_available() else "", "available")
103
+ print("CUDA is", "NOT" if not torch.cuda.is_available() else "", "available")
104
+ device_type_available = "cpu"
105
+
106
+ if not torch.backends.mps.is_available():
107
+ if not torch.backends.mps.is_built():
108
+ print(
109
+ "MPS not available because the current PyTorch install was not "
110
+ "built with MPS enabled."
111
+ )
112
+ else:
113
+ print(
114
+ "MPS not available because the current MacOS version is not 12.3+ "
115
+ "and/or you do not have an MPS-enabled device on this machine."
116
+ )
117
+ else:
118
+ device_type_available = "mps"
119
+
120
+ if torch.cuda.is_available():
121
+ print("CUDA is available, we have found ", torch.cuda.device_count(), " GPU(s)")
122
+ print(torch.cuda.get_device_name(0))
123
+ print("CUDA version: " + torch.version.cuda)
124
+ device_type_available = f"cuda:{torch.cuda.current_device()}"
125
+
126
+ return (
127
+ os.environ.get("HF_EMBEDDINGS_DEVICE_TYPE") or device_type_available,
128
+ os.environ.get("HF_PIPELINE_DEVICE_TYPE") or device_type_available,
129
+ )
130
+
131
+
132
+ def ensure_model_is_downloaded(llm_model_type):
133
+ if llm_model_type.startswith("gpt4all"):
134
+ local_path = (
135
+ os.environ.get("GPT4ALL_J_MODEL_PATH")
136
+ if llm_model_type == "gpt4all-j"
137
+ else os.environ.get("GPT4ALL_MODEL_PATH")
138
+ )
139
+ url = (
140
+ os.environ.get("GPT4ALL_J_DOWNLOAD_LINK")
141
+ if llm_model_type == "gpt4all-j"
142
+ else os.environ.get("GPT4ALL_DOWNLOAD_LINK")
143
+ )
144
+ elif llm_model_type == "llamacpp":
145
+ local_path = os.environ.get("LLAMACPP_MODEL_PATH")
146
+ url = os.environ.get("LLAMACPP_DOWNLOAD_LINK")
147
+ elif llm_model_type == "ctransformers":
148
+ local_path = os.environ.get("CTRANSFORMERS_MODEL_PATH")
149
+ url = os.environ.get("CTRANSFORMERS_DOWNLOAD_LINK")
150
+ else:
151
+ raise ValueError(f"wrong model typle: {llm_model_type}")
152
+
153
+ path = Path(local_path)
154
+
155
+ if path.is_file():
156
+ print(f"model: {local_path} exists")
157
+ else:
158
+ print(f"downloading model: {local_path} from {url} ...")
159
+ path.parent.mkdir(parents=True, exist_ok=True)
160
+
161
+ # send a GET request to the URL to download the file. Stream since it's large
162
+ response = requests.get(url, stream=True)
163
+
164
+ # open the file in binary mode and write the contents of the response to it in chunks
165
+ # This is a large file, so be prepared to wait.
166
+ with open(local_path, "wb") as f:
167
+ for chunk in tqdm(response.iter_content(chunk_size=8192)):
168
+ if chunk:
169
+ f.write(chunk)
170
+
171
+ return local_path
172
+
173
+
174
+ bleu = evaluate.load("bleu")
175
+ rouge = evaluate.load("rouge")
176
+
177
+
178
+ def calc_bleu_rouge_scores(predictions, references, debug=False):
179
+ if debug:
180
+ print("predictions:", predictions)
181
+ print("references:", references)
182
+
183
+ bleu_scores = bleu.compute(
184
+ predictions=predictions, references=references, max_order=1
185
+ )
186
+ rouge_scores = rouge.compute(predictions=predictions, references=references)
187
+ result = {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
188
+
189
+ if debug:
190
+ print("result:", result)
191
+
192
+ return result
193
+
194
+
195
+ def calc_metrics(df):
196
+ predictions = [df["answer"][i] for i in range(len(df))]
197
+ references = [df["ground_truth"][i] for i in range(len(df))]
198
+
199
+ return calc_bleu_rouge_scores(predictions, references)
200
+
201
+
202
+ pattern_abnormal_newlines = re.compile(r"\n{5,}")
203
+ pattern_text_repetitions = re.compile(r"\b(\w.+?)\b(\1+)", re.M | re.DOTALL)
204
+ exception_pattern = re.compile(r"(\w+\.)\1")
205
+
206
+
207
+ # final version for repetition detection
208
+ def detect_repetitions(
209
+ text, debug=False, pattern_text_repetitions=pattern_text_repetitions
210
+ ):
211
+ subtotals = [0, 0]
212
+
213
+ if isinstance(text, str):
214
+ patterns = [pattern_abnormal_newlines, pattern_text_repetitions]
215
+ for i, pattern in enumerate(patterns):
216
+ if debug:
217
+ print(
218
+ f"----detect {'abnormal newlines' if i == 0 else 'text repetitions'}----"
219
+ )
220
+ matches = pattern.finditer(text)
221
+ for match in matches:
222
+ if debug:
223
+ print(match)
224
+ for groupNum in range(0, len(match.groups())):
225
+ groupNum = groupNum + 1
226
+ print(
227
+ "Group {groupNum} found at {start}-{end}: `{group}`".format(
228
+ groupNum=groupNum,
229
+ start=match.start(groupNum),
230
+ end=match.end(groupNum),
231
+ group=match.group(groupNum),
232
+ )
233
+ )
234
+
235
+ if exception_pattern.match(match[0]):
236
+ if debug:
237
+ print("ignored: ", match[0])
238
+ continue
239
+
240
+ start, end = match.span()
241
+ subtotals[i] += end - start
242
+
243
+ result = (subtotals[0], subtotals[1], subtotals[0] + subtotals[1])
244
+
245
+ if debug:
246
+ print(result)
247
+ return result
248
+
249
+
250
+ def detect_abnormal_newlines(text, debug=False):
251
+ return detect_repetitions(text, debug=debug)[0]
252
+
253
+
254
+ def detect_text_repetitions(text, debug=False):
255
+ return detect_repetitions(text, debug=debug)[1]
256
+
257
+
258
+ def detect_repetition_scores(text, debug=False):
259
+ newline_score, repetition_score, total_repetitions = detect_repetitions(
260
+ text, debug=debug
261
+ )
262
+ return pd.Series([newline_score, repetition_score, total_repetitions])
ms_macro.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1 +1,20 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub==0.26.0
2
+ nltk==3.8.1
3
+ langchain==0.1.16
4
+ langchain-openai==0.1.3
5
+ langchain_google_genai==1.0.2
6
+ transformers==4.40.1
7
+ accelerate==0.29.3
8
+ python-dotenv==1.0.1
9
+ gradio==4.44.1
10
+ black==24.4.0
11
+ InstructorEmbedding==1.0.1
12
+ sentence-transformers==2.2.2
13
+ chardet==5.2.0
14
+ sentencepiece==0.1.98
15
+ evaluate==0.4.3
16
+ rouge_score==0.1.2
17
+ pytest==8.2.1
18
+ seaborn==0.13.2
19
+ tenacity==8.3.0
20
+ bert_score==0.3.13