From bad9106668b321249731747b56f1d9254b88b9eb Mon Sep 17 00:00:00 2001 From: Alessandro Genova Date: Fri, 15 Nov 2024 14:12:18 -0500 Subject: [PATCH] feat(transforms): add ability to add multiple transforms --- src/nrtk_explorer/app/parameters.py | 128 ++++++++++++------ src/nrtk_explorer/app/transforms.py | 31 +++-- src/nrtk_explorer/library/nrtk_transforms.py | 3 +- src/nrtk_explorer/library/transforms.py | 48 ++++++- src/nrtk_explorer/library/yaml_transforms.py | 57 ++++++-- src/nrtk_explorer/widgets/nrtk_explorer.py | 18 +++ tests/test_transforms.py | 4 +- vue-components/src/components/ParamWidget.vue | 18 +-- .../src/components/ParamsWidget.vue | 12 +- .../src/components/TransformWidget.vue | 55 ++++++++ .../src/components/TransformsWidget.vue | 34 +++++ vue-components/src/components/index.js | 2 + vue-components/src/types/index.ts | 7 + 13 files changed, 317 insertions(+), 100 deletions(-) create mode 100644 vue-components/src/components/TransformWidget.vue create mode 100644 vue-components/src/components/TransformsWidget.vue diff --git a/src/nrtk_explorer/app/parameters.py b/src/nrtk_explorer/app/parameters.py index 52a178bd..3e666b43 100644 --- a/src/nrtk_explorer/app/parameters.py +++ b/src/nrtk_explorer/app/parameters.py @@ -7,7 +7,7 @@ from trame.app import get_server -from nrtk_explorer.widgets.nrtk_explorer import ParamsWidget +from nrtk_explorer.widgets.nrtk_explorer import TransformsWidget from nrtk_explorer.library.transforms import ( ImageTransform, @@ -21,16 +21,17 @@ class ParametersApp(Applet): def __init__(self, server): super().__init__(server) - self.state.current_transform = "TestTransform" + self.context.setdefault("transforms", []) + self.state.setdefault("transforms", []) + self.state.setdefault("transform_descriptions", {}) - self._transforms: Dict[str, ImageTransform] = { - "TestTransform": TestTransform(), - "GaussianBlurTransform": GaussianBlurTransform(), - "IdentityTransform": IdentityTransform(), + self._transform_classes: Dict[str, type[ImageTransform]] = { + "TestTransform": TestTransform, + "GaussianBlurTransform": GaussianBlurTransform, + "IdentityTransform": IdentityTransform, } - self.state.transforms = [k for k in self._transforms.keys()] - self.state.current_transform = self.state.transforms[0] + self._default_transform = None self.server.controller.add("on_server_ready")(self.on_server_ready) @@ -38,47 +39,93 @@ def __init__(self, server): def on_server_ready(self, *args, **kwargs): # Bind instance methods to state change - self.on_current_transform_change() - self.state.change("current_transform")(self.on_current_transform_change) + self.update_transforms_descriptions() - def on_current_transform_change(self, **kwargs): - transform = self._transforms[self.state.current_transform] - self.state.params_values = transform.get_parameters() - self.state.params_descriptions = transform.get_parameters_description() + def on_add_transform(self, *args, **kwargs): + if self._default_transform in self._transform_classes: + transform_name, transform_class = ( + self._default_transform, + self._transform_classes[self._default_transform], + ) + else: + transform_name, transform_class = next(iter(self._transform_classes.items())) - def on_transform_parameters_changed(self, parameters, **kwargs): - transform = self._transforms[self.state.current_transform] - transform.set_parameters(parameters) - self.state.params_values = transform.get_parameters() + self.context.transforms.append({"name": transform_name, "instance": transform_class()}) - def transform_select_ui(self): - with html.Div(trame_server=self.server): - quasar.QSelect( - label="Transform", - v_model=("current_transform",), - options=(self.state.transforms,), - filled=True, - emit_value=True, - map_options=True, - ) + self.update_transforms_values() - def transform_params_ui(self): - with html.Div(trame_server=self.server): - ParamsWidget( - values=("params_values",), - descriptions=("params_descriptions",), - valuesChanged=(self.on_transform_parameters_changed, "[$event]"), - ) + def on_remove_transform(self, i, **kwargs): + if i >= len(self.context.transforms): + return + + self.context.transforms.pop(i) + + self.update_transforms_values() + + def on_type_changed(self, event): + i = event["id"] + transform_name = event["type"] + transform_class = self._transform_classes.get(transform_name) + + if i >= len(self.context.transforms) or transform_class is None: + return + + self.context.transforms[i] = {"name": transform_name, "instance": transform_class()} + + self.update_transforms_values() + + def on_params_changed(self, event): + i = event["id"] + params = event["params"] + + if i >= len(self.context.transforms): + return + + transform: ImageTransform = self.context.transforms[i]["instance"] + transform.set_parameters(params) + + self.update_transforms_values() + + def update_transforms_descriptions(self): + transform_descriptions = { + transform_name: transform_class.get_parameters_description() + for transform_name, transform_class in self._transform_classes.items() + } + + with self.state: + self.state.transform_descriptions = transform_descriptions + + def update_transforms_values(self): + def serialize_transform(item): + name = item["name"] + transform = item["instance"] + return {"name": name, "parameters": transform.get_parameters()} + + state_transforms = list(map(serialize_transform, self.context.transforms)) + + with self.state: + self.state.transforms = state_transforms def transform_apply_ui(self): with html.Div(trame_server=self.server): quasar.QBtn( - "Apply", + "Apply Transforms", click=(self.server.controller.apply_transform), classes="full-width", flat=True, ) + def transforms_ui(self): + with html.Div(trame_server=self.server): + TransformsWidget( + values=("transforms",), + descriptions=("transform_descriptions",), + add_transform=(self.on_add_transform, "[$event]"), + remove_transform=(self.on_remove_transform, "[$event]"), + type_changed=(self.on_type_changed, "[$event]"), + params_changed=(self.on_params_changed, "[$event]"), + ) + @property def ui(self): if self._ui is None: @@ -89,14 +136,9 @@ def ui(self): v_model=("leftDrawerOpen", True), side="left", elevated=True, + width="500", ): - self.transform_select_ui() - - with html.Div( - classes="q-pa-md q-ma-md", - style="border-style: solid; border-width: thin; border-radius: 0.5rem; border-color: lightgray;", - ): - self.transform_params_ui() + self.transforms_ui() self.transform_apply_ui() diff --git a/src/nrtk_explorer/app/transforms.py b/src/nrtk_explorer/app/transforms.py index 879b5d6f..14406329 100644 --- a/src/nrtk_explorer/app/transforms.py +++ b/src/nrtk_explorer/app/transforms.py @@ -112,6 +112,9 @@ def __init__( known_args, _ = self.server.cli.parse_known_args() self.state.inference_models = known_args.models self.state.object_detection_model = self.state.inference_models[0] + self.state.setdefault("image_list_ids", []) + self.state.setdefault("dataset_ids", []) + self.state.setdefault("user_selected_ids", []) self.images = images or Images(server) @@ -161,23 +164,24 @@ def delete_meta_state(old_ids, new_ids): self._on_transform_fn = None - self._transforms: Dict[str, trans.ImageTransform] = { - "blur": trans.GaussianBlurTransform(), - "invert": trans.InvertTransform(), - "downsample": trans.DownSampleTransform(), - "identity": trans.IdentityTransform(), + self._transform_classes: Dict[str, type[trans.ImageTransform]] = { + "blur": trans.GaussianBlurTransform, + "invert": trans.InvertTransform, + "downsample": trans.DownSampleTransform, + "identity": trans.IdentityTransform, } if nrtk_trans.nrtk_transforms_available(): - self._transforms["nrtk_pybsm"] = nrtk_trans.NrtkPybsmTransform() + self._transform_classes["nrtk_pybsm"] = nrtk_trans.NrtkPybsmTransform # Add transform from YAML definition - self._transforms.update(nrtk_yaml.generate_transforms()) + self._transform_classes.update(nrtk_yaml.generate_transforms()) - self._parameters_app._transforms = self._transforms + self._parameters_app._transform_classes = self._transform_classes - self.state.transforms = [k for k in self._transforms.keys()] - self.state.current_transform = self.state.transforms[0] + # Initialize the transforms pipeline to the identity + self._parameters_app._default_transform = "blur" + self._parameters_app.on_add_transform() init_visible_columns(self.state) @@ -260,7 +264,8 @@ async def _update_transformed_images(self, dataset_ids): if not self.state.transform_enabled: return - transform = self._transforms[self.state.current_transform] + transforms = list(map(lambda t: t["instance"], self.context.transforms)) + transform = trans.ChainedImageTransform(transforms) id_to_matching_size_img = {} for id in dataset_ids: @@ -399,9 +404,7 @@ def on_hover(self, hover_event): def settings_widget(self): with html.Div(classes="col"): - self._parameters_app.transform_select_ui() - with html.Div(classes="q-pa-md q-ma-md"): - self._parameters_app.transform_params_ui() + self._parameters_app.transforms_ui() def apply_ui(self): with html.Div(): diff --git a/src/nrtk_explorer/library/nrtk_transforms.py b/src/nrtk_explorer/library/nrtk_transforms.py index 2af6d4de..973734a6 100644 --- a/src/nrtk_explorer/library/nrtk_transforms.py +++ b/src/nrtk_explorer/library/nrtk_transforms.py @@ -164,7 +164,8 @@ def set_parameters(self, params: Dict[str, Any]): self._perturber.sensor.D = params["D"] self._perturber.sensor.f = params["f"] - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: aperture_description: ParameterDescription = { "type": "float", "label": "Effective Aperture (m)", diff --git a/src/nrtk_explorer/library/transforms.py b/src/nrtk_explorer/library/transforms.py index 1bd7af69..d272e462 100644 --- a/src/nrtk_explorer/library/transforms.py +++ b/src/nrtk_explorer/library/transforms.py @@ -31,8 +31,8 @@ def get_parameters(self) -> Dict[str, Any]: def set_parameters(self, params: Dict[str, Any]): raise NotImplementedError() - @abc.abstractmethod - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: raise NotImplementedError() @abc.abstractmethod @@ -44,6 +44,35 @@ class ImageTransform(Transform[Image, Image]): pass +class ChainedImageTransform(ImageTransform): + def __init__(self, transforms: list[ImageTransform]): + self.transforms = transforms + + def execute(self, input: Image, *input_args: Any) -> Image: + output = input + + for transform in self.transforms: + output = transform.execute(output, *input_args) + + return output + + def get_parameters(self) -> Dict[str, Any]: + raise NotImplementedError( + "Set/Get parameters on the individual transforms making up the ChainedImageTransform" + ) + + def set_parameters(self, params: Dict[str, Any]): + raise NotImplementedError( + "Set/Get parameters on the individual transforms making up the ChainedImageTransform" + ) + + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: + raise NotImplementedError( + "Set/Get parameters on the individual transforms making up the ChainedImageTransform" + ) + + class IdentityTransform(ImageTransform): def get_parameters(self) -> Dict[str, Any]: return {} @@ -51,7 +80,8 @@ def get_parameters(self) -> Dict[str, Any]: def set_parameters(self, params: Dict[str, Any]): pass - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: return {} def execute(self, input: Image, *input_args: Any) -> Image: @@ -70,7 +100,8 @@ def get_parameters(self) -> Dict[str, Any]: def set_parameters(self, params: Dict[str, Any]): self._radius = params.get("radius", GaussianBlurTransform.default_radius) - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: radius_description: ParameterDescription = { "type": "integer", "default": GaussianBlurTransform.default_radius, @@ -94,7 +125,8 @@ def get_parameters(self) -> dict[str, Any]: def set_parameters(self, params: Dict[str, Any]): pass - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: return {} def execute(self, input: Image, *input_args: Any) -> Image: @@ -108,7 +140,8 @@ def get_parameters(self) -> dict[str, Any]: def set_parameters(self, params: Dict[str, Any]): pass - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: return {} def execute(self, input: Image, *input_args: Any) -> Image: @@ -146,7 +179,8 @@ def set_parameters(self, params: Dict[str, Any]): self._boolean_value = params.get("boolean_param", TestTransform.default_boolean) self._select_value = params.get("select_param", TestTransform.default_select) - def get_parameters_description(self) -> Dict[str, ParameterDescription]: + @classmethod + def get_parameters_description(cls) -> Dict[str, ParameterDescription]: string_description: ParameterDescription = { "type": "string", "default": TestTransform.default_string, diff --git a/src/nrtk_explorer/library/yaml_transforms.py b/src/nrtk_explorer/library/yaml_transforms.py index 011d823f..e1f33652 100644 --- a/src/nrtk_explorer/library/yaml_transforms.py +++ b/src/nrtk_explorer/library/yaml_transforms.py @@ -6,8 +6,6 @@ from pathlib import Path from yaml import load, Loader -from nrtk_explorer.library.transforms import ImageTransform - TRANSFORM_FILE = Path(__file__).with_name("nrtk_transforms.yaml").resolve() if "NRTK_TRANSFORM_DEFINITION" in os.environ: @@ -34,7 +32,7 @@ def generate_transforms(): transforms = {} for k, v in TRANSFORM_DEFINITIONS.items(): with contextlib.suppress(ModuleNotFoundError): - transforms[k] = GenericPerturber(v) + transforms[k] = MetaYamlPerturber(k, (), {}, v) return transforms @@ -77,7 +75,7 @@ def set_value(obj, attr_path, value): # ----------------------------------------------------------------------------- -def create_perturber_instance(klass, kwargs): +def get_perturber_constructor(klass, kwargs): klass = get(klass) if isinstance(kwargs, str): @@ -89,21 +87,50 @@ def create_perturber_instance(klass, kwargs): if not isinstance(kwargs, dict): raise ValueError(f"kwarg must lead to a dict but got {type(kwargs)}") - return klass(**kwargs) + return klass, kwargs # ----------------------------------------------------------------------------- -class GenericPerturber(ImageTransform): - def __init__(self, config): - self.description = config.get("description") - self.exec_args = config.get("exec_default_args", []) +class MetaYamlPerturber(type): + @classmethod + def __prepare__(metacls, name, bases, config=None): + return super().__prepare__(name, bases, config=config) + + def __new__(metacls, name, bases, namespace, config=None): + return super().__new__(metacls, name, bases, namespace) + + def __init__(cls, name, bases, namespace, config=None): + if config is None: + raise TypeError("MetaYamlPerturber: configuration is missing.") + + # add class variables + setattr(cls, "description", config.get("description")) + setattr(cls, "exec_args", config.get("exec_default_args", [])) + perturber_class, perturber_kwargs = get_perturber_constructor( + config.get("perturber"), config.get("perturber_kwargs", {}) + ) + setattr(cls, "perturber_class", perturber_class) + setattr(cls, "perturber_kwargs", perturber_kwargs) + + # class methods + setattr(cls, "__init__", MetaYamlPerturber.instance_init) + setattr(cls, MetaYamlPerturber.get_parameters.__name__, MetaYamlPerturber.get_parameters) + setattr(cls, MetaYamlPerturber.set_parameters.__name__, MetaYamlPerturber.set_parameters) + setattr(cls, MetaYamlPerturber.execute.__name__, MetaYamlPerturber.execute) + setattr( + cls, + MetaYamlPerturber.get_parameters_description.__name__, + classmethod(MetaYamlPerturber.get_parameters_description), + ) + + super().__init__(name, bases, namespace) + + # Methods that will be defined on the dynamic YamlPerturber classes - # klass - klass = config.get("perturber") - kwargs = config.get("perturber_kwargs", {}) - self._perturber = create_perturber_instance(klass, kwargs) + def instance_init(self): + self._perturber = self.perturber_class(**self.perturber_kwargs) def get_parameters(self): params = {} @@ -117,8 +144,8 @@ def set_parameters(self, params): attr_path = self.description.get(k).get("_path", [k]) set_value(self._perturber, attr_path, v) - def get_parameters_description(self): - return self.description + def get_parameters_description(cls): + return cls.description def execute(self, input, *input_args): if len(input_args) == 0: diff --git a/src/nrtk_explorer/widgets/nrtk_explorer.py b/src/nrtk_explorer/widgets/nrtk_explorer.py index 6df5a65f..86ad7e86 100644 --- a/src/nrtk_explorer/widgets/nrtk_explorer.py +++ b/src/nrtk_explorer/widgets/nrtk_explorer.py @@ -42,6 +42,24 @@ def __init__(self, **kwargs): self._event_names += ["valuesChanged"] +class TransformsWidget(HtmlElement): + def __init__(self, **kwargs): + super().__init__( + "transforms-widget", + **kwargs, + ) + self._attr_names += [ + "values", + "descriptions", + ] + self._event_names += [ + ("add_transform", "addTransform"), + ("remove_transform", "removeTransform"), + ("type_changed", "typeChanged"), + ("params_changed", "paramsChanged"), + ] + + class FilterOptionsWidget(HtmlElement): def __init__(self, **kwargs): super().__init__( diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 24d449b4..1c758c95 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -7,13 +7,13 @@ def test_gaussian_blur(): transforms = generate_transforms() - blur = transforms["nrtk_cv2_gauss_blur"] + blur = transforms["nrtk_cv2_gauss_blur"]() blur.set_parameters({"ksize": 3}) blur.execute(get_image()) def test_pybsm(): transforms = generate_transforms() - pybsm = transforms["nrtk_pybsm_2"] + pybsm = transforms["nrtk_pybsm_2"]() pybsm.set_parameters({"D": 0.25, "f": 4.0}) pybsm.execute(get_image()) diff --git a/vue-components/src/components/ParamWidget.vue b/vue-components/src/components/ParamWidget.vue index 4a20ea69..5567229b 100644 --- a/vue-components/src/components/ParamWidget.vue +++ b/vue-components/src/components/ParamWidget.vue @@ -1,6 +1,4 @@ + + diff --git a/vue-components/src/components/TransformsWidget.vue b/vue-components/src/components/TransformsWidget.vue new file mode 100644 index 00000000..9fb0730e --- /dev/null +++ b/vue-components/src/components/TransformsWidget.vue @@ -0,0 +1,34 @@ + + + diff --git a/vue-components/src/components/index.js b/vue-components/src/components/index.js index 5f1bbba1..444762d1 100644 --- a/vue-components/src/components/index.js +++ b/vue-components/src/components/index.js @@ -1,11 +1,13 @@ import ScatterPlot from './ScatterPlot.vue' import ParamsWidget from './ParamsWidget.vue' +import TransformsWidget from './TransformsWidget.vue' import FilterOptionsWidget from './FilterOptionsWidget.vue' import FilterOperatorWidget from './FilterOperatorWidget.vue' export default { scatterPlot: ScatterPlot, paramsWidget: ParamsWidget, + transformsWidget: TransformsWidget, filterOptionsWidget: FilterOptionsWidget, filterOperatorWidget: FilterOperatorWidget } diff --git a/vue-components/src/types/index.ts b/vue-components/src/types/index.ts index 02efe4ab..97566ced 100644 --- a/vue-components/src/types/index.ts +++ b/vue-components/src/types/index.ts @@ -24,3 +24,10 @@ export type ParameterDescription = { default?: ParameterValue options?: ParameterValue[] } + +export type TransformDescription = { [paramName: string]: ParameterDescription } + +export type TransformValue = { + name: string + parameters: { [name: string]: ParameterValue } +}