Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit for cli #75

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions kornia/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Command-line interface for Kornia."""

import argparse
import logging
from pathlib import Path

from kornia.contrib.face_detection import YuFaceDetectNet

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


def export_onnx_model_resolver(model_name: str, output_path: Path) -> None:
"""Resolve the configuration for exporting a model to ONNX format."""
if "yunet" in model_name:
onnx_model_path = output_path / "yunet.onnx"
res = YuFaceDetectNet("test", pretrained=True).to_onnx(
image_shape={"channels": 3, "height": 320, "width": 320},
onnx_model_path=onnx_model_path,
input_names=["images"],
output_names=["loc", "conf", "iou"],
dynamic_axes={"images": {0: "B"}},
)
if res:
logger.info("Model exported to %s", onnx_model_path)
else:
raise ValueError(f"Model {model_name} not supported")


def main() -> None:
"""Main function for the Kornia CLI."""
parser = argparse.ArgumentParser(description="Kornia CLI")

subparsers = parser.add_subparsers(dest="command")

# Create a subparser for the 'export' command
export_parser = subparsers.add_parser("export", help="Export a model to different formats")
export_parser.add_argument("--model", type=str, required=True, help="Model name to export")
export_parser.add_argument(
"--format",
type=str,
required=True,
choices=["onnx"],
help="Format to export the model",
)
export_parser.add_argument(
"--output-path",
type=Path,
required=True,
help="Path to save the exported model",
)

args = parser.parse_args()

# Handle the 'export' command
if args.command == "export":
if args.format == "onnx":
logger.info("Exporting model %s to ONNX format", args.model)
export_onnx_model_resolver(args.model, args.output_path)
else:
logger.error("Format %s not supported", args.format)
else:
parser.print_help()


if __name__ == "__main__":
main()
98 changes: 75 additions & 23 deletions kornia/contrib/face_detection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# based on: https://github.com/ShiqiYu/libfacedetection.train/blob/74f3aa77c63234dd954d21286e9a60703b8d0868/tasks/task1/yufacedetectnet.py # noqa
from __future__ import annotations

import math
from enum import Enum
from typing import Dict, List, Optional, Tuple
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from kornia.core.module import KorniaModule
from kornia.geometry.bbox import nms as nms_kornia
from kornia.utils.helpers import map_location_to_cpu

Expand Down Expand Up @@ -41,7 +44,7 @@ def __init__(self, data: torch.Tensor) -> None:
raise ValueError(f"Result must comes as vector of size(14). Got: {data.shape}.")
self._data = data

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> "FaceDetectorResult":
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> FaceDetectorResult:
"""Like :func:`torch.nn.Module.to()` method."""
self._data = self._data.to(device=device, dtype=dtype)
return self
Expand Down Expand Up @@ -147,7 +150,11 @@ class FaceDetector(nn.Module):
"""

