Spaces:
AIR-Bench
/
Running on CPU Upgrade

nan commited on
Commit
9fcf267
·
1 Parent(s): a50e211

test: add unit tests for models

Browse files
src/models.py CHANGED
@@ -85,16 +85,17 @@ class FullEvalResult:
85
  is_anonymous=config.get("is_anonymous", False),
86
  )
87
  result_list.append(eval_result)
 
88
  return cls(
89
- eval_name=f"{result_list[0].retrieval_model}_{result_list[0].reranking_model}",
90
- retrieval_model=result_list[0].retrieval_model,
91
- reranking_model=result_list[0].reranking_model,
92
  retrieval_model_link=retrieval_model_link,
93
  reranking_model_link=reranking_model_link,
94
  results=result_list,
95
- timestamp=result_list[0].timestamp,
96
- revision=result_list[0].revision,
97
- is_anonymous=result_list[0].is_anonymous,
98
  )
99
 
100
  def to_dict(self, task="qa", metric="ndcg_at_3") -> List:
 
85
  is_anonymous=config.get("is_anonymous", False),
86
  )
87
  result_list.append(eval_result)
88
+ eval_result = result_list[0]
89
  return cls(
90
+ eval_name=f"{eval_result.retrieval_model}_{eval_result.reranking_model}",
91
+ retrieval_model=eval_result.retrieval_model,
92
+ reranking_model=eval_result.reranking_model,
93
  retrieval_model_link=retrieval_model_link,
94
  reranking_model_link=reranking_model_link,
95
  results=result_list,
96
+ timestamp=eval_result.timestamp,
97
+ revision=eval_result.revision,
98
+ is_anonymous=eval_result.is_anonymous,
99
  )
100
 
101
  def to_dict(self, task="qa", metric="ndcg_at_3") -> List:
tests/src/test_models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from pathlib import Path
3
+
4
+ from src.models import EvalResult, FullEvalResult
5
+
6
+ cur_fp = Path(__file__)
7
+
8
+
9
+ def test_eval_result():
10
+ eval_result = EvalResult(
11
+ eval_name="eval_name",
12
+ retrieval_model="bge-m3",
13
+ reranking_model="NoReranking",
14
+ results=[
15
+ {
16
+ "domain": "law",
17
+ "lang": "en",
18
+ "dataset": "lex_files_500K-600K",
19
+ "value": 0.45723
20
+ }
21
+ ],
22
+ task="qa",
23
+ metric="ndcg_at_3",
24
+ timestamp="2024-05-14T03:09:08Z",
25
+ revision="1e243f14bd295ccdea7a118fe847399d",
26
+ is_anonymous=True,
27
+ )
28
+
29
+
30
+ @pytest.mark.parametrize(
31
+ 'file_path',
32
+ [
33
+ "AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json",
34
+ "AIR-Bench_24.05/bge-m3/NoReranker/results.json"
35
+ ])
36
+ def test_full_eval_result_init_from_json_file(file_path):
37
+ json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path
38
+ full_eval_result = FullEvalResult.init_from_json_file(json_fp)
39
+ assert json_fp.parents[0].stem == full_eval_result.reranking_model
40
+ assert json_fp.parents[1].stem == full_eval_result.retrieval_model
41
+ assert len(full_eval_result.results) == 70
42
+
43
+
44
+ def test_full_eval_result_to_dict():
45
+ json_fp = cur_fp.parents[1] / "toydata/eval_results/" / "AIR-Bench_24.05/bge-m3/NoReranker/results.json"
46
+ full_eval_result = FullEvalResult.init_from_json_file(json_fp)
47
+ result_dict_list = full_eval_result.to_dict()
48
+ assert len(result_dict_list) == 1
49
+ print(len(result_dict_list[0]))
tests/toydata/eval_results/AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json ADDED
The diff for this file is too large to render. See raw diff
 
tests/toydata/eval_results/AIR-Bench_24.05/bge-m3/NoReranker/results.json ADDED
The diff for this file is too large to render. See raw diff