Spaces:
Sleeping
Sleeping
initial working code
Browse files- .env.example +81 -0
- .gitignore +149 -0
- app.py +109 -25
- eval_modules/calc_repetitions_v2e.py +1333 -0
- eval_modules/utils.py +262 -0
- ms_macro.json +0 -0
- requirements.txt +20 -1
.env.example
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLM_MODEL_TYPE=huggingface
|
2 |
+
# LLM_MODEL_TYPE=openai
|
3 |
+
# LLM_MODEL_TYPE=hftgi
|
4 |
+
# LLM_MODEL_TYPE=ollama
|
5 |
+
# LLM_MODEL_TYPE=google
|
6 |
+
# LLM_MODEL_TYPE=vllm
|
7 |
+
|
8 |
+
HUGGINGFACE_AUTH_TOKEN=
|
9 |
+
|
10 |
+
HFTGI_SERVER_URL=
|
11 |
+
|
12 |
+
OPENAI_API_KEY=
|
13 |
+
|
14 |
+
GOOGLE_API_KEY=
|
15 |
+
|
16 |
+
# if unset, default to "gpt-3.5-turbo"
|
17 |
+
OPENAI_MODEL_NAME=
|
18 |
+
|
19 |
+
# GEMINI_MODEL_NAME=gemini-1.5-pro-latest
|
20 |
+
|
21 |
+
# OLLAMA_MODEL_NAME=orca2:7b
|
22 |
+
# OLLAMA_MODEL_NAME=mistral:7b
|
23 |
+
# OLLAMA_MODEL_NAME=gemma:7b
|
24 |
+
# OLLAMA_MODEL_NAME=llama2:7b
|
25 |
+
OLLAMA_MODEL_NAME=llama3:8b
|
26 |
+
|
27 |
+
OLLAMA_RP=1.15
|
28 |
+
HF_RP=1.15
|
29 |
+
|
30 |
+
LANGCHAIN_DEBUG=false
|
31 |
+
BATCH_SIZE=1
|
32 |
+
APPLY_CHAT_TEMPLATE_FOR_RAG=true
|
33 |
+
|
34 |
+
# cpu, mps or cuda:0 - if unset, use whatever detected
|
35 |
+
HF_EMBEDDINGS_DEVICE_TYPE=
|
36 |
+
HF_PIPELINE_DEVICE_TYPE=
|
37 |
+
|
38 |
+
# uncomment one of the below to load corresponding quantized model
|
39 |
+
# LOAD_QUANTIZED_MODEL=4bit
|
40 |
+
# LOAD_QUANTIZED_MODEL=8bit
|
41 |
+
|
42 |
+
QA_WITH_RAG=true
|
43 |
+
# QA_WITH_RAG=false
|
44 |
+
|
45 |
+
RETRIEVER_TYPE=questions_file
|
46 |
+
# RETRIEVER_TYPE=vectorstore
|
47 |
+
|
48 |
+
QUESTIONS_FILE_PATH="./data/datasets/ms_macro.json"
|
49 |
+
|
50 |
+
DISABLE_MODEL_PRELOADING=true
|
51 |
+
CHAT_HISTORY_ENABLED=false
|
52 |
+
SHOW_PARAM_SETTINGS=false
|
53 |
+
SHARE_GRADIO_APP=false
|
54 |
+
|
55 |
+
# if unset, default to "hkunlp/instructor-xl"
|
56 |
+
HF_EMBEDDINGS_MODEL_NAME="hkunlp/instructor-large"
|
57 |
+
|
58 |
+
# number of cpu cores - used to set n_threads for GPT4ALL & LlamaCpp models
|
59 |
+
NUMBER_OF_CPU_CORES=
|
60 |
+
|
61 |
+
USING_TORCH_BFLOAT16=true
|
62 |
+
|
63 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-3b"
|
64 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-7b"
|
65 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="databricks/dolly-v2-12b"
|
66 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="TheBloke/wizardLM-7B-HF"
|
67 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="TheBloke/vicuna-7B-1.1-HF"
|
68 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="nomic-ai/gpt4all-j"
|
69 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="nomic-ai/gpt4all-falcon"
|
70 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="lmsys/fastchat-t5-3b-v1.0"
|
71 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-7b-chat-hf"
|
72 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-13b-chat-hf"
|
73 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Llama-2-70b-chat-hf"
|
74 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-8B-Instruct"
|
75 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="meta-llama/Meta-Llama-3-70B-Instruct"
|
76 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Orca-2-7b"
|
77 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Orca-2-13b"
|
78 |
+
HUGGINGFACE_MODEL_NAME_OR_PATH="google/gemma-1.1-2b-it"
|
79 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="google/gemma-1.1-7b-it"
|
80 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-128k-instruct"
|
81 |
+
# HUGGINGFACE_MODEL_NAME_OR_PATH="mistralai/Mistral-7B-Instruct-v0.2"
|
.gitignore
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.out
|
2 |
+
*.log
|
3 |
+
pdfs/
|
4 |
+
.vscode/
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
pip-wheel-metadata/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
# *.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
target/
|
81 |
+
|
82 |
+
# Jupyter Notebook
|
83 |
+
.ipynb_checkpoints
|
84 |
+
|
85 |
+
# IPython
|
86 |
+
profile_default/
|
87 |
+
ipython_config.py
|
88 |
+
|
89 |
+
# pyenv
|
90 |
+
.python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
100 |
+
__pypackages__/
|
101 |
+
|
102 |
+
# Celery stuff
|
103 |
+
celerybeat-schedule
|
104 |
+
celerybeat.pid
|
105 |
+
|
106 |
+
# SageMath parsed files
|
107 |
+
*.sage.py
|
108 |
+
|
109 |
+
# Environments
|
110 |
+
.env
|
111 |
+
.venv
|
112 |
+
env/
|
113 |
+
venv/
|
114 |
+
ENV/
|
115 |
+
env.bak/
|
116 |
+
venv.bak/
|
117 |
+
|
118 |
+
# Spyder project settings
|
119 |
+
.spyderproject
|
120 |
+
.spyproject
|
121 |
+
|
122 |
+
# Rope project settings
|
123 |
+
.ropeproject
|
124 |
+
|
125 |
+
# mkdocs documentation
|
126 |
+
/site
|
127 |
+
|
128 |
+
# mypy
|
129 |
+
.mypy_cache/
|
130 |
+
.dmypy.json
|
131 |
+
dmypy.json
|
132 |
+
|
133 |
+
# Pyre type checker
|
134 |
+
.pyre/
|
135 |
+
|
136 |
+
# JetBrains
|
137 |
+
.idea
|
138 |
+
|
139 |
+
*.db
|
140 |
+
|
141 |
+
.DS_Store
|
142 |
+
|
143 |
+
vectorstore.pkl
|
144 |
+
langchain.readthedocs.io/
|
145 |
+
|
146 |
+
models/
|
147 |
+
data/logs/hftgi-2024-03-18.txt
|
148 |
+
qa_*_all_results.csv
|
149 |
+
qa_*_test_results.csv
|
app.py
CHANGED
@@ -1,54 +1,141 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
"""
|
5 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
6 |
"""
|
7 |
-
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
|
|
|
|
8 |
|
9 |
|
10 |
-
def
|
11 |
message,
|
12 |
history: list[tuple[str, str]],
|
13 |
system_message,
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
):
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
if val[0]:
|
22 |
-
messages.append({"role": "user", "content": val[0]})
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
|
|
|
26 |
messages.append({"role": "user", "content": message})
|
27 |
|
28 |
-
|
29 |
|
|
|
30 |
for message in client.chat_completion(
|
31 |
messages,
|
32 |
max_tokens=max_tokens,
|
33 |
stream=True,
|
34 |
temperature=temperature,
|
|
|
|
|
35 |
top_p=top_p,
|
|
|
36 |
):
|
37 |
-
|
|
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
-
"""
|
44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
45 |
-
"""
|
46 |
demo = gr.ChatInterface(
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
48 |
additional_inputs=[
|
49 |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
50 |
-
gr.Slider(
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
gr.Slider(
|
53 |
minimum=0.1,
|
54 |
maximum=1.0,
|
@@ -58,7 +145,4 @@ demo = gr.ChatInterface(
|
|
58 |
),
|
59 |
],
|
60 |
)
|
61 |
-
|
62 |
-
|
63 |
-
if __name__ == "__main__":
|
64 |
-
demo.launch()
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
import gradio as gr
|
4 |
from huggingface_hub import InferenceClient
|
5 |
+
from eval_modules.utils import calc_bleu_rouge_scores
|
6 |
+
from eval_modules.calc_repetitions_v2e import detect_repetitions
|
7 |
+
|
8 |
+
questions_file_path = os.getenv("QUESTIONS_FILE_PATH") or "./ms_macro.json"
|
9 |
+
|
10 |
+
questions = json.loads(open(questions_file_path).read())
|
11 |
+
examples = [[question["question"].strip()] for question in questions]
|
12 |
+
print(f"Loaded {len(examples)} examples")
|
13 |
+
|
14 |
+
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
|
15 |
|
16 |
"""
|
17 |
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
18 |
"""
|
19 |
+
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
20 |
+
# client = InferenceClient("HuggingFaceH4/zephyr-7b-gemma-v0.1")
|
21 |
+
client = InferenceClient("microsoft/Phi-3.5-mini-instruct")
|
22 |
|
23 |
|
24 |
+
def chat(
|
25 |
message,
|
26 |
history: list[tuple[str, str]],
|
27 |
system_message,
|
28 |
+
temperature=0,
|
29 |
+
frequency_penalty=0,
|
30 |
+
presence_penalty=0,
|
31 |
+
max_tokens=256,
|
32 |
+
top_p=0.95,
|
33 |
):
|
34 |
+
chat = []
|
35 |
+
for item in history:
|
36 |
+
chat.append({"role": "user", "content": item[0]})
|
37 |
+
if item[1] is not None:
|
38 |
+
chat.append({"role": "assistant", "content": item[1]})
|
39 |
+
|
40 |
+
index = -1
|
41 |
+
if [message] in examples:
|
42 |
+
index = examples.index([message])
|
43 |
+
message = f"{qa_system_prompt}\n\n{questions[index]['context']}\n\nQuestion: {message}"
|
44 |
+
print("RAG prompt:", message)
|
45 |
|
46 |
+
chat.append({"role": "user", "content": message})
|
|
|
|
|
|
|
|
|
47 |
|
48 |
+
messages = [{"role": "system", "content": system_message}]
|
49 |
messages.append({"role": "user", "content": message})
|
50 |
|
51 |
+
partial_text = ""
|
52 |
|
53 |
+
finish_reason = None
|
54 |
for message in client.chat_completion(
|
55 |
messages,
|
56 |
max_tokens=max_tokens,
|
57 |
stream=True,
|
58 |
temperature=temperature,
|
59 |
+
frequency_penalty=None, # frequency_penalty,
|
60 |
+
presence_penalty=None, # presence_penalty,
|
61 |
top_p=top_p,
|
62 |
+
seed=42,
|
63 |
):
|
64 |
+
finish_reason = message.choices[0].finish_reason
|
65 |
+
# print("finish_reason:", finish_reason)
|
66 |
|
67 |
+
if finish_reason is None:
|
68 |
+
new_text = message.choices[0].delta.content
|
69 |
+
partial_text += new_text
|
70 |
+
yield partial_text
|
71 |
+
else:
|
72 |
+
break
|
73 |
+
|
74 |
+
answer = partial_text
|
75 |
+
(whitespace_score, repetition_score, total_repetitions) = detect_repetitions(answer)
|
76 |
+
partial_text += "\n\nRepetition Metrics:\n"
|
77 |
+
partial_text += f"1. Whitespace Score: {whitespace_score:.3f}\n"
|
78 |
+
partial_text += f"1. Repetition Score: {repetition_score:.3f}\n"
|
79 |
+
partial_text += f"1. Total Repetitions: {total_repetitions:.3f}\n"
|
80 |
+
partial_text += (
|
81 |
+
f"1. Non-Repetitive Ratio: {1 - total_repetitions / len(answer):.3f}\n"
|
82 |
+
)
|
83 |
+
|
84 |
+
if index >= 0: # RAG
|
85 |
+
key = (
|
86 |
+
"wellFormedAnswers"
|
87 |
+
if "wellFormedAnswers" in questions[index]
|
88 |
+
else "answers"
|
89 |
+
)
|
90 |
+
scores = calc_bleu_rouge_scores([answer], [questions[index][key]], debug=True)
|
91 |
+
|
92 |
+
partial_text += "\n\n Performance Metrics:\n"
|
93 |
+
partial_text += f'1. BLEU-1: {scores["bleu_scores"]["bleu"]:.3f}\n'
|
94 |
+
partial_text += f'1. RougeL: {scores["rouge_scores"]["rougeL"]:.3f}\n'
|
95 |
+
|
96 |
+
partial_text += f"\n\nGround truth: {questions[index][key][0]}\n"
|
97 |
+
|
98 |
+
partial_text += f"\n\nThe text generation has ended because: {finish_reason}\n"
|
99 |
+
|
100 |
+
yield partial_text
|
101 |
|
102 |
|
|
|
|
|
|
|
103 |
demo = gr.ChatInterface(
|
104 |
+
fn=chat,
|
105 |
+
examples=examples,
|
106 |
+
cache_examples=False,
|
107 |
+
additional_inputs_accordion=gr.Accordion(
|
108 |
+
label="⚙️ Parameters", open=False, render=False
|
109 |
+
),
|
110 |
additional_inputs=[
|
111 |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
112 |
+
gr.Slider(
|
113 |
+
minimum=0, maximum=2, step=0.1, value=0, label="Temperature", render=False
|
114 |
+
),
|
115 |
+
gr.Slider(
|
116 |
+
minimum=-2,
|
117 |
+
maximum=2,
|
118 |
+
step=0.1,
|
119 |
+
value=0,
|
120 |
+
label="Frequency Penalty",
|
121 |
+
render=False,
|
122 |
+
),
|
123 |
+
gr.Slider(
|
124 |
+
minimum=-2,
|
125 |
+
maximum=2,
|
126 |
+
step=0.1,
|
127 |
+
value=0,
|
128 |
+
label="Presence Penalty",
|
129 |
+
render=False,
|
130 |
+
),
|
131 |
+
gr.Slider(
|
132 |
+
minimum=128,
|
133 |
+
maximum=4096,
|
134 |
+
step=1,
|
135 |
+
value=512,
|
136 |
+
label="Max new tokens",
|
137 |
+
render=False,
|
138 |
+
),
|
139 |
gr.Slider(
|
140 |
minimum=0.1,
|
141 |
maximum=1.0,
|
|
|
145 |
),
|
146 |
],
|
147 |
)
|
148 |
+
demo.launch()
|
|
|
|
|
|
eval_modules/calc_repetitions_v2e.py
ADDED
@@ -0,0 +1,1333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import math
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib.ticker as mtick
|
8 |
+
import seaborn as sns
|
9 |
+
import nltk
|
10 |
+
import evaluate
|
11 |
+
import traceback
|
12 |
+
|
13 |
+
bert_score = evaluate.load("bertscore")
|
14 |
+
meteor = evaluate.load("meteor")
|
15 |
+
|
16 |
+
print(f"loading: {__file__}")
|
17 |
+
|
18 |
+
# pattern_non_word_char_repetition = re.compile(r"\s{5,}")
|
19 |
+
# pattern_text_repetitions = re.compile(r"(.{5}.*)\s*((\1)\s*)+", re.M | re.DOTALL)
|
20 |
+
|
21 |
+
# final version
|
22 |
+
pattern_non_word_char_repetition = re.compile(r"[\s\W]{5,}")
|
23 |
+
pattern_text_repetitions = re.compile(
|
24 |
+
r"(?P<repeat>.{5}.*?)(?:[\s\W]*(?P=repeat))+", re.M | re.DOTALL | re.IGNORECASE
|
25 |
+
)
|
26 |
+
# Explanation of the Regex Pattern:
|
27 |
+
# (?P<repeat>.{5}.*?): Captures any sequence of characters with minimal length of 5 and names this group repeat.
|
28 |
+
# .*?: Matches zero or more characters, non-greedily (as few as possible).
|
29 |
+
# (?:[\s\W]+(?P=repeat))+: A non-capturing group that matches one or more repetitions of:
|
30 |
+
# [\s\W]+: One or more whitespace or non-word characters (spaces, punctuation, etc.).
|
31 |
+
# (?P=repeat): A backreference to the named group repeat.
|
32 |
+
|
33 |
+
|
34 |
+
def del_non_word_char_repetition(text, debug=False):
|
35 |
+
count = 0
|
36 |
+
|
37 |
+
if isinstance(text, str):
|
38 |
+
if debug:
|
39 |
+
print("----detect non-word characters repetition----")
|
40 |
+
count = len(text)
|
41 |
+
text = pattern_non_word_char_repetition.sub("\t", text)
|
42 |
+
count -= len(text)
|
43 |
+
if debug and count:
|
44 |
+
print(f"removed non-word characters repetition: {count}")
|
45 |
+
return text, count
|
46 |
+
|
47 |
+
|
48 |
+
# final version for repetition detection
|
49 |
+
def detect_text_repetitions(text, debug=False):
|
50 |
+
count = 0
|
51 |
+
|
52 |
+
if isinstance(text, str):
|
53 |
+
if debug:
|
54 |
+
print("----detect text repetitions----")
|
55 |
+
matches = pattern_text_repetitions.finditer(text)
|
56 |
+
for match in matches:
|
57 |
+
if debug:
|
58 |
+
print(match)
|
59 |
+
for groupNum in range(0, len(match.groups())):
|
60 |
+
groupNum = groupNum + 1
|
61 |
+
print(
|
62 |
+
"Group {groupNum} found at {start}-{end}: `{group}`".format(
|
63 |
+
groupNum=groupNum,
|
64 |
+
start=match.start(groupNum),
|
65 |
+
end=match.end(groupNum),
|
66 |
+
group=match.group(groupNum),
|
67 |
+
)
|
68 |
+
)
|
69 |
+
|
70 |
+
start, end = match.span()
|
71 |
+
count += end - start - len(match.group(1))
|
72 |
+
|
73 |
+
return count
|
74 |
+
|
75 |
+
|
76 |
+
def detect_repetitions(text, debug=False):
|
77 |
+
if isinstance(text, str) is False:
|
78 |
+
return 0, 0, 0
|
79 |
+
text, count_non_word_char_repetition = del_non_word_char_repetition(
|
80 |
+
text, debug=debug
|
81 |
+
)
|
82 |
+
count_text_repetitions = detect_text_repetitions(text, debug=debug)
|
83 |
+
total_repetitions = count_non_word_char_repetition + count_text_repetitions
|
84 |
+
|
85 |
+
result = (count_non_word_char_repetition, count_text_repetitions, total_repetitions)
|
86 |
+
|
87 |
+
if debug:
|
88 |
+
print(result)
|
89 |
+
return result
|
90 |
+
|
91 |
+
|
92 |
+
def detect_scores(
|
93 |
+
row, debug=False, answer_col="answer", ground_truth_col="ground_truth"
|
94 |
+
):
|
95 |
+
newline_score, repetition_score, total_repetitions = detect_repetitions(
|
96 |
+
row[answer_col], debug=debug
|
97 |
+
)
|
98 |
+
|
99 |
+
if ground_truth_col:
|
100 |
+
ground_truth_newline_score, ground_truth_repetition_score, _ = (
|
101 |
+
detect_repetitions(row[ground_truth_col], debug=debug)
|
102 |
+
)
|
103 |
+
|
104 |
+
newline_score -= ground_truth_newline_score
|
105 |
+
if newline_score < 0:
|
106 |
+
newline_score = 0
|
107 |
+
|
108 |
+
repetition_score -= ground_truth_repetition_score
|
109 |
+
if repetition_score < 0:
|
110 |
+
repetition_score = 0
|
111 |
+
|
112 |
+
total_repetitions = newline_score + repetition_score
|
113 |
+
|
114 |
+
return pd.Series([newline_score, repetition_score, total_repetitions])
|
115 |
+
|
116 |
+
|
117 |
+
def load_with_newline_and_repetition_scores(result_file, force_recalculate=False):
|
118 |
+
print(f"loading result file: {result_file}")
|
119 |
+
df = pd.read_csv(result_file, comment="#", on_bad_lines="warn")
|
120 |
+
|
121 |
+
if (
|
122 |
+
force_recalculate
|
123 |
+
or "newline_score" not in df.columns
|
124 |
+
or "repetition_score" not in df.columns
|
125 |
+
or "total_repetitions" not in df.columns
|
126 |
+
or "nrr" not in df.columns
|
127 |
+
or "rr" not in df.columns
|
128 |
+
):
|
129 |
+
if (
|
130 |
+
force_recalculate
|
131 |
+
or "newline_score" not in df.columns
|
132 |
+
or "repetition_score" not in df.columns
|
133 |
+
or "total_repetitions" not in df.columns
|
134 |
+
):
|
135 |
+
df[["newline_score", "repetition_score", "total_repetitions"]] = df.apply(
|
136 |
+
detect_scores, axis=1
|
137 |
+
)
|
138 |
+
|
139 |
+
df["answer_len"] = df["answer"].apply(
|
140 |
+
lambda x: len(x) if isinstance(x, str) else 0
|
141 |
+
)
|
142 |
+
|
143 |
+
df["nrr"] = df.apply(
|
144 |
+
lambda x: (
|
145 |
+
1
|
146 |
+
if x["answer_len"] == 0
|
147 |
+
else 1 - (x["newline_score"] + x["repetition_score"]) / x["answer_len"]
|
148 |
+
),
|
149 |
+
axis=1,
|
150 |
+
)
|
151 |
+
|
152 |
+
df["rr"] = df["nrr"].apply(lambda x: 1 - x)
|
153 |
+
|
154 |
+
df.to_csv(result_file, index=False)
|
155 |
+
|
156 |
+
return df
|
157 |
+
|
158 |
+
|
159 |
+
def replace_last(source_string, old_string, new_string):
|
160 |
+
head, _sep, tail = source_string.rpartition(old_string)
|
161 |
+
return head + new_string + tail
|
162 |
+
|
163 |
+
|
164 |
+
def load_for_repetition_penalty(
|
165 |
+
csv_result_file, repetition_penalty, force_recalculate=False
|
166 |
+
):
|
167 |
+
result_file = replace_last(
|
168 |
+
csv_result_file, ".csv", f"_RP_{repetition_penalty:.3f}.csv"
|
169 |
+
)
|
170 |
+
return load_with_newline_and_repetition_scores(
|
171 |
+
result_file, force_recalculate=force_recalculate
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
rap_penalty_functions = {
|
176 |
+
"linear": lambda x: x,
|
177 |
+
"quadratic": lambda x: x * x,
|
178 |
+
"cubic": lambda x: x * x * x,
|
179 |
+
"logarithmic": lambda x: math.log(x + 1, 2),
|
180 |
+
"exponential": lambda x: math.exp(x - 1),
|
181 |
+
}
|
182 |
+
|
183 |
+
|
184 |
+
def calc_adjusted_performance(f, r, l=1, penalty_function="cubic"):
|
185 |
+
n = 1 - r / l if l > 0 else 0
|
186 |
+
return f * rap_penalty_functions[penalty_function](n)
|
187 |
+
|
188 |
+
|
189 |
+
def calculate_adjusted_performance(row):
|
190 |
+
r = row["total_repetitions"]
|
191 |
+
l = row["answer_len"]
|
192 |
+
adjusted_precision = calc_adjusted_performance(row["precision"], r, l)
|
193 |
+
adjusted_recall = calc_adjusted_performance(row["recall"], r, l)
|
194 |
+
return pd.Series([adjusted_precision, adjusted_recall])
|
195 |
+
|
196 |
+
|
197 |
+
def load_performance_df(csv_result_file, repetition_penalty):
|
198 |
+
result_file = replace_last(
|
199 |
+
csv_result_file, ".csv", f"_RP_{repetition_penalty:.3f}-t2_evaluated.json"
|
200 |
+
)
|
201 |
+
result_file = result_file.replace("/results/", "/eval/")
|
202 |
+
print(f"loading json file: {result_file}")
|
203 |
+
df = pd.read_json(result_file)
|
204 |
+
|
205 |
+
return df
|
206 |
+
|
207 |
+
|
208 |
+
def calculate_performance_score(
|
209 |
+
csv_result_file, repetition_penalty, force_recalculate=False
|
210 |
+
):
|
211 |
+
result_file = replace_last(
|
212 |
+
csv_result_file, ".csv", f"_rpp_{repetition_penalty:.2f}.csv"
|
213 |
+
)
|
214 |
+
|
215 |
+
if os.path.exists(result_file):
|
216 |
+
print(f"loading result file: {result_file}")
|
217 |
+
df = load_with_newline_and_repetition_scores(
|
218 |
+
result_file, force_recalculate=force_recalculate
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
print(f"re-creating result file: {result_file}")
|
222 |
+
df = pd.DataFrame()
|
223 |
+
force_recalculate = True
|
224 |
+
|
225 |
+
if force_recalculate or "f2" in df.columns or "f1" not in df.columns:
|
226 |
+
try:
|
227 |
+
perf_df = load_performance_df(csv_result_file, repetition_penalty)
|
228 |
+
df.drop(
|
229 |
+
columns=[
|
230 |
+
"precision",
|
231 |
+
"recall",
|
232 |
+
"f1",
|
233 |
+
"f2",
|
234 |
+
"entities_in_answer",
|
235 |
+
"entities_in_question",
|
236 |
+
"word_count",
|
237 |
+
],
|
238 |
+
errors="ignore",
|
239 |
+
inplace=True,
|
240 |
+
)
|
241 |
+
|
242 |
+
df["id"] = perf_df["id"]
|
243 |
+
df["question"] = perf_df["question"]
|
244 |
+
df["answer"] = perf_df["pred_answer"]
|
245 |
+
df["word_count"] = df["answer"].apply(
|
246 |
+
lambda x: len(nltk.word_tokenize(x)) if isinstance(x, str) else 0
|
247 |
+
)
|
248 |
+
df["ground_truth"] = perf_df["ground_truth"]
|
249 |
+
|
250 |
+
df["eval_gemini_1.0_pro"] = perf_df["eval_gemini_1.0_pro"]
|
251 |
+
df["precision"] = perf_df["score"].apply(lambda x: x[0])
|
252 |
+
df["recall"] = perf_df["score"].apply(lambda x: x[1])
|
253 |
+
df["f1"] = perf_df["score"].apply(lambda x: x[2])
|
254 |
+
except Exception as e:
|
255 |
+
print(f"\tignored error: {e}")
|
256 |
+
# traceback.print_exc()
|
257 |
+
|
258 |
+
df[["newline_score", "repetition_score", "total_repetitions"]] = df.apply(
|
259 |
+
detect_scores, axis=1
|
260 |
+
)
|
261 |
+
df["answer_len"] = df["answer"].apply(
|
262 |
+
lambda x: len(x) if isinstance(x, str) else 0
|
263 |
+
)
|
264 |
+
|
265 |
+
df[["adjusted_precision", "adjusted_recall"]] = df.apply(
|
266 |
+
calculate_adjusted_performance, axis=1
|
267 |
+
)
|
268 |
+
|
269 |
+
df.to_csv(result_file, index=False)
|
270 |
+
print(f"performance scores saved to result file: {result_file}")
|
271 |
+
|
272 |
+
# print(f"df len: {len(df)}")
|
273 |
+
|
274 |
+
return df
|
275 |
+
|
276 |
+
|
277 |
+
def adjust_perf_scores_with_repetition_penalty(result, precision, recall):
|
278 |
+
newline_score = [
|
279 |
+
df["newline_score"].mean() for df in result["df_list_repetition_penalty"]
|
280 |
+
]
|
281 |
+
|
282 |
+
repetition_score = [
|
283 |
+
df["repetition_score"].mean() for df in result["df_list_repetition_penalty"]
|
284 |
+
]
|
285 |
+
|
286 |
+
answer_len = [
|
287 |
+
df["answer_len"].mean() for df in result["df_list_repetition_penalty"]
|
288 |
+
]
|
289 |
+
|
290 |
+
precision = [
|
291 |
+
calc_adjusted_performance(f, n + r, l)
|
292 |
+
for f, n, r, l in zip(precision, newline_score, repetition_score, answer_len)
|
293 |
+
]
|
294 |
+
recall = [
|
295 |
+
calc_adjusted_performance(f, n + r, l)
|
296 |
+
for f, n, r, l in zip(recall, newline_score, repetition_score, answer_len)
|
297 |
+
]
|
298 |
+
|
299 |
+
return precision, recall
|
300 |
+
|
301 |
+
|
302 |
+
def plot_performance_scores(
|
303 |
+
result,
|
304 |
+
models=None,
|
305 |
+
title="Performance",
|
306 |
+
):
|
307 |
+
if models is None:
|
308 |
+
models = result.keys()
|
309 |
+
for model in models:
|
310 |
+
print(f"model: {model}")
|
311 |
+
df = result[model]["df_overall"]
|
312 |
+
|
313 |
+
# Calculate the statistics
|
314 |
+
precision = [
|
315 |
+
df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
|
316 |
+
]
|
317 |
+
recall = [
|
318 |
+
df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
|
319 |
+
]
|
320 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
321 |
+
best_f1 = max(f1)
|
322 |
+
best_f1_index = f1.index(best_f1)
|
323 |
+
|
324 |
+
precision, recall = adjust_perf_scores_with_repetition_penalty(
|
325 |
+
result[model], precision, recall
|
326 |
+
)
|
327 |
+
afrp = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
328 |
+
|
329 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
330 |
+
best_afrp = max(afrp)
|
331 |
+
best_afrp_index = afrp.index(best_afrp)
|
332 |
+
|
333 |
+
adjusted_precision = [
|
334 |
+
df["adjusted_precision"].mean()
|
335 |
+
for df in result[model]["df_list_repetition_penalty"]
|
336 |
+
]
|
337 |
+
adjusted_recall = [
|
338 |
+
df["adjusted_recall"].mean()
|
339 |
+
for df in result[model]["df_list_repetition_penalty"]
|
340 |
+
]
|
341 |
+
afrp2 = [
|
342 |
+
2 * (p * r) / (p + r) for p, r in zip(adjusted_precision, adjusted_recall)
|
343 |
+
]
|
344 |
+
best_afrp2 = max(afrp2)
|
345 |
+
best_afrp2_index = afrp2.index(best_afrp2)
|
346 |
+
|
347 |
+
repetition_penalties = list(df["repetition_penalty"])
|
348 |
+
|
349 |
+
# line plot for precision, recall, f1
|
350 |
+
plt.figure(figsize=(10, 6))
|
351 |
+
|
352 |
+
plt.axvspan(
|
353 |
+
repetition_penalties[best_f1_index] - 0.01,
|
354 |
+
repetition_penalties[best_f1_index] + 0.01,
|
355 |
+
alpha=0.5,
|
356 |
+
edgecolor="none",
|
357 |
+
facecolor="blue",
|
358 |
+
)
|
359 |
+
|
360 |
+
# plt.axvspan(
|
361 |
+
# repetition_penalties[best_afrp2_index] - 0.01,
|
362 |
+
# repetition_penalties[best_afrp2_index] + 0.01,
|
363 |
+
# alpha=0.5,
|
364 |
+
# edgecolor="none",
|
365 |
+
# facecolor="green",
|
366 |
+
# )
|
367 |
+
|
368 |
+
plt.axvspan(
|
369 |
+
repetition_penalties[best_afrp_index] - 0.01,
|
370 |
+
repetition_penalties[best_afrp_index] + 0.01,
|
371 |
+
alpha=0.5,
|
372 |
+
edgecolor="none",
|
373 |
+
facecolor="orange",
|
374 |
+
)
|
375 |
+
|
376 |
+
plt.plot(repetition_penalties, f1, label="F1", marker="D", color="blue")
|
377 |
+
# plt.plot(
|
378 |
+
# repetition_penalties,
|
379 |
+
# afrp2,
|
380 |
+
# label="Per-question RAP - F1",
|
381 |
+
# marker="s",
|
382 |
+
# color="green",
|
383 |
+
# )
|
384 |
+
plt.plot(
|
385 |
+
repetition_penalties,
|
386 |
+
afrp,
|
387 |
+
label="RAP - F1",
|
388 |
+
marker="o",
|
389 |
+
color="orange",
|
390 |
+
)
|
391 |
+
plt.xlabel("Repetition Penalties")
|
392 |
+
plt.ylabel("Score")
|
393 |
+
# plt.xlim(0.99, 1.31)
|
394 |
+
# y in percentage
|
395 |
+
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
396 |
+
plt.title(f"{model} {title}")
|
397 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
398 |
+
|
399 |
+
plt.show()
|
400 |
+
|
401 |
+
|
402 |
+
def plot_best_afrp(
|
403 |
+
result,
|
404 |
+
models=None,
|
405 |
+
title="Models with Best RAP - F1",
|
406 |
+
ref_result=None,
|
407 |
+
):
|
408 |
+
# Initialize lists to store the statistics
|
409 |
+
model_names = []
|
410 |
+
best_f1 = []
|
411 |
+
best_afrp = []
|
412 |
+
best_repetition_penalty = []
|
413 |
+
best_mtr = []
|
414 |
+
|
415 |
+
if models is None:
|
416 |
+
models = result.keys()
|
417 |
+
for model in models:
|
418 |
+
print(f"model: {model}")
|
419 |
+
df = result[model]["df_overall"]
|
420 |
+
|
421 |
+
# Calculate the statistics
|
422 |
+
precision = [
|
423 |
+
df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
|
424 |
+
]
|
425 |
+
recall = [
|
426 |
+
df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
|
427 |
+
]
|
428 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
429 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
430 |
+
|
431 |
+
newline_score = [
|
432 |
+
df["newline_score"].mean()
|
433 |
+
for df in result[model]["df_list_repetition_penalty"]
|
434 |
+
]
|
435 |
+
# print(f"newline_score: {newline_score}")
|
436 |
+
|
437 |
+
repetition_score = [
|
438 |
+
df["repetition_score"].mean()
|
439 |
+
for df in result[model]["df_list_repetition_penalty"]
|
440 |
+
]
|
441 |
+
# print(f"repetition_score: {repetition_score}")
|
442 |
+
|
443 |
+
answer_len = [
|
444 |
+
df["answer_len"].mean()
|
445 |
+
for df in result[model]["df_list_repetition_penalty"]
|
446 |
+
]
|
447 |
+
|
448 |
+
afrp = [
|
449 |
+
calc_adjusted_performance(f, n + r, l)
|
450 |
+
for f, n, r, l in zip(f1, newline_score, repetition_score, answer_len)
|
451 |
+
]
|
452 |
+
|
453 |
+
best_afrp.append(max(afrp))
|
454 |
+
best_afrp_index = afrp.index(best_afrp[-1])
|
455 |
+
best_repetition_penalty.append(df["repetition_penalty"][best_afrp_index])
|
456 |
+
|
457 |
+
best_f1.append(f1[best_afrp_index])
|
458 |
+
best_mtr.append(
|
459 |
+
newline_score[best_afrp_index] + repetition_score[best_afrp_index]
|
460 |
+
)
|
461 |
+
|
462 |
+
# print(
|
463 |
+
# f"best repetition penalty: {best_repetition_penalty[-1]}, best afrp: {best_afrp[-1]}, f1: {best_f1[-1]}"
|
464 |
+
# )
|
465 |
+
|
466 |
+
df = result[model]["df_list_repetition_penalty"][best_afrp_index]
|
467 |
+
|
468 |
+
model_names.append(
|
469 |
+
f"{model} (RP={best_repetition_penalty[-1]})"
|
470 |
+
) # Add the model name to the list
|
471 |
+
|
472 |
+
if ref_result is not None:
|
473 |
+
print("ref_result:", ref_result)
|
474 |
+
for model in ref_result.keys():
|
475 |
+
model_names.append(model)
|
476 |
+
df = pd.read_csv(ref_result[model])
|
477 |
+
# df = df[df["id"].isin(wikidata_df["id"])]
|
478 |
+
|
479 |
+
p = df["precision"].mean()
|
480 |
+
r = df["recall"].mean()
|
481 |
+
|
482 |
+
f1 = 2 * p * r / (p + r) if p + r > 0 else 0
|
483 |
+
best_f1.append(f1)
|
484 |
+
best_afrp.append(f1)
|
485 |
+
best_mtr.append(0)
|
486 |
+
|
487 |
+
print("model_names:", model_names)
|
488 |
+
# print("best_f1:", best_f1)
|
489 |
+
# print("best_afrp:", best_afrp)
|
490 |
+
|
491 |
+
# Create a DataFrame with the statistics
|
492 |
+
data = pd.DataFrame(
|
493 |
+
{
|
494 |
+
"Model": model_names,
|
495 |
+
"RAP - F1": best_afrp,
|
496 |
+
"F1": best_f1,
|
497 |
+
}
|
498 |
+
)
|
499 |
+
|
500 |
+
# Melt the DataFrame to a long format
|
501 |
+
data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
|
502 |
+
|
503 |
+
# Pivot the DataFrame to a wide format
|
504 |
+
data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
|
505 |
+
|
506 |
+
# make sure the columns are following the order of the models
|
507 |
+
data_pivoted = data_pivoted[model_names]
|
508 |
+
|
509 |
+
# make sure three groups in the order of precision, recall, f1
|
510 |
+
data_pivoted = data_pivoted.reindex(["RAP - F1", "F1"])
|
511 |
+
|
512 |
+
# Plot the statistics
|
513 |
+
plt.figure(figsize=(15, 6))
|
514 |
+
ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
|
515 |
+
plt.title(title)
|
516 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
517 |
+
|
518 |
+
# Set the rotation of the x-axis labels to 0 degrees
|
519 |
+
plt.xticks(rotation=0)
|
520 |
+
|
521 |
+
# Format the y-axis to display as percentage
|
522 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
523 |
+
|
524 |
+
# get the max value of the y-axis
|
525 |
+
a1 = max(best_afrp)
|
526 |
+
a2 = max(best_f1)
|
527 |
+
|
528 |
+
max_value = max([a1, a2]) * 1.12
|
529 |
+
print("max_value:", max_value)
|
530 |
+
|
531 |
+
# Set the y-axis limit up to 70%
|
532 |
+
ax.set_ylim(0, max_value)
|
533 |
+
|
534 |
+
# Add the values above each bar
|
535 |
+
for p in ax.patches:
|
536 |
+
ax.annotate(
|
537 |
+
f"{p.get_height() * 100:.1f}",
|
538 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
539 |
+
ha="center",
|
540 |
+
va="bottom",
|
541 |
+
xytext=(0, 10),
|
542 |
+
textcoords="offset points",
|
543 |
+
rotation=90,
|
544 |
+
)
|
545 |
+
|
546 |
+
plt.show()
|
547 |
+
return data_pivoted, best_mtr
|
548 |
+
|
549 |
+
|
550 |
+
def plot_best_performance(
|
551 |
+
result,
|
552 |
+
models=None,
|
553 |
+
title="Models with Best F1 Score",
|
554 |
+
adjusted_f1=False,
|
555 |
+
ref_result=None,
|
556 |
+
):
|
557 |
+
# Initialize lists to store the statistics
|
558 |
+
model_names = []
|
559 |
+
best_precision = []
|
560 |
+
best_recall = []
|
561 |
+
best_f1 = []
|
562 |
+
best_repetition_penalty = []
|
563 |
+
best_mtr = []
|
564 |
+
|
565 |
+
if models is None:
|
566 |
+
models = result.keys()
|
567 |
+
for model in models:
|
568 |
+
print(f"model: {model}")
|
569 |
+
df = result[model]["df_overall"]
|
570 |
+
|
571 |
+
# Calculate the statistics
|
572 |
+
precision = [
|
573 |
+
df["precision"].mean() for df in result[model]["df_list_repetition_penalty"]
|
574 |
+
]
|
575 |
+
recall = [
|
576 |
+
df["recall"].mean() for df in result[model]["df_list_repetition_penalty"]
|
577 |
+
]
|
578 |
+
newline_score = [
|
579 |
+
df["newline_score"].mean()
|
580 |
+
for df in result[model]["df_list_repetition_penalty"]
|
581 |
+
]
|
582 |
+
|
583 |
+
repetition_score = [
|
584 |
+
df["repetition_score"].mean()
|
585 |
+
for df in result[model]["df_list_repetition_penalty"]
|
586 |
+
]
|
587 |
+
|
588 |
+
if adjusted_f1:
|
589 |
+
precision, recall = adjust_perf_scores_with_repetition_penalty(
|
590 |
+
result[model], precision, recall
|
591 |
+
)
|
592 |
+
|
593 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
594 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
595 |
+
|
596 |
+
best_f1.append(max(f1))
|
597 |
+
best_f1_index = f1.index(best_f1[-1])
|
598 |
+
best_repetition_penalty.append(df["repetition_penalty"][best_f1_index])
|
599 |
+
|
600 |
+
best_precision.append(precision[best_f1_index])
|
601 |
+
best_recall.append(recall[best_f1_index])
|
602 |
+
best_mtr.append(newline_score[best_f1_index] + repetition_score[best_f1_index])
|
603 |
+
|
604 |
+
print(
|
605 |
+
f"best repetition penalty: {best_repetition_penalty[-1]}, best f1: {best_f1[-1]}, precision: {best_precision[-1]}, recall: {best_recall[-1]}"
|
606 |
+
)
|
607 |
+
|
608 |
+
df = result[model]["df_list_repetition_penalty"][best_f1_index]
|
609 |
+
|
610 |
+
model_names.append(
|
611 |
+
f"{model} (RP={best_repetition_penalty[-1]})"
|
612 |
+
) # Add the model name to the list
|
613 |
+
|
614 |
+
# print sum for columns: newline_score, repetition_score
|
615 |
+
print(
|
616 |
+
f"newline_score: {df['newline_score'].sum()}, repetition_score: {df['repetition_score'].sum()}"
|
617 |
+
)
|
618 |
+
|
619 |
+
if ref_result is not None:
|
620 |
+
print("ref_result:", ref_result)
|
621 |
+
for model in ref_result.keys():
|
622 |
+
model_names.append(model)
|
623 |
+
df = pd.read_csv(ref_result[model])
|
624 |
+
# df = df[df["id"].isin(wikidata_df["id"])]
|
625 |
+
|
626 |
+
best_precision.append(df["precision"].mean())
|
627 |
+
best_recall.append(df["recall"].mean())
|
628 |
+
f1 = (
|
629 |
+
2
|
630 |
+
* (best_precision[-1] * best_recall[-1])
|
631 |
+
/ (best_precision[-1] + best_recall[-1])
|
632 |
+
)
|
633 |
+
# best_f1.append(df["f1"].mean())
|
634 |
+
best_f1.append(f1)
|
635 |
+
best_mtr.append(0)
|
636 |
+
|
637 |
+
# Create a DataFrame with the statistics
|
638 |
+
data = (
|
639 |
+
pd.DataFrame(
|
640 |
+
{
|
641 |
+
"Model": model_names,
|
642 |
+
"Adjusted Precision with RP": best_precision,
|
643 |
+
"Adjusted Recall with RP": best_recall,
|
644 |
+
"Adjusted F1 with RP": best_f1,
|
645 |
+
}
|
646 |
+
)
|
647 |
+
if adjusted_f1
|
648 |
+
else pd.DataFrame(
|
649 |
+
{
|
650 |
+
"Model": model_names,
|
651 |
+
"Precision": best_precision,
|
652 |
+
"Recall": best_recall,
|
653 |
+
"F1": best_f1,
|
654 |
+
}
|
655 |
+
)
|
656 |
+
)
|
657 |
+
columns = list(data.columns)
|
658 |
+
|
659 |
+
# Melt the DataFrame to a long format
|
660 |
+
data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
|
661 |
+
|
662 |
+
# Pivot the DataFrame to a wide format
|
663 |
+
data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
|
664 |
+
|
665 |
+
# make sure the columns are following the order of the models
|
666 |
+
data_pivoted = data_pivoted[model_names]
|
667 |
+
|
668 |
+
# make sure three groups in the order of precision, recall, f1
|
669 |
+
data_pivoted = data_pivoted.reindex(columns[1:])
|
670 |
+
|
671 |
+
# Plot the statistics
|
672 |
+
plt.figure(figsize=(10, 6))
|
673 |
+
ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
|
674 |
+
plt.title(title)
|
675 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
676 |
+
|
677 |
+
# Set the rotation of the x-axis labels to 0 degrees
|
678 |
+
plt.xticks(rotation=0)
|
679 |
+
|
680 |
+
# Format the y-axis to display as percentage
|
681 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
682 |
+
|
683 |
+
# get the max value of the y-axis
|
684 |
+
a1 = max(best_precision)
|
685 |
+
a2 = max(best_recall)
|
686 |
+
a3 = max(best_f1)
|
687 |
+
|
688 |
+
max_value = max([a1, a2, a3]) * 1.12
|
689 |
+
print("max_value:", max_value)
|
690 |
+
|
691 |
+
# Set the y-axis limit up to 70%
|
692 |
+
ax.set_ylim(0, max_value)
|
693 |
+
|
694 |
+
# Add the values above each bar
|
695 |
+
for p in ax.patches:
|
696 |
+
ax.annotate(
|
697 |
+
f"{p.get_height() * 100:.1f}",
|
698 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
699 |
+
ha="center",
|
700 |
+
va="bottom",
|
701 |
+
xytext=(0, 10),
|
702 |
+
textcoords="offset points",
|
703 |
+
rotation=90,
|
704 |
+
)
|
705 |
+
|
706 |
+
plt.show()
|
707 |
+
return data_pivoted, best_mtr
|
708 |
+
|
709 |
+
|
710 |
+
def plot_best_performance_ms_macro(
|
711 |
+
result,
|
712 |
+
models=None,
|
713 |
+
title="Models with Best RAP - Performance",
|
714 |
+
ref_result=None,
|
715 |
+
skip_generic_prompt=False,
|
716 |
+
include_adjusted_performance=True,
|
717 |
+
):
|
718 |
+
# Initialize lists to store the statistics
|
719 |
+
model_names = []
|
720 |
+
best_f1 = []
|
721 |
+
best_afrp = []
|
722 |
+
best_repetition_penalty = []
|
723 |
+
best_bleu1 = []
|
724 |
+
best_rougeL = []
|
725 |
+
best_mtr = []
|
726 |
+
|
727 |
+
if models is None:
|
728 |
+
models = result.keys()
|
729 |
+
for model in models:
|
730 |
+
if skip_generic_prompt and "generic prompt" in model:
|
731 |
+
continue
|
732 |
+
print(f"model: {model}")
|
733 |
+
df = result[model]["df_overall"]
|
734 |
+
|
735 |
+
# Calculate the statistics
|
736 |
+
bleu1 = [x for x in df["bleu1"]]
|
737 |
+
rougeL = [x for x in df["rougeL"]]
|
738 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
|
739 |
+
|
740 |
+
newline_score = [
|
741 |
+
df["newline_score"].mean()
|
742 |
+
for df in result[model]["df_list_repetition_penalty"]
|
743 |
+
]
|
744 |
+
# print(f"newline_score: {newline_score}")
|
745 |
+
|
746 |
+
repetition_score = [
|
747 |
+
df["repetition_score"].mean()
|
748 |
+
for df in result[model]["df_list_repetition_penalty"]
|
749 |
+
]
|
750 |
+
# print(f"repetition_score: {repetition_score}")
|
751 |
+
|
752 |
+
answer_len = [
|
753 |
+
df["answer_len"].mean()
|
754 |
+
for df in result[model]["df_list_repetition_penalty"]
|
755 |
+
]
|
756 |
+
|
757 |
+
afrp = [
|
758 |
+
calc_adjusted_performance(f, n + r, l)
|
759 |
+
for f, n, r, l in zip(f1, newline_score, repetition_score, answer_len)
|
760 |
+
]
|
761 |
+
|
762 |
+
best_afrp.append(max(afrp if include_adjusted_performance else f1))
|
763 |
+
best_afrp_index = (
|
764 |
+
afrp.index(best_afrp[-1])
|
765 |
+
if include_adjusted_performance
|
766 |
+
else f1.index(best_afrp[-1])
|
767 |
+
)
|
768 |
+
best_repetition_penalty.append(df["repetition_penalty"][best_afrp_index])
|
769 |
+
|
770 |
+
best_f1.append(f1[best_afrp_index])
|
771 |
+
best_bleu1.append(bleu1[best_afrp_index])
|
772 |
+
best_rougeL.append(rougeL[best_afrp_index])
|
773 |
+
best_mtr.append(
|
774 |
+
newline_score[best_afrp_index] + repetition_score[best_afrp_index]
|
775 |
+
)
|
776 |
+
|
777 |
+
# print(
|
778 |
+
# f"best repetition penalty: {best_repetition_penalty[-1]}, best afrp: {best_afrp[-1]}, f1: {best_f1[-1]}"
|
779 |
+
# )
|
780 |
+
|
781 |
+
df = result[model]["df_list_repetition_penalty"][best_afrp_index]
|
782 |
+
|
783 |
+
model_names.append(
|
784 |
+
f"{model} (RP={best_repetition_penalty[-1]})"
|
785 |
+
) # Add the model name to the list
|
786 |
+
|
787 |
+
if ref_result is not None:
|
788 |
+
print("ref_result:", ref_result)
|
789 |
+
for model in ref_result.keys():
|
790 |
+
model_names.append(model)
|
791 |
+
df = pd.read_csv(ref_result[model], comment="#", on_bad_lines="warn")
|
792 |
+
# df = df[df["id"].isin(wikidata_df["id"])]
|
793 |
+
|
794 |
+
p = df["bleu1"][0]
|
795 |
+
best_bleu1.append(p)
|
796 |
+
|
797 |
+
r = df["rougeL"][0]
|
798 |
+
best_rougeL.append(r)
|
799 |
+
|
800 |
+
f1 = 2 * p * r / (p + r) if p + r > 0 else 0
|
801 |
+
best_f1.append(f1)
|
802 |
+
best_afrp.append(f1)
|
803 |
+
best_mtr.append(0)
|
804 |
+
|
805 |
+
# print("model_names:", model_names)
|
806 |
+
# print("best_f1:", best_f1)
|
807 |
+
# print("best_afrp:", best_afrp)
|
808 |
+
|
809 |
+
# Create a DataFrame with the statistics
|
810 |
+
data = (
|
811 |
+
pd.DataFrame(
|
812 |
+
{
|
813 |
+
"Model": model_names,
|
814 |
+
"RAP - Perf Score": best_afrp,
|
815 |
+
"Overall Perf Score": best_f1,
|
816 |
+
}
|
817 |
+
)
|
818 |
+
if include_adjusted_performance
|
819 |
+
else pd.DataFrame(
|
820 |
+
{
|
821 |
+
"Model": model_names,
|
822 |
+
"Bleu-1": best_bleu1,
|
823 |
+
"Rouge-L": best_rougeL,
|
824 |
+
"Overall Perf Score": best_f1,
|
825 |
+
}
|
826 |
+
)
|
827 |
+
)
|
828 |
+
|
829 |
+
# Melt the DataFrame to a long format
|
830 |
+
data_melted = data.melt(id_vars="Model", var_name="Metric", value_name="Score")
|
831 |
+
|
832 |
+
# Pivot the DataFrame to a wide format
|
833 |
+
data_pivoted = data_melted.pivot(index="Metric", columns="Model", values="Score")
|
834 |
+
|
835 |
+
# make sure the columns are following the order of the models
|
836 |
+
data_pivoted = data_pivoted[model_names]
|
837 |
+
|
838 |
+
columns = list(data.columns)
|
839 |
+
data_pivoted = data_pivoted.reindex(columns[1:])
|
840 |
+
|
841 |
+
# Plot the statistics
|
842 |
+
plt.figure(figsize=(10, 6))
|
843 |
+
ax = data_pivoted.plot(kind="bar", ax=plt.gca(), width=0.9)
|
844 |
+
plt.title(title)
|
845 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
846 |
+
|
847 |
+
# Set the rotation of the x-axis labels to 0 degrees
|
848 |
+
plt.xticks(rotation=0)
|
849 |
+
|
850 |
+
# Format the y-axis to display as percentage
|
851 |
+
ax.yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
852 |
+
|
853 |
+
# get the max value of the y-axis
|
854 |
+
a1 = max(best_afrp)
|
855 |
+
a2 = max(best_f1)
|
856 |
+
a3 = max(best_bleu1)
|
857 |
+
a4 = max(best_rougeL)
|
858 |
+
|
859 |
+
max_value = (
|
860 |
+
max([a1, a2] if include_adjusted_performance else [a1, a2, a3, a4]) * 1.12
|
861 |
+
)
|
862 |
+
print("max_value:", max_value)
|
863 |
+
|
864 |
+
# Set the y-axis limit up to 70%
|
865 |
+
ax.set_ylim(0, max_value)
|
866 |
+
|
867 |
+
# Add the values above each bar
|
868 |
+
for p in ax.patches:
|
869 |
+
ax.annotate(
|
870 |
+
f"{p.get_height() * 100:.1f}",
|
871 |
+
(p.get_x() + p.get_width() / 2.0, p.get_height()),
|
872 |
+
ha="center",
|
873 |
+
va="bottom",
|
874 |
+
xytext=(0, 10),
|
875 |
+
textcoords="offset points",
|
876 |
+
rotation=90,
|
877 |
+
)
|
878 |
+
|
879 |
+
plt.show()
|
880 |
+
return data_pivoted, best_mtr
|
881 |
+
|
882 |
+
|
883 |
+
all_open_source_models = [
|
884 |
+
"gemma-1.1-2b-it",
|
885 |
+
"Phi-3-mini-128k-instruct",
|
886 |
+
"gemma-1.1-7b-it",
|
887 |
+
"Llama-2-7b-chat-hf",
|
888 |
+
"Mistral-7B-Instruct-v0.2",
|
889 |
+
"Meta-Llama-3-8B-Instruct",
|
890 |
+
"Llama-2-13b-chat-hf",
|
891 |
+
"Llama-2-70b-chat-hf",
|
892 |
+
"Meta-Llama-3-70B-Instruct",
|
893 |
+
]
|
894 |
+
|
895 |
+
|
896 |
+
def load_for_repetition_penalty_ms_macro(
|
897 |
+
csv_result_file, repetition_penalty, force_recalculate=False
|
898 |
+
):
|
899 |
+
result_file = replace_last(
|
900 |
+
csv_result_file, ".csv", f"_rpp_{repetition_penalty:.2f}.csv"
|
901 |
+
)
|
902 |
+
df = load_with_newline_and_repetition_scores(
|
903 |
+
result_file, force_recalculate=force_recalculate
|
904 |
+
)
|
905 |
+
|
906 |
+
return df
|
907 |
+
|
908 |
+
|
909 |
+
# MS MACRO
|
910 |
+
def plot_performance_scores_ms_macro(
|
911 |
+
result,
|
912 |
+
models=None,
|
913 |
+
title="Performance",
|
914 |
+
):
|
915 |
+
if models is None:
|
916 |
+
models = result.keys()
|
917 |
+
for model in models:
|
918 |
+
print(f"model: {model}")
|
919 |
+
df = result[model]["df_overall"]
|
920 |
+
# print(result[model]["df_list_repetition_penalty"][0].describe())
|
921 |
+
|
922 |
+
# Calculate the statistics
|
923 |
+
bleu1 = list(df["bleu1"])
|
924 |
+
rougeL = list(df["rougeL"])
|
925 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
|
926 |
+
best_f1 = max(f1)
|
927 |
+
best_f1_index = f1.index(best_f1)
|
928 |
+
|
929 |
+
bleu1, rougeL = adjust_perf_scores_with_repetition_penalty(
|
930 |
+
result[model], bleu1, rougeL
|
931 |
+
)
|
932 |
+
afrp = [2 * (p * r) / (p + r) for p, r in zip(bleu1, rougeL)]
|
933 |
+
|
934 |
+
# f1 = [df["f1"].mean() for df in result[model]["df_list_repetition_penalty"]]
|
935 |
+
best_afrp = max(afrp)
|
936 |
+
best_afrp_index = afrp.index(best_afrp)
|
937 |
+
|
938 |
+
repetition_penalties = list(df["repetition_penalty"])
|
939 |
+
|
940 |
+
# line plot for precision, recall, f1
|
941 |
+
plt.figure(figsize=(10, 6))
|
942 |
+
|
943 |
+
plt.axvspan(
|
944 |
+
repetition_penalties[best_f1_index] - 0.01,
|
945 |
+
repetition_penalties[best_f1_index] + 0.01,
|
946 |
+
alpha=0.5,
|
947 |
+
edgecolor="none",
|
948 |
+
facecolor="blue",
|
949 |
+
)
|
950 |
+
|
951 |
+
plt.axvspan(
|
952 |
+
repetition_penalties[best_afrp_index] - 0.01,
|
953 |
+
repetition_penalties[best_afrp_index] + 0.01,
|
954 |
+
alpha=0.5,
|
955 |
+
edgecolor="none",
|
956 |
+
facecolor="orange",
|
957 |
+
)
|
958 |
+
|
959 |
+
plt.plot(
|
960 |
+
repetition_penalties,
|
961 |
+
f1,
|
962 |
+
label="Overall Perf Score",
|
963 |
+
marker="D",
|
964 |
+
color="blue",
|
965 |
+
)
|
966 |
+
plt.plot(
|
967 |
+
repetition_penalties,
|
968 |
+
afrp,
|
969 |
+
label="RAP - Perf Score",
|
970 |
+
marker="o",
|
971 |
+
color="orange",
|
972 |
+
)
|
973 |
+
|
974 |
+
plt.xlabel("Repetition Penalties")
|
975 |
+
plt.ylabel("Score")
|
976 |
+
# plt.xlim(0.99, 1.31)
|
977 |
+
# y in percentage
|
978 |
+
plt.gca().yaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
979 |
+
plt.title(f"{model} {title}")
|
980 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
981 |
+
|
982 |
+
plt.show()
|
983 |
+
|
984 |
+
|
985 |
+
def plot_repetition_factors(result, groups):
|
986 |
+
for group in groups:
|
987 |
+
# Plot the statistics
|
988 |
+
plt.figure(figsize=(10, 6))
|
989 |
+
|
990 |
+
max_value = 0
|
991 |
+
for model in result.keys():
|
992 |
+
if not group in model.lower():
|
993 |
+
continue
|
994 |
+
print(f"model: {model}")
|
995 |
+
df = result[model]["df_overall"]
|
996 |
+
repetition_panelties = [
|
997 |
+
repetition_penalty for repetition_penalty in df["repetition_penalty"]
|
998 |
+
]
|
999 |
+
|
1000 |
+
mean_score = [
|
1001 |
+
df["total_repetitions"].mean()
|
1002 |
+
for df in result[model]["df_list_repetition_penalty"]
|
1003 |
+
]
|
1004 |
+
|
1005 |
+
sns.lineplot(x=repetition_panelties, y=mean_score, label=model)
|
1006 |
+
|
1007 |
+
new_max = max(mean_score)
|
1008 |
+
if new_max > max_value:
|
1009 |
+
max_value = new_max
|
1010 |
+
|
1011 |
+
max_value = max_value * 1.05
|
1012 |
+
# if max_value < 1.5:
|
1013 |
+
# max_value = 1.5
|
1014 |
+
# set ylimit
|
1015 |
+
plt.ylim(0, max_value)
|
1016 |
+
|
1017 |
+
# show grid
|
1018 |
+
plt.grid(True)
|
1019 |
+
plt.xlabel("Repetition Penalties")
|
1020 |
+
plt.ylabel("Mean Total Repetitions")
|
1021 |
+
plt.title("Mean Total Repetitions vs Repetition Penalties")
|
1022 |
+
plt.legend()
|
1023 |
+
|
1024 |
+
plt.show()
|
1025 |
+
|
1026 |
+
|
1027 |
+
def plot_repetition_factors_by_group(result, group_filter=None):
|
1028 |
+
markers = ["D", "o", "s", "x"]
|
1029 |
+
colors = ["blue", "orange", "green", "red"]
|
1030 |
+
|
1031 |
+
# Plot the statistics
|
1032 |
+
plt.figure(figsize=(10, 6))
|
1033 |
+
index = 0
|
1034 |
+
max_value = 0
|
1035 |
+
|
1036 |
+
for model in result.keys():
|
1037 |
+
if group_filter is not None and group_filter not in model:
|
1038 |
+
continue
|
1039 |
+
|
1040 |
+
print(f"model: {model}")
|
1041 |
+
|
1042 |
+
df = result[model]["df_overall"]
|
1043 |
+
repetition_panelties = [
|
1044 |
+
repetition_penalty for repetition_penalty in df["repetition_penalty"]
|
1045 |
+
]
|
1046 |
+
|
1047 |
+
# Calculate the statistics
|
1048 |
+
mean_score = [
|
1049 |
+
df["total_repetitions"].mean()
|
1050 |
+
for df in result[model]["df_list_repetition_penalty"]
|
1051 |
+
]
|
1052 |
+
if len(mean_score) != len(repetition_panelties):
|
1053 |
+
print(
|
1054 |
+
f"model: {model} has different length of repetition penalties and mean score"
|
1055 |
+
)
|
1056 |
+
print("repetition_panelties:", len(repetition_panelties))
|
1057 |
+
print("mean_score:", len(mean_score))
|
1058 |
+
continue
|
1059 |
+
|
1060 |
+
new_max = max(mean_score)
|
1061 |
+
if new_max > max_value:
|
1062 |
+
max_value = new_max
|
1063 |
+
|
1064 |
+
sns.lineplot(
|
1065 |
+
x=repetition_panelties,
|
1066 |
+
y=mean_score,
|
1067 |
+
label=model,
|
1068 |
+
marker=markers[index],
|
1069 |
+
color=colors[index],
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
index += 1
|
1073 |
+
|
1074 |
+
max_value = max_value * 1.05
|
1075 |
+
# if max_value < 1.5:
|
1076 |
+
# max_value = 1.5
|
1077 |
+
# set ylimit
|
1078 |
+
plt.ylim(0, max_value)
|
1079 |
+
max_value = 0
|
1080 |
+
|
1081 |
+
plt.xlabel("Repetition Penalties")
|
1082 |
+
plt.ylabel("Mean Total Repetitions")
|
1083 |
+
plt.title("Mean Total Repetitions vs Repetition Penalties")
|
1084 |
+
plt.legend(bbox_to_anchor=(1.0, 0.5), loc="center left")
|
1085 |
+
|
1086 |
+
plt.show()
|
1087 |
+
|
1088 |
+
|
1089 |
+
ms_marco_csv_result_files = [
|
1090 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Generic Prompt)_mm.csv",
|
1091 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Chat Template)_mm.csv",
|
1092 |
+
"data/results_v2/gemma-1.1-2b-it(Non-RAG)_mm.csv",
|
1093 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Generic Prompt)_mm.csv",
|
1094 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Chat Template)_mm.csv",
|
1095 |
+
"data/results_v2/Phi-3-mini-128k-instruct(Non-RAG)_mm.csv",
|
1096 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Generic Prompt)_mm.csv",
|
1097 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Chat Template)_mm.csv",
|
1098 |
+
"data/results_v2/gemma-1.1-7b-it(Non-RAG)_mm.csv",
|
1099 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Generic Prompt)_mm.csv",
|
1100 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Chat Template)_mm.csv",
|
1101 |
+
"data/results_v2/Llama-2-7b-chat-hf(Non-RAG)_mm.csv",
|
1102 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Generic Prompt)_mm.csv",
|
1103 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Chat Template)_mm.csv",
|
1104 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(Non-RAG)_mm.csv",
|
1105 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Generic Prompt)_mm.csv",
|
1106 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Chat Template)_mm.csv",
|
1107 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(Non-RAG)_mm.csv",
|
1108 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Generic Prompt)_mm.csv",
|
1109 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Chat Template)_mm.csv",
|
1110 |
+
"data/results_v2/Llama-2-13b-chat-hf(Non-RAG)_mm.csv",
|
1111 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Generic Prompt)_mm.csv",
|
1112 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Chat Template)_mm.csv",
|
1113 |
+
"data/results_v2/Llama-2-70b-chat-hf(Non-RAG)_mm.csv",
|
1114 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Generic Prompt)_mm.csv",
|
1115 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Chat Template)_mm.csv",
|
1116 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(Non-RAG)_mm.csv",
|
1117 |
+
]
|
1118 |
+
|
1119 |
+
webqsp_csv_result_files = [
|
1120 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Generic Prompt)_wd.csv",
|
1121 |
+
"data/results_v2/gemma-1.1-2b-it(RAG - Chat Template)_wd.csv",
|
1122 |
+
"data/results_v2/gemma-1.1-2b-it(Non-RAG)_wd.csv",
|
1123 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Generic Prompt)_wd.csv",
|
1124 |
+
"data/results_v2/Phi-3-mini-128k-instruct(RAG - Chat Template)_wd.csv",
|
1125 |
+
"data/results_v2/Phi-3-mini-128k-instruct(Non-RAG)_wd.csv",
|
1126 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Generic Prompt)_wd.csv",
|
1127 |
+
"data/results_v2/gemma-1.1-7b-it(RAG - Chat Template)_wd.csv",
|
1128 |
+
"data/results_v2/gemma-1.1-7b-it(Non-RAG)_wd.csv",
|
1129 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Generic Prompt)_wd.csv",
|
1130 |
+
"data/results_v2/Llama-2-7b-chat-hf(RAG - Chat Template)_wd.csv",
|
1131 |
+
"data/results_v2/Llama-2-7b-chat-hf(Non-RAG)_wd.csv",
|
1132 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Generic Prompt)_wd.csv",
|
1133 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(RAG - Chat Template)_wd.csv",
|
1134 |
+
"data/results_v2/Mistral-7B-Instruct-v0.2(Non-RAG)_wd.csv",
|
1135 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Generic Prompt)_wd.csv",
|
1136 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(RAG - Chat Template)_wd.csv",
|
1137 |
+
"data/results_v2/Meta-Llama-3-8B-Instruct(Non-RAG)_wd.csv",
|
1138 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Generic Prompt)_wd.csv",
|
1139 |
+
"data/results_v2/Llama-2-13b-chat-hf(RAG - Chat Template)_wd.csv",
|
1140 |
+
"data/results_v2/Llama-2-13b-chat-hf(Non-RAG)_wd.csv",
|
1141 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Generic Prompt)_wd.csv",
|
1142 |
+
"data/results_v2/Llama-2-70b-chat-hf(RAG - Chat Template)_wd.csv",
|
1143 |
+
"data/results_v2/Llama-2-70b-chat-hf(Non-RAG)_wd.csv",
|
1144 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Generic Prompt)_wd.csv",
|
1145 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(RAG - Chat Template)_wd.csv",
|
1146 |
+
"data/results_v2/Meta-Llama-3-70B-Instruct(Non-RAG)_wd.csv",
|
1147 |
+
]
|
1148 |
+
|
1149 |
+
|
1150 |
+
def calc_rap_scores(
|
1151 |
+
result, precision="precision", recall="recall", penalty_function="cubic"
|
1152 |
+
):
|
1153 |
+
newline_score = [
|
1154 |
+
df["newline_score"].mean() for df in result["df_list_repetition_penalty"]
|
1155 |
+
]
|
1156 |
+
|
1157 |
+
repetition_score = [
|
1158 |
+
df["repetition_score"].mean() for df in result["df_list_repetition_penalty"]
|
1159 |
+
]
|
1160 |
+
|
1161 |
+
if precision in result["df_list_repetition_penalty"][0].columns:
|
1162 |
+
precision = [
|
1163 |
+
df[precision].mean() for df in result["df_list_repetition_penalty"]
|
1164 |
+
]
|
1165 |
+
recall = [df[recall].mean() for df in result["df_list_repetition_penalty"]]
|
1166 |
+
else:
|
1167 |
+
precision = result["df_overall"][precision]
|
1168 |
+
recall = result["df_overall"][recall]
|
1169 |
+
|
1170 |
+
f1 = [2 * (p * r) / (p + r) for p, r in zip(precision, recall)]
|
1171 |
+
|
1172 |
+
nrr = [
|
1173 |
+
1 - (n + r) / s
|
1174 |
+
for f, n, r, s in zip(
|
1175 |
+
f1, newline_score, repetition_score, result["df_overall"]["answer_len"]
|
1176 |
+
)
|
1177 |
+
]
|
1178 |
+
|
1179 |
+
rap = [
|
1180 |
+
calc_adjusted_performance(f, 1 - n, penalty_function=penalty_function)
|
1181 |
+
for f, n in zip(f1, nrr)
|
1182 |
+
]
|
1183 |
+
|
1184 |
+
return newline_score, repetition_score, f1, rap, nrr
|
1185 |
+
|
1186 |
+
|
1187 |
+
def get_model_name(csv_result_file):
|
1188 |
+
parts = re.split(r"[_/]", csv_result_file)
|
1189 |
+
print(f"parts: {parts}")
|
1190 |
+
model_name = parts[3]
|
1191 |
+
return model_name
|
1192 |
+
|
1193 |
+
|
1194 |
+
def load_webqsp_result(
|
1195 |
+
csv_result_files, force_recalculate=False, save=False, penalty_function="cubic"
|
1196 |
+
):
|
1197 |
+
result = {}
|
1198 |
+
for i, csv_result_file in enumerate(csv_result_files):
|
1199 |
+
try:
|
1200 |
+
df = pd.read_csv(csv_result_file)
|
1201 |
+
model_name = get_model_name(csv_result_file)
|
1202 |
+
print(f"\tmodel_name: {model_name}")
|
1203 |
+
|
1204 |
+
dfs = [
|
1205 |
+
calculate_performance_score(
|
1206 |
+
csv_result_file,
|
1207 |
+
repetition_penalty,
|
1208 |
+
force_recalculate=force_recalculate,
|
1209 |
+
)
|
1210 |
+
for repetition_penalty in df["repetition_penalty"]
|
1211 |
+
]
|
1212 |
+
|
1213 |
+
answer_lens = []
|
1214 |
+
for df_rpp in dfs:
|
1215 |
+
answer_lens.append(df_rpp["answer_len"].mean())
|
1216 |
+
df["answer_len"] = answer_lens
|
1217 |
+
|
1218 |
+
result[model_name] = {
|
1219 |
+
"df_overall": df,
|
1220 |
+
"df_list_repetition_penalty": dfs,
|
1221 |
+
"file": csv_result_file,
|
1222 |
+
}
|
1223 |
+
newline_score, repetition_score, perf, rap, nrr = calc_rap_scores(
|
1224 |
+
result[model_name], penalty_function=penalty_function
|
1225 |
+
)
|
1226 |
+
df["newline_score"] = newline_score
|
1227 |
+
df["repetition_score"] = repetition_score
|
1228 |
+
df["total_repetitions"] = df["newline_score"] + df["repetition_score"]
|
1229 |
+
df["perf"] = perf
|
1230 |
+
df["nrr"] = nrr
|
1231 |
+
df["rap"] = rap
|
1232 |
+
df["rr"] = df["nrr"].apply(lambda x: 1 - x)
|
1233 |
+
df["rrp"] = df["rr"].apply(lambda x: x * 100)
|
1234 |
+
if save:
|
1235 |
+
df.to_csv(csv_result_file, index=False)
|
1236 |
+
except Exception as e:
|
1237 |
+
print(f"Error: {e}")
|
1238 |
+
traceback.print_exc()
|
1239 |
+
|
1240 |
+
return result
|
1241 |
+
|
1242 |
+
|
1243 |
+
def load_ms_marco_result(
|
1244 |
+
csv_result_files,
|
1245 |
+
force_recalculate=False,
|
1246 |
+
calc_bertscore=True,
|
1247 |
+
save=False,
|
1248 |
+
penalty_function="cubic",
|
1249 |
+
):
|
1250 |
+
result = {}
|
1251 |
+
for csv_result_file in csv_result_files:
|
1252 |
+
try:
|
1253 |
+
df = pd.read_csv(csv_result_file)
|
1254 |
+
model_name = get_model_name(csv_result_file)
|
1255 |
+
print(f"\tmodel_name: {model_name}")
|
1256 |
+
|
1257 |
+
dfs = [
|
1258 |
+
load_for_repetition_penalty_ms_macro(
|
1259 |
+
csv_result_file,
|
1260 |
+
repetition_penalty,
|
1261 |
+
force_recalculate=force_recalculate,
|
1262 |
+
)
|
1263 |
+
for repetition_penalty in df["repetition_penalty"]
|
1264 |
+
]
|
1265 |
+
|
1266 |
+
answer_lens = []
|
1267 |
+
for df_rpp in dfs:
|
1268 |
+
answer_lens.append(df_rpp["answer_len"].mean())
|
1269 |
+
df["answer_len"] = answer_lens
|
1270 |
+
|
1271 |
+
col = "bert_score" if calc_bertscore else "meteor"
|
1272 |
+
score_unavailable = col not in df.columns
|
1273 |
+
|
1274 |
+
if score_unavailable:
|
1275 |
+
save = True
|
1276 |
+
bert_meteor_scores = []
|
1277 |
+
bert_score_references = None
|
1278 |
+
for df_rpp in dfs:
|
1279 |
+
if calc_bertscore:
|
1280 |
+
bert_meteor_score = 0
|
1281 |
+
|
1282 |
+
for i, row in df_rpp.iterrows():
|
1283 |
+
answer = row["answer"]
|
1284 |
+
if not isinstance(answer, str):
|
1285 |
+
answer = ""
|
1286 |
+
bert_meteor_score += bert_score.compute(
|
1287 |
+
predictions=[answer],
|
1288 |
+
references=[row["ground_truth"][0]],
|
1289 |
+
lang="en",
|
1290 |
+
model_type="microsoft/deberta-large-mnli",
|
1291 |
+
)["f1"][0]
|
1292 |
+
# get average of bertscore
|
1293 |
+
bert_meteor_score = bert_meteor_score / len(df_rpp)
|
1294 |
+
|
1295 |
+
print(f"bert_score: {bert_meteor_score}")
|
1296 |
+
else:
|
1297 |
+
bert_meteor_score = meteor.compute(
|
1298 |
+
predictions=df_rpp["answer"],
|
1299 |
+
references=df_rpp["ground_truth"],
|
1300 |
+
)["meteor"]
|
1301 |
+
|
1302 |
+
bert_meteor_scores.append(bert_meteor_score)
|
1303 |
+
|
1304 |
+
df[col] = bert_meteor_scores
|
1305 |
+
|
1306 |
+
result[model_name] = {
|
1307 |
+
"df_overall": df,
|
1308 |
+
"df_list_repetition_penalty": dfs,
|
1309 |
+
"file": csv_result_file,
|
1310 |
+
}
|
1311 |
+
newline_score, repetition_score, perf, rap, nrr = calc_rap_scores(
|
1312 |
+
result[model_name],
|
1313 |
+
precision=col,
|
1314 |
+
recall=col,
|
1315 |
+
penalty_function=penalty_function,
|
1316 |
+
)
|
1317 |
+
df["newline_score"] = newline_score
|
1318 |
+
df["repetition_score"] = repetition_score
|
1319 |
+
df["total_repetitions"] = df["newline_score"] + df["repetition_score"]
|
1320 |
+
df["perf"] = perf
|
1321 |
+
df["nrr"] = nrr
|
1322 |
+
df["rap"] = rap
|
1323 |
+
df["rr"] = df["nrr"].apply(lambda x: 1 - x)
|
1324 |
+
df["rrp"] = df["rr"].apply(lambda x: x * 100)
|
1325 |
+
|
1326 |
+
if save:
|
1327 |
+
df.to_csv(csv_result_file, index=False)
|
1328 |
+
except Exception as e:
|
1329 |
+
print("An error occurred:", e)
|
1330 |
+
traceback.print_exc()
|
1331 |
+
print(f"csv_result_file: {csv_result_file}")
|
1332 |
+
|
1333 |
+
return result
|
eval_modules/utils.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import platform
|
8 |
+
import re
|
9 |
+
from pathlib import Path
|
10 |
+
import evaluate
|
11 |
+
import pandas as pd
|
12 |
+
import requests
|
13 |
+
import torch
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
|
17 |
+
class LogRecord(logging.LogRecord):
|
18 |
+
def getMessage(self):
|
19 |
+
msg = self.msg
|
20 |
+
if self.args:
|
21 |
+
if isinstance(self.args, dict):
|
22 |
+
msg = msg.format(**self.args)
|
23 |
+
else:
|
24 |
+
msg = msg.format(*self.args)
|
25 |
+
return msg
|
26 |
+
|
27 |
+
|
28 |
+
class Logger(logging.Logger):
|
29 |
+
def makeRecord(
|
30 |
+
self,
|
31 |
+
name,
|
32 |
+
level,
|
33 |
+
fn,
|
34 |
+
lno,
|
35 |
+
msg,
|
36 |
+
args,
|
37 |
+
exc_info,
|
38 |
+
func=None,
|
39 |
+
extra=None,
|
40 |
+
sinfo=None,
|
41 |
+
):
|
42 |
+
rv = LogRecord(name, level, fn, lno, msg, args, exc_info, func, sinfo)
|
43 |
+
if extra is not None:
|
44 |
+
for key in extra:
|
45 |
+
rv.__dict__[key] = extra[key]
|
46 |
+
return rv
|
47 |
+
|
48 |
+
|
49 |
+
def init_settings():
|
50 |
+
logging.setLoggerClass(Logger)
|
51 |
+
logging.basicConfig(
|
52 |
+
level=logging.WARNING,
|
53 |
+
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def remove_extra_spaces(text):
|
58 |
+
return re.sub(" +", " ", text.strip())
|
59 |
+
|
60 |
+
|
61 |
+
def print_llm_response(llm_response, debug_retrieval=True):
|
62 |
+
answer = llm_response["answer"] if "answer" in llm_response else None
|
63 |
+
if answer is None:
|
64 |
+
answer = llm_response["response"] if "response" in llm_response else None
|
65 |
+
|
66 |
+
if answer is not None:
|
67 |
+
print("\n\n***Answer:")
|
68 |
+
print(answer)
|
69 |
+
|
70 |
+
source_documents = (
|
71 |
+
llm_response["source_documents"] if "source_documents" in llm_response else None
|
72 |
+
)
|
73 |
+
if source_documents is None:
|
74 |
+
source_documents = (
|
75 |
+
llm_response["sourceDocs"] if "sourceDocs" in llm_response else None
|
76 |
+
)
|
77 |
+
|
78 |
+
if debug_retrieval and source_documents is not None:
|
79 |
+
print("\nSources:")
|
80 |
+
for index, source in enumerate(source_documents):
|
81 |
+
metadata = source["metadata"] if "metadata" in source else source.metadata
|
82 |
+
if "page" in metadata:
|
83 |
+
print(f" Page: {metadata['page']}", end="")
|
84 |
+
|
85 |
+
print(
|
86 |
+
f" Source {index + 1}: "
|
87 |
+
+ str(metadata["url"] if "url" in metadata else metadata["source"])
|
88 |
+
)
|
89 |
+
print(
|
90 |
+
source["page_content"]
|
91 |
+
if "page_content" in source
|
92 |
+
else source.page_content
|
93 |
+
)
|
94 |
+
|
95 |
+
if "chat_history" in llm_response:
|
96 |
+
print("\nChat History:")
|
97 |
+
print(llm_response["chat_history"])
|
98 |
+
|
99 |
+
|
100 |
+
def get_device_types():
|
101 |
+
print("Running on: ", platform.platform())
|
102 |
+
print("MPS is", "NOT" if not torch.backends.mps.is_available() else "", "available")
|
103 |
+
print("CUDA is", "NOT" if not torch.cuda.is_available() else "", "available")
|
104 |
+
device_type_available = "cpu"
|
105 |
+
|
106 |
+
if not torch.backends.mps.is_available():
|
107 |
+
if not torch.backends.mps.is_built():
|
108 |
+
print(
|
109 |
+
"MPS not available because the current PyTorch install was not "
|
110 |
+
"built with MPS enabled."
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
print(
|
114 |
+
"MPS not available because the current MacOS version is not 12.3+ "
|
115 |
+
"and/or you do not have an MPS-enabled device on this machine."
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
device_type_available = "mps"
|
119 |
+
|
120 |
+
if torch.cuda.is_available():
|
121 |
+
print("CUDA is available, we have found ", torch.cuda.device_count(), " GPU(s)")
|
122 |
+
print(torch.cuda.get_device_name(0))
|
123 |
+
print("CUDA version: " + torch.version.cuda)
|
124 |
+
device_type_available = f"cuda:{torch.cuda.current_device()}"
|
125 |
+
|
126 |
+
return (
|
127 |
+
os.environ.get("HF_EMBEDDINGS_DEVICE_TYPE") or device_type_available,
|
128 |
+
os.environ.get("HF_PIPELINE_DEVICE_TYPE") or device_type_available,
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
def ensure_model_is_downloaded(llm_model_type):
|
133 |
+
if llm_model_type.startswith("gpt4all"):
|
134 |
+
local_path = (
|
135 |
+
os.environ.get("GPT4ALL_J_MODEL_PATH")
|
136 |
+
if llm_model_type == "gpt4all-j"
|
137 |
+
else os.environ.get("GPT4ALL_MODEL_PATH")
|
138 |
+
)
|
139 |
+
url = (
|
140 |
+
os.environ.get("GPT4ALL_J_DOWNLOAD_LINK")
|
141 |
+
if llm_model_type == "gpt4all-j"
|
142 |
+
else os.environ.get("GPT4ALL_DOWNLOAD_LINK")
|
143 |
+
)
|
144 |
+
elif llm_model_type == "llamacpp":
|
145 |
+
local_path = os.environ.get("LLAMACPP_MODEL_PATH")
|
146 |
+
url = os.environ.get("LLAMACPP_DOWNLOAD_LINK")
|
147 |
+
elif llm_model_type == "ctransformers":
|
148 |
+
local_path = os.environ.get("CTRANSFORMERS_MODEL_PATH")
|
149 |
+
url = os.environ.get("CTRANSFORMERS_DOWNLOAD_LINK")
|
150 |
+
else:
|
151 |
+
raise ValueError(f"wrong model typle: {llm_model_type}")
|
152 |
+
|
153 |
+
path = Path(local_path)
|
154 |
+
|
155 |
+
if path.is_file():
|
156 |
+
print(f"model: {local_path} exists")
|
157 |
+
else:
|
158 |
+
print(f"downloading model: {local_path} from {url} ...")
|
159 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
160 |
+
|
161 |
+
# send a GET request to the URL to download the file. Stream since it's large
|
162 |
+
response = requests.get(url, stream=True)
|
163 |
+
|
164 |
+
# open the file in binary mode and write the contents of the response to it in chunks
|
165 |
+
# This is a large file, so be prepared to wait.
|
166 |
+
with open(local_path, "wb") as f:
|
167 |
+
for chunk in tqdm(response.iter_content(chunk_size=8192)):
|
168 |
+
if chunk:
|
169 |
+
f.write(chunk)
|
170 |
+
|
171 |
+
return local_path
|
172 |
+
|
173 |
+
|
174 |
+
bleu = evaluate.load("bleu")
|
175 |
+
rouge = evaluate.load("rouge")
|
176 |
+
|
177 |
+
|
178 |
+
def calc_bleu_rouge_scores(predictions, references, debug=False):
|
179 |
+
if debug:
|
180 |
+
print("predictions:", predictions)
|
181 |
+
print("references:", references)
|
182 |
+
|
183 |
+
bleu_scores = bleu.compute(
|
184 |
+
predictions=predictions, references=references, max_order=1
|
185 |
+
)
|
186 |
+
rouge_scores = rouge.compute(predictions=predictions, references=references)
|
187 |
+
result = {"bleu_scores": bleu_scores, "rouge_scores": rouge_scores}
|
188 |
+
|
189 |
+
if debug:
|
190 |
+
print("result:", result)
|
191 |
+
|
192 |
+
return result
|
193 |
+
|
194 |
+
|
195 |
+
def calc_metrics(df):
|
196 |
+
predictions = [df["answer"][i] for i in range(len(df))]
|
197 |
+
references = [df["ground_truth"][i] for i in range(len(df))]
|
198 |
+
|
199 |
+
return calc_bleu_rouge_scores(predictions, references)
|
200 |
+
|
201 |
+
|
202 |
+
pattern_abnormal_newlines = re.compile(r"\n{5,}")
|
203 |
+
pattern_text_repetitions = re.compile(r"\b(\w.+?)\b(\1+)", re.M | re.DOTALL)
|
204 |
+
exception_pattern = re.compile(r"(\w+\.)\1")
|
205 |
+
|
206 |
+
|
207 |
+
# final version for repetition detection
|
208 |
+
def detect_repetitions(
|
209 |
+
text, debug=False, pattern_text_repetitions=pattern_text_repetitions
|
210 |
+
):
|
211 |
+
subtotals = [0, 0]
|
212 |
+
|
213 |
+
if isinstance(text, str):
|
214 |
+
patterns = [pattern_abnormal_newlines, pattern_text_repetitions]
|
215 |
+
for i, pattern in enumerate(patterns):
|
216 |
+
if debug:
|
217 |
+
print(
|
218 |
+
f"----detect {'abnormal newlines' if i == 0 else 'text repetitions'}----"
|
219 |
+
)
|
220 |
+
matches = pattern.finditer(text)
|
221 |
+
for match in matches:
|
222 |
+
if debug:
|
223 |
+
print(match)
|
224 |
+
for groupNum in range(0, len(match.groups())):
|
225 |
+
groupNum = groupNum + 1
|
226 |
+
print(
|
227 |
+
"Group {groupNum} found at {start}-{end}: `{group}`".format(
|
228 |
+
groupNum=groupNum,
|
229 |
+
start=match.start(groupNum),
|
230 |
+
end=match.end(groupNum),
|
231 |
+
group=match.group(groupNum),
|
232 |
+
)
|
233 |
+
)
|
234 |
+
|
235 |
+
if exception_pattern.match(match[0]):
|
236 |
+
if debug:
|
237 |
+
print("ignored: ", match[0])
|
238 |
+
continue
|
239 |
+
|
240 |
+
start, end = match.span()
|
241 |
+
subtotals[i] += end - start
|
242 |
+
|
243 |
+
result = (subtotals[0], subtotals[1], subtotals[0] + subtotals[1])
|
244 |
+
|
245 |
+
if debug:
|
246 |
+
print(result)
|
247 |
+
return result
|
248 |
+
|
249 |
+
|
250 |
+
def detect_abnormal_newlines(text, debug=False):
|
251 |
+
return detect_repetitions(text, debug=debug)[0]
|
252 |
+
|
253 |
+
|
254 |
+
def detect_text_repetitions(text, debug=False):
|
255 |
+
return detect_repetitions(text, debug=debug)[1]
|
256 |
+
|
257 |
+
|
258 |
+
def detect_repetition_scores(text, debug=False):
|
259 |
+
newline_score, repetition_score, total_repetitions = detect_repetitions(
|
260 |
+
text, debug=debug
|
261 |
+
)
|
262 |
+
return pd.Series([newline_score, repetition_score, total_repetitions])
|
ms_macro.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
CHANGED
@@ -1 +1,20 @@
|
|
1 |
-
huggingface_hub==0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub==0.26.0
|
2 |
+
nltk==3.8.1
|
3 |
+
langchain==0.1.16
|
4 |
+
langchain-openai==0.1.3
|
5 |
+
langchain_google_genai==1.0.2
|
6 |
+
transformers==4.40.1
|
7 |
+
accelerate==0.29.3
|
8 |
+
python-dotenv==1.0.1
|
9 |
+
gradio==4.44.1
|
10 |
+
black==24.4.0
|
11 |
+
InstructorEmbedding==1.0.1
|
12 |
+
sentence-transformers==2.2.2
|
13 |
+
chardet==5.2.0
|
14 |
+
sentencepiece==0.1.98
|
15 |
+
evaluate==0.4.3
|
16 |
+
rouge_score==0.1.2
|
17 |
+
pytest==8.2.1
|
18 |
+
seaborn==0.13.2
|
19 |
+
tenacity==8.3.0
|
20 |
+
bert_score==0.3.13
|