def __init__(
self, top_k: int = 5000, confidence_threshold: float = 0.3, nms_threshold: float = 0.3, keep_top_k: int = 750
self,
top_k: int = 5000,
confidence_threshold: float = 0.3,
nms_threshold: float = 0.3,
keep_top_k: int = 750,
) -> None:
super().__init__()
self.top_k = top_k
Expand All @@ -171,19 +178,34 @@ def __init__(
def preprocess(self, image: torch.Tensor) -> torch.Tensor:
return image

def postprocess(self, data: Dict[str, torch.Tensor], height: int, width: int) -> List[torch.Tensor]:
def postprocess(self, data: dict[str, torch.Tensor], height: int, width: int) -> list[torch.Tensor]:
loc, conf, iou = data["loc"], data["conf"], data["iou"]

scale = torch.tensor(
[width, height, width, height, width, height, width, height, width, height, width, height, width, height],
[
width,
height,
width,
height,
width,
height,
width,
height,
width,
height,
width,
height,
width,
height,
],
device=loc.device,
dtype=loc.dtype,
) # 14

priors = _PriorBox(self.min_sizes, self.steps, self.clip, image_size=(height, width))
priors = priors.to(loc.device, loc.dtype)

batched_dets: List[torch.Tensor] = []
batched_dets: list[torch.Tensor] = []
for batch_elem in range(loc.shape[0]):
boxes = _decode(loc[batch_elem], priors(), self.variance) # Nx14
boxes = boxes * scale
Expand Down Expand Up @@ -211,7 +233,7 @@ def postprocess(self, data: Dict[str, torch.Tensor], height: int, width: int) ->
batched_dets.append(dets[: self.keep_top_k])
return batched_dets

def forward(self, image: torch.Tensor) -> List[torch.Tensor]:
def forward(self, image: torch.Tensor) -> list[torch.Tensor]:
r"""Detect faces in a given batch of images.

Args:
Expand All @@ -232,7 +254,10 @@ class ConvDPUnit(nn.Sequential):
def __init__(self, in_channels: int, out_channels: int, withBNRelu: bool = True) -> None:
super().__init__()
self.add_module("conv1", nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=True, groups=1))
self.add_module("conv2", nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels))
self.add_module(
"conv2",
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels),
)
if withBNRelu:
self.add_module("bn", nn.BatchNorm2d(out_channels))
self.add_module("relu", nn.ReLU(inplace=True))
Expand All @@ -254,7 +279,8 @@ def __init__(self, in_channels: int, out_channels: int, withBNRelu: bool = True)
self.add_module("conv2", ConvDPUnit(in_channels, out_channels, withBNRelu))


class YuFaceDetectNet(nn.Module):
# class YuFaceDetectNet(nn.Module):
class YuFaceDetectNet(KorniaModule):
def __init__(self, phase: str, pretrained: bool) -> None:
super().__init__()
self.phase = phase
Expand Down Expand Up @@ -293,7 +319,7 @@ def __init__(self, phase: str, pretrained: bool) -> None:
self.load_state_dict(pretrained_dict, strict=True)
self.eval()

def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
detection_sources, head_list = [], []

x = self.model0(x)
Expand Down Expand Up @@ -339,7 +365,7 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:


# Adapted from https://github.com/Hakuyume/chainer-ssd
def _decode(loc: torch.Tensor, priors: torch.Tensor, variances: List[float]) -> torch.Tensor:
def _decode(loc: torch.Tensor, priors: torch.Tensor, variances: list[float]) -> torch.Tensor:
"""Decode locations from predictions using priors to undo the encoding we did for offset regression at train
time.

Expand Down Expand Up @@ -369,7 +395,13 @@ def _decode(loc: torch.Tensor, priors: torch.Tensor, variances: List[float]) ->


class _PriorBox:
def __init__(self, min_sizes: List[List[int]], steps: List[int], clip: bool, image_size: Tuple[int, int]) -> None:
def __init__(
self,
min_sizes: list[list[int]],
steps: list[int],
clip: bool,
image_size: tuple[int, int],
) -> None:
self.min_sizes = min_sizes
self.steps = steps
self.clip = clip
Expand All @@ -382,23 +414,43 @@ def __init__(self, min_sizes: List[List[int]], steps: List[int], clip: bool, ima
if self.steps[i] != math.pow(2, (i + 3)):
raise ValueError("steps must be [8,16,32,64]")

self.feature_map_2th = [int(int((self.image_size[0] + 1) / 2) / 2), int(int((self.image_size[1] + 1) / 2) / 2)]
self.feature_map_3th = [int(self.feature_map_2th[0] / 2), int(self.feature_map_2th[1] / 2)]
self.feature_map_4th = [int(self.feature_map_3th[0] / 2), int(self.feature_map_3th[1] / 2)]
self.feature_map_5th = [int(self.feature_map_4th[0] / 2), int(self.feature_map_4th[1] / 2)]
self.feature_map_6th = [int(self.feature_map_5th[0] / 2), int(self.feature_map_5th[1] / 2)]

self.feature_maps = [self.feature_map_3th, self.feature_map_4th, self.feature_map_5th, self.feature_map_6th]

def to(self, device: torch.device, dtype: torch.dtype) -> "_PriorBox":
self.feature_map_2th = [
int(int((self.image_size[0] + 1) / 2) / 2),
int(int((self.image_size[1] + 1) / 2) / 2),
]
self.feature_map_3th = [
int(self.feature_map_2th[0] / 2),
int(self.feature_map_2th[1] / 2),
]
self.feature_map_4th = [
int(self.feature_map_3th[0] / 2),
int(self.feature_map_3th[1] / 2),
]
self.feature_map_5th = [
int(self.feature_map_4th[0] / 2),
int(self.feature_map_4th[1] / 2),
]
self.feature_map_6th = [
int(self.feature_map_5th[0] / 2),
int(self.feature_map_5th[1] / 2),
]

self.feature_maps = [
self.feature_map_3th,
self.feature_map_4th,
self.feature_map_5th,
self.feature_map_6th,
]

def to(self, device: torch.device, dtype: torch.dtype) -> _PriorBox:
self.device = device
self.dtype = dtype
return self

def __call__(self) -> torch.Tensor:
anchors: List[float] = []
anchors: list[float] = []
for k, f in enumerate(self.feature_maps):
min_sizes: List[int] = self.min_sizes[k]
min_sizes: list[int] = self.min_sizes[k]
# NOTE: the nested loop it's to make torchscript happy
for i in range(f[0]):
for j in range(f[1]):
Expand Down
64 changes: 62 additions & 2 deletions kornia/core/module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Module definition for Kornia."""

import datetime
import math
import os
from functools import wraps
from typing import Any, Callable, List, Optional, Tuple, Union
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch

import kornia

Expand All @@ -11,6 +16,59 @@
from .external import numpy as np


class KorniaModule(Module):
"""Base class for all Kornia modules.

This class extends the PyTorch `Module` class and provides additional functionalities
to handle input and output types, and end-to-end visualization.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def to_onnx(
self,
image_shape: Dict[str, int],
onnx_model_path: Path,
input_names: List[str],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
) -> bool:
"""Export the model to ONNX format.

Args:
image_shape: The shape of the input image. It must contain the number of channels, height, and width.
onnx_model_path: The path to save the ONNX model.
input_names: The names of the input tensors.
output_names: The names of the output tensors.
dynamic_axes: The dynamic axes of the model.
"""
if "channels" not in image_shape:
raise ValueError("The image shape must contain the number of channels.")

if "height" not in image_shape:
raise ValueError("The image shape must contain the height.")

if "width" not in image_shape:
raise ValueError("The image shape must contain the width.")

input_image = torch.rand(1, image_shape["channels"], image_shape["height"], image_shape["width"])

torch.onnx.export(
self,
input_image,
onnx_model_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)

if not onnx_model_path.exists():
return False

return True


class ImageModuleMixIn:
"""A MixIn that handles image-based operations.

Expand All @@ -21,7 +79,9 @@ class ImageModuleMixIn:
_output_image: Any

def convert_input_output(
self, input_names_to_handle: Optional[List[Any]] = None, output_type: str = "tensor"
self,
input_names_to_handle: Optional[List[Any]] = None,
output_type: str = "tensor",
) -> Callable[[Any], Any]:
"""Decorator to convert input and output types for a function.

Expand Down
Loading
Loading