Skip to content

Commit

Permalink
feat: support classification datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Nov 22, 2024
1 parent be817b9 commit 5428c1c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 12 deletions.
13 changes: 12 additions & 1 deletion src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,19 @@ def compute_predictions_original_images(self, dataset_ids):
ground_truth_annotations = self.ground_truth_annotations.get_annotations(
dataset_ids
).values()

def ensure_bbox(annotation, image):
if "bbox" not in annotation:
annotation["bbox"] = [0, 0, image.width, image.height]
return annotation

ground_truth_with_bbox = [
[ensure_bbox(annotation, image) for annotation in annotations]
for annotations, image in zip(ground_truth_annotations, image_id_to_image.values())
]

ground_truth_predictions = convert_from_ground_truth_to_second_arg(
ground_truth_annotations, self.context.dataset
ground_truth_with_bbox, self.context.dataset
)
scores = compute_score(
dataset_ids,
Expand Down
55 changes: 44 additions & 11 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,24 +90,32 @@ def expand_hugging_face_datasets(dataset_identifiers: SequenceType[str], streami
return expanded_identifiers


def find_column_name(features, column_names):
return next((key for key in column_names if key in features), None)


class HuggingFaceDataset(BaseDataset):
"""Interface for Hugging Face datasets with a similar API to JsonDataset."""

def __init__(self, identifier: str):
self.imgs: dict[str, dict] = {}
self.anns: dict[str, dict] = {}
self.cats: dict[str, dict] = {}
self._id_to_row_idx: dict[str, int] = {}

repo, config, split, streaming = identifier.split("@")
self._streaming = streaming == "streaming"
self._dataset = load_dataset(repo, config, split=split, streaming=self._streaming)
# transforms and base64 encoding require RGB mode
self._dataset.cast_column("image", DatasetImage(mode="RGB"))
if self._streaming:
self._dataset = self._dataset.take(HF_ROWS_TO_TAKE_STREAMING)
self.imgs: dict[str, dict] = {}
self.anns: dict[str, dict] = {}
self.cats: dict[str, dict] = {}
self._id_to_row_idx: dict[str, int] = {}
self._load_data()

def _load_data(self):
image_key = find_column_name(self._dataset.features, ["image", "img"])
self._image_key = image_key
# transforms and base64 encoding require RGB mode
self._dataset.cast_column(image_key, DatasetImage(mode="RGB"))

counter = 0

def make_id():
Expand Down Expand Up @@ -136,21 +144,32 @@ def extract_labels(feature):
if labels:
self.cats = {i: {"id": i, "name": str(name)} for i, name in enumerate(labels)}

objects_key = find_column_name(self._dataset.features, ["objects"])

classifications_key = find_column_name(
self._dataset.features,
[
"labels",
"label",
"classifications",
],
)

new_cats = set()
# speed initial metadata process by not loading images if we can random access rows (not streaming)
maybe_no_image = (
self._dataset if self._streaming else self._dataset.remove_columns(["image"])
self._dataset if self._streaming else self._dataset.remove_columns([image_key])
)
for idx, example in enumerate(maybe_no_image):
id = example.get("id", example.get("image_id", idx))
if self._streaming:
self.imgs[id] = {"id": id, "image": example["image"]}
self.imgs[id] = {"id": id, "image": example[image_key]}
else:
self.imgs[id] = {"id": id}
self._id_to_row_idx[id] = idx

if "objects" in example:
objects = example["objects"]
if objects_key:
objects = example[objects_key]
if isinstance(objects, list):
# Convert list of dicts to dict of lists. We want columns, not rows.
cat_keys = ["category", "category_id", "label"]
Expand All @@ -175,6 +194,20 @@ def extract_labels(feature):
"bbox": bbox,
}

if classifications_key:
classes = example[classifications_key]
if not isinstance(classes, list):
classes = [classes]
for cat_id in classes:
if cat_id not in self.cats:
new_cats.add(cat_id)
ann_id = make_id()
self.anns[ann_id] = {
"id": ann_id,
"image_id": id,
"category_id": cat_id,
}

if new_cats:
max_existing_id = max(self.cats.keys(), default=0)
for new_cat in new_cats:
Expand All @@ -189,7 +222,7 @@ def get_image(self, id):
return self.imgs[id]["image"]
else:
row_idx = self._id_to_row_idx[id]
return self._dataset[row_idx]["image"]
return self._dataset[row_idx][self._image_key]


@lru_cache
Expand Down

0 comments on commit 5428c1c

Please sign in to comment.