Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: evaluate missing splits #1268

Merged
110 changes: 96 additions & 14 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import traceback
from collections.abc import Iterable
from copy import copy
from copy import copy, deepcopy
from datetime import datetime
from itertools import chain
from pathlib import Path
Expand All @@ -15,6 +15,7 @@
import datasets
from sentence_transformers import SentenceTransformer

from mteb.abstasks.AbsTask import ScoresDict
from mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta
from mteb.models import model_meta_from_sentence_transformers
Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(
self._version = version
self.err_logs_path = err_logs_path

self.last_evaluated_splits = {}

self.select_tasks(**kwargs)

def deprecation_warning(
Expand Down Expand Up @@ -308,6 +311,59 @@ def _run_eval(
tock = time()
return results, tick, tock

@staticmethod
def _get_missing_splits(
existing_results: TaskResult | None, task_eval_splits: list[str], task: AbsTask
) -> list[str]:
if existing_results is None:
return task_eval_splits

missing_splits = []
for split in task_eval_splits:
if split not in existing_results.scores:
missing_splits.append(split)
elif not existing_results.scores[
split
]: # Check if the split has any scores
missing_splits.append(split)

return missing_splits

@staticmethod
def _merge_results(
existing_results: TaskResult, new_results: TaskResult
) -> TaskResult:
merged_scores = existing_results.scores.copy()

for split, scores in new_results.scores.items():
if split in merged_scores:
merged_scores[split] = MTEB._merge_split_scores(
merged_scores[split], scores
)
else:
merged_scores[split] = scores

merged_results = TaskResult(
dataset_revision=existing_results.dataset_revision,
task_name=existing_results.task_name,
mteb_version=existing_results.mteb_version,
scores=merged_scores,
evaluation_time=existing_results.evaluation_time
+ new_results.evaluation_time,
kg_co2_emissions=existing_results.kg_co2_emissions,
)

return merged_results

@staticmethod
def _merge_split_scores(
existing_scores: list[ScoresDict], new_scores: list[ScoresDict]
) -> list[ScoresDict]:
merged = {score["hf_subset"]: score for score in existing_scores}
for score in new_scores:
merged[score["hf_subset"]] = score
return list(merged.values())

def run(
self,
model: SentenceTransformer | Encoder,
Expand Down Expand Up @@ -379,15 +435,16 @@ def run(
original_tasks = (
self.tasks.copy()
) # save them in case we re-use the object (e.g. for reranking)
self.last_evaluated_splits = {}
while len(self.tasks) > 0:
task = self.tasks[0]
logger.info(
f"\n\n********************** Evaluating {task.metadata.name} **********************"
)

# skip evaluation if results folder exists and overwrite_results is False
if output_path:
save_path = output_path / f"{task.metadata.name}{task.save_suffix}.json"
existing_results = None
if save_path.exists() and not overwrite_results:
logger.info(
f"{task.metadata.name} results already exists. Loading results from disk. Set overwrite_results=True to overwrite."
Expand All @@ -396,21 +453,38 @@ def run(
evaluation_results.append(mteb_results)
del self.tasks[0] # empty memory
continue
try:

task_eval_splits = (
eval_splits if eval_splits is not None else task.eval_splits
)
missing_splits = self._get_missing_splits(
existing_results, task_eval_splits, task
)

# load data
logger.info(f"Loading dataset for {task.metadata_dict['name']}")
if not missing_splits and existing_results:
logger.info(
f"{task.metadata.name} results already exist. Loading results from disk."
)
evaluation_results.append(existing_results)
self.last_evaluated_splits[task.metadata.name] = [] # Add this line
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an error?

del self.tasks[0]
continue

if missing_splits:
logger.info(
f"Running evaluation for missing splits: {missing_splits}"
)

try:
task.check_if_dataset_is_superseeded()
task.load_data(eval_splits=task_eval_splits, **kwargs)

# run evaluation
task_results = {}
evaluation_time = 0
kg_co2_emissions: int | None = 0 if co2_tracker else None
for split in task_eval_splits:

for split in missing_splits:
if co2_tracker:
try:
from codecarbon import EmissionsTracker
Expand Down Expand Up @@ -453,21 +527,22 @@ def run(
if verbosity >= 1:
logger.info(f"Scores: {results}")

mteb_task_result = TaskResult.from_task_results(
new_results = TaskResult.from_task_results(
task,
task_results,
evaluation_time=evaluation_time,
kg_co2_emissions=kg_co2_emissions,
)

# save results
if existing_results:
merged_results = self._merge_results(existing_results, new_results)
Comment on lines +537 to +538
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably worth merging using the MTEBResult object, instead of directly on the dict.

new_results = MTEBResults(...)
if existing_results:
  new_results.update(existing_results)

else:
merged_results = new_results

if output_path:
with open(save_path, "w") as f_out:
json.dump(
mteb_task_result.to_dict(), f_out, indent=2, sort_keys=True
)
merged_results.to_disk(save_path)

evaluation_results.append(mteb_task_result)
evaluation_results.append(merged_results)

except Exception as e:
logger.error(
Expand All @@ -486,7 +561,6 @@ def run(
# empty memory
del self.tasks[0]

# restore original tasks
self.tasks = original_tasks
return evaluation_results

Expand Down Expand Up @@ -537,3 +611,11 @@ def _save_model_metadata(model_meta: ModelMeta, output_folder: Path) -> None:

with save_path.open("w") as f:
json.dump(model_meta.to_dict(), f)

def get_last_evaluated_splits(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? (I would just add some logging messages instead)

"""Returns a dictionary of tasks and their evaluated splits from the most recent run.
Tasks with empty lists indicate that results already existed and no splits were evaluated.
"""
return deepcopy(
{task: list(splits) for task, splits in self.last_evaluated_splits.items()}
)
92 changes: 92 additions & 0 deletions tests/test_evaluation/test_split_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

import pytest
from sentence_transformers import SentenceTransformer

import mteb
from mteb import MTEB


@pytest.fixture
def model():
return SentenceTransformer("all-MiniLM-L6-v2")
isaac-chung marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def nfcorpus_tasks():
return mteb.get_tasks(tasks=["NFCorpus"], languages=["eng"])
isaac-chung marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skip(reason="WIP")
def test_all_splits_evaluated(model, nfcorpus_tasks, tmp_path):
evaluation = MTEB(tasks=nfcorpus_tasks)
evaluation.run(
model,
eval_splits=["train", "test"],
save_predictions=True,
output_folder=str(tmp_path / "testcase1"),
verbosity=2,
)
last_evaluated_splits = evaluation.get_last_evaluated_splits()
print(last_evaluated_splits)
assert "NFCorpus" in last_evaluated_splits
assert set(last_evaluated_splits["NFCorpus"]) == {"train", "test"}
assert len(last_evaluated_splits["NFCorpus"]) == 2
Comment on lines +23 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify tests a bit: (general across tests)

Suggested change
evaluation.run(
model,
eval_splits=["train", "test"],
save_predictions=True,
output_folder=str(tmp_path / "testcase1"),
verbosity=2,
)
last_evaluated_splits = evaluation.get_last_evaluated_splits()
print(last_evaluated_splits)
assert "NFCorpus" in last_evaluated_splits
assert set(last_evaluated_splits["NFCorpus"]) == {"train", "test"}
assert len(last_evaluated_splits["NFCorpus"]) == 2
result_obj = evaluation.run(
model,
eval_splits=["train", "test"],
save_predictions=True,
output_folder=str(tmp_path / "testcase1"),
verbosity=2,
)
# check splits here based on object - no need to last_evaluated_splits

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why save_predictions is true?



@pytest.mark.skip(reason="WIP")
def test_one_missing_split(model, nfcorpus_tasks, tmp_path):
evaluation = MTEB(tasks=nfcorpus_tasks)
evaluation.run(
model,
eval_splits=["train"],
save_predictions=True,
output_folder=str(tmp_path / "testcase2"),
verbosity=2,
)

# Get model and tasks again
model = SentenceTransformer("all-MiniLM-L6-v2")
nfcorpus_tasks = mteb.get_tasks(tasks=["NFCorpus"], languages=["eng"])

evaluation_2 = MTEB(tasks=nfcorpus_tasks)
evaluation_2.run(
model,
eval_splits=["train", "test"],
save_predictions=True,
output_folder=str(tmp_path / "testcase2"),
verbosity=2,
)

last_evaluated_splits = evaluation_2.get_last_evaluated_splits()

print(last_evaluated_splits)
assert "NFCorpus" in last_evaluated_splits
assert set(last_evaluated_splits["NFCorpus"]) == {"test"}
assert len(last_evaluated_splits["NFCorpus"]) == 1


@pytest.mark.skip(reason="WIP")
def test_no_missing_splits(model, nfcorpus_tasks, tmp_path):
evaluation_1 = MTEB(tasks=nfcorpus_tasks)
evaluation_1.run(
model,
eval_splits=["train", "test"],
save_predictions=True,
output_folder=str(tmp_path / "testcase3"),
verbosity=2,
)

evaluation_2 = MTEB(tasks=nfcorpus_tasks)
evaluation_2.run(
model,
eval_splits=["train", "test"],
save_predictions=True,
output_folder=str(tmp_path / "testcase3"),
verbosity=2,
)

last_evaluated_splits = evaluation_2.get_last_evaluated_splits()

assert "NFCorpus" in last_evaluated_splits
assert len(last_evaluated_splits["NFCorpus"]) == 0