Skip to content

Commit

Permalink
refactor(coco_utils): make normalized annotation struct
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Nov 25, 2024
1 parent 620567e commit 2431259
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 186 deletions.
32 changes: 3 additions & 29 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,7 @@ async def _update_transformed_images(self, dataset_ids):
# depends on original images predictions
if self.state.predictions_original_images_enabled:
scores = compute_score(
self.context.dataset,
{
"annotations": self.predictions_original_images,
"type": "predictions",
},
{
"annotations": annotations,
"type": "predictions",
},
self.context.dataset, self.predictions_original_images, annotations
)
for id, score in scores:
update_image_meta(
Expand All @@ -292,17 +284,7 @@ async def _update_transformed_images(self, dataset_ids):
)

ground_truth_annotations = self.ground_truth_annotations.get_annotations(dataset_ids)
scores = compute_score(
self.context.dataset,
{
"annotations": ground_truth_annotations,
"type": "truth",
},
{
"annotations": annotations,
"type": "predictions",
},
)
scores = compute_score(self.context.dataset, ground_truth_annotations, annotations)
for id, score in scores:
update_image_meta(
self.state, id, {"ground_truth_to_transformed_detection_score": score}
Expand Down Expand Up @@ -334,15 +316,7 @@ def compute_predictions_original_images(self, dataset_ids):
ground_truth_annotations = self.ground_truth_annotations.get_annotations(dataset_ids)

scores = compute_score(
self.context.dataset,
{
"annotations": ground_truth_annotations,
"type": "truth",
},
{
"annotations": self.predictions_original_images,
"type": "predictions",
},
self.context.dataset, ground_truth_annotations, self.predictions_original_images
)
for dataset_id, score in scores:
update_image_meta(
Expand Down
220 changes: 81 additions & 139 deletions src/nrtk_explorer/library/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,100 +4,9 @@
ClassAgnosticPixelwiseIoUScorer,
)

# This module contains functions to convert ground truth annotations and predictions to COCOScorer format
# COCOScorer is a library that computes the COCO metrics for object detection tasks.


def convert_from_ground_truth_to_first_arg(dataset_annotations):
"""Convert ground truth annotations to COCOScorer format"""
annotations = list()
for dataset_image_annotations in dataset_annotations:
image_annotations = list()
for annotation in dataset_image_annotations:
image_annotations.append(
(
AxisAlignedBoundingBox(
annotation["bbox"][0:2],
[
annotation["bbox"][0] + annotation["bbox"][2],
annotation["bbox"][1] + annotation["bbox"][3],
],
),
{
"category_id": annotation["category_id"],
"image_id": annotation["image_id"],
},
)
)
annotations.append(image_annotations)
return annotations


def convert_from_ground_truth_to_second_arg(dataset_annotations, dataset):
"""Convert ground truth annotations to COCOScorer format"""
categories = {cat["id"]: cat["name"] for cat in dataset.cats.values()}
annotations = list()
for dataset_image_annotations in dataset_annotations:
image_annotations = list()
for annotation in dataset_image_annotations:
image_annotations.append(
(
AxisAlignedBoundingBox(
annotation["bbox"][0:2],
[
annotation["bbox"][0] + annotation["bbox"][2],
annotation["bbox"][1] + annotation["bbox"][3],
],
),
{categories[annotation["category_id"]]: 1},
)
)
annotations.append(image_annotations)
return annotations


def convert_from_predictions_to_first_arg(predictions, dataset, ids):
"""Convert predictions to COCOScorer format"""
predictions = convert_from_predictions_to_second_arg(predictions)
categories = {cat["name"]: cat["id"] for cat in dataset.cats.values()}
real_ids = [id_.split("_")[-1] for id_ in ids]

for id_, img_predictions in zip(real_ids, predictions):
for prediction in img_predictions:
class_name = list(prediction[1].keys())[0]

prediction[1].clear()
prediction[1]["image_id"] = int(id_)
if class_name in categories:
prediction[1]["category_id"] = categories[class_name]
else:
prediction[1]["category_id"] = 0

return predictions


def convert_from_predictions_to_second_arg(predictions):
"""Convert predictions to COCOScorer format"""
annotations_predictions = list()
for img_predictions in predictions:
current_annotations = list()
for prediction in img_predictions:
if prediction:
current_annotations.append(
(
AxisAlignedBoundingBox(
[prediction["box"]["xmin"], prediction["box"]["ymin"]],
[prediction["box"]["xmax"], prediction["box"]["ymax"]],
),
{prediction["label"]: prediction["score"]},
)
)

annotations_predictions.append(current_annotations)

return annotations_predictions


# Example usage:
# def is_odd(x):
# return x % 2 == 1
Expand All @@ -115,40 +24,61 @@ def partition(pred, iterable):


def image_id_to_dataset_id(image_id):
return image_id.split("_")[-1]
return str(image_id).split("_")[-1]


def keys_to_dataset_ids(image_dict):
"""Convert keys to dataset ids."""
return {image_id_to_dataset_id(key): value for key, value in image_dict.items()}


def make_get_cat_ids(dataset):
"""Get category ids from annotations."""
label_to_id = {cat["name"]: cat["id"] for cat in dataset.cats.values()}
def get_cat_id(dataset, annotation):
if "category_id" in annotation:
return annotation["category_id"]
cat = dataset.name_to_cat.get(annotation["label"], None)
if not cat:
return None
return cat["id"]


def get_cat_id(annotation):
if "category_id" in annotation:
return annotation["category_id"]
return label_to_id.get(annotation["label"], None)
def get_label(dataset, annotation):
if annotation["category_id"] is not None:
return dataset.cats[annotation["category_id"]]["name"]
return annotation["label"]

return get_cat_id

def get_score(annotation):
return 1.0 if "score" not in annotation else annotation["score"]

def make_ensure_cat_ids(dataset):
"""Ensure category_id exists in annotations."""
get_cat_id = make_get_cat_ids(dataset)

def ensure_cat_id(annotations_set):
return [
[{**annotation, "category_id": get_cat_id(annotation)} for annotation in annotations]
for annotations in annotations_set
]
def normalize_annotation(dataset, image_id, annotation):
"""Normalize a single annotation."""
category_id = get_cat_id(dataset, annotation)
annotation = {
**annotation,
"category_id": category_id,
"image_id": image_id,
"score": get_score(annotation),
"bbox": annotation.get("bbox", annotation.get("box", None)),
}
annotation["label"] = get_label(dataset, annotation)
annotation[annotation["label"]] = annotation["score"]
return annotation

return ensure_cat_id

def normalize_annotations(dataset, image_id, annotations):
"""Ensure category_id, bbox, label, score, [label]:score."""
return [normalize_annotation(dataset, image_id, annotation) for annotation in annotations]

def calculate_category_match_score(actual: List[Dict], predicted: List[Dict]) -> float:

def predictions_to_annotations(dataset, ids, predictions):
return [
normalize_annotations(dataset, id, image_predictions)
for id, image_predictions in zip(ids, predictions)
]


def get_category_similarity_score(actual: List[Dict], predicted: List[Dict]) -> float:
"""
Calculate matching score between actual and predicted category annotations.
Expand All @@ -166,40 +96,56 @@ def calculate_category_match_score(actual: List[Dict], predicted: List[Dict]) ->
1 for cat_id in predicted_cat_ids if cat_id in actual_cat_ids and cat_id is not None
)

total_cat_ids = len(actual_cat_ids) + len(predicted_cat_ids)
total_cat_ids = len(actual_cat_ids.union(predicted_cat_ids))
return matching_cat_ids / total_cat_ids if total_cat_ids > 0 else 0.0


def calculate_category_matching_scores(
dataset: Dict, actual: List[Dict], predicted: List[Dict], ids: List[str]
def compute_category_similarity_scores(
actual: List[Dict], predicted: List[Dict], ids: List[str]
) -> List[Tuple[str, float]]:
"""
Calculate matching scores between actual and predicted category annotations.
Args:
dataset: The COCO dataset dictionary
actual: List of actual annotations
predicted: List of predicted annotations
ids: List of image IDs
Returns:
List of tuples containing (image_id, matching_score)
"""
ensure_cat_id = make_ensure_cat_ids(dataset)
actual_with_cat_ids = ensure_cat_id(actual)
predicted_with_cat_ids = ensure_cat_id(predicted)

return [
(id, calculate_category_match_score(actual, predicted))
for actual, predicted, id in zip(actual_with_cat_ids, predicted_with_cat_ids, ids)
(id, get_category_similarity_score(actual, predicted))
for actual, predicted, id in zip(actual, predicted, ids)
]


def compute_score(dataset, actual_info, predicted_info):
def get_aabb(annotation):
bbox = annotation["bbox"]
if "xmin" in bbox:
return AxisAlignedBoundingBox(
[bbox["xmin"], bbox["ymin"]],
[bbox["xmax"], bbox["ymax"]],
)
return AxisAlignedBoundingBox(
[bbox[0], bbox[1]],
[bbox[0] + bbox[2], bbox[1] + bbox[3]],
)


def to_nrtk_score_shape(annotation):
return (
get_aabb(annotation),
annotation,
)


def compute_score(dataset, actual, predicted):
"""Compute score for image ids."""

actual = keys_to_dataset_ids(actual_info["annotations"])
predicted = keys_to_dataset_ids(predicted_info["annotations"])
actual = keys_to_dataset_ids(actual)
predicted = keys_to_dataset_ids(predicted)
ids = list(actual.keys())

pairs = [(actual[key], predicted[key], key) for key in ids]
Expand All @@ -226,31 +172,27 @@ def is_empty(prediction_pair):

actual, predicted, ids = zip(*has_annotations)

actual = predictions_to_annotations(dataset, ids, actual)
predicted = predictions_to_annotations(dataset, ids, predicted)

all_annotations_have_bbox = all(
"bbox" in annotation
annotation["bbox"] is not None
for annotation_list in actual + predicted
for annotation in annotation_list
)

if not all_annotations_have_bbox:
s = calculate_category_matching_scores(dataset, actual, predicted, ids)
s = compute_category_similarity_scores(actual, predicted, ids)
return scores + s

if actual_info["type"] == "predictions":
actual_converted = convert_from_predictions_to_first_arg(
actual,
dataset,
ids,
)
elif actual_info["type"] == "truth":
actual_converted = convert_from_ground_truth_to_first_arg(actual)

if predicted_info["type"] == "predictions":
predicted_converted = convert_from_predictions_to_second_arg(
predicted,
)
elif predicted_info["type"] == "truth":
predicted_converted = convert_from_ground_truth_to_second_arg(predicted, dataset)
actual_converted = [
[to_nrtk_score_shape(annotation) for annotation in img_annotations]
for img_annotations in actual
]
predicted_converted = [
[to_nrtk_score_shape(annotation) for annotation in img_annotations]
for img_annotations in predicted
]

score_output = ClassAgnosticPixelwiseIoUScorer().score(actual_converted, predicted_converted)
for id, score in zip(ids, score_output):
Expand Down
12 changes: 10 additions & 2 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def get_image(self, id: int):
pass


class JsonDataset(BaseDataset):
class CategoryIndex:
def build_cat_index(self):
# follows kwcoco.name_to_cat
self.name_to_cat = {cat["name"]: cat for cat in self.cats.values()}


class JsonDataset(BaseDataset, CategoryIndex):
"""JSON-based COCO datasets."""

def __init__(self, path: str):
Expand All @@ -40,6 +46,7 @@ def __init__(self, path: str):
self.cats = {cat["id"]: cat for cat in self.data["categories"]}
self.anns = {ann["id"]: ann for ann in self.data["annotations"]}
self.imgs = {img["id"]: img for img in self.data["images"]}
self.build_cat_index()

def _get_image_fpath(self, selected_id: int):
dataset_dir = Path(self.fpath).parent
Expand Down Expand Up @@ -94,7 +101,7 @@ def find_column_name(features, column_names):
return next((key for key in column_names if key in features), None)


class HuggingFaceDataset(BaseDataset):
class HuggingFaceDataset(BaseDataset, CategoryIndex):
"""Interface for Hugging Face datasets with a similar API to JsonDataset."""

def __init__(self, identifier: str):
Expand All @@ -109,6 +116,7 @@ def __init__(self, identifier: str):
if self._streaming:
self._dataset = self._dataset.take(HF_ROWS_TO_TAKE_STREAMING)
self._load_data()
self.build_cat_index()

def _load_data(self):
image_key = find_column_name(self._dataset.features, ["image", "img"])
Expand Down
Loading

0 comments on commit 2431259

Please sign in to comment.