Skip to content

Commit

Permalink
feat(transforms): add ability to add multiple transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
alesgenova committed Nov 19, 2024
1 parent d2fd39d commit bad9106
Show file tree
Hide file tree
Showing 13 changed files with 317 additions and 100 deletions.
128 changes: 85 additions & 43 deletions src/nrtk_explorer/app/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from trame.app import get_server

from nrtk_explorer.widgets.nrtk_explorer import ParamsWidget
from nrtk_explorer.widgets.nrtk_explorer import TransformsWidget

from nrtk_explorer.library.transforms import (
ImageTransform,
Expand All @@ -21,64 +21,111 @@ class ParametersApp(Applet):
def __init__(self, server):
super().__init__(server)

self.state.current_transform = "TestTransform"
self.context.setdefault("transforms", [])
self.state.setdefault("transforms", [])
self.state.setdefault("transform_descriptions", {})

self._transforms: Dict[str, ImageTransform] = {
"TestTransform": TestTransform(),
"GaussianBlurTransform": GaussianBlurTransform(),
"IdentityTransform": IdentityTransform(),
self._transform_classes: Dict[str, type[ImageTransform]] = {
"TestTransform": TestTransform,
"GaussianBlurTransform": GaussianBlurTransform,
"IdentityTransform": IdentityTransform,
}

self.state.transforms = [k for k in self._transforms.keys()]
self.state.current_transform = self.state.transforms[0]
self._default_transform = None

self.server.controller.add("on_server_ready")(self.on_server_ready)

self._ui = None

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.on_current_transform_change()
self.state.change("current_transform")(self.on_current_transform_change)
self.update_transforms_descriptions()

def on_current_transform_change(self, **kwargs):
transform = self._transforms[self.state.current_transform]
self.state.params_values = transform.get_parameters()
self.state.params_descriptions = transform.get_parameters_description()
def on_add_transform(self, *args, **kwargs):
if self._default_transform in self._transform_classes:
transform_name, transform_class = (
self._default_transform,
self._transform_classes[self._default_transform],
)
else:
transform_name, transform_class = next(iter(self._transform_classes.items()))

def on_transform_parameters_changed(self, parameters, **kwargs):
transform = self._transforms[self.state.current_transform]
transform.set_parameters(parameters)
self.state.params_values = transform.get_parameters()
self.context.transforms.append({"name": transform_name, "instance": transform_class()})

def transform_select_ui(self):
with html.Div(trame_server=self.server):
quasar.QSelect(
label="Transform",
v_model=("current_transform",),
options=(self.state.transforms,),
filled=True,
emit_value=True,
map_options=True,
)
self.update_transforms_values()

def transform_params_ui(self):
with html.Div(trame_server=self.server):
ParamsWidget(
values=("params_values",),
descriptions=("params_descriptions",),
valuesChanged=(self.on_transform_parameters_changed, "[$event]"),
)
def on_remove_transform(self, i, **kwargs):
if i >= len(self.context.transforms):
return

self.context.transforms.pop(i)

self.update_transforms_values()

def on_type_changed(self, event):
i = event["id"]
transform_name = event["type"]
transform_class = self._transform_classes.get(transform_name)

if i >= len(self.context.transforms) or transform_class is None:
return

self.context.transforms[i] = {"name": transform_name, "instance": transform_class()}

self.update_transforms_values()

def on_params_changed(self, event):
i = event["id"]
params = event["params"]

if i >= len(self.context.transforms):
return

transform: ImageTransform = self.context.transforms[i]["instance"]
transform.set_parameters(params)

self.update_transforms_values()

def update_transforms_descriptions(self):
transform_descriptions = {
transform_name: transform_class.get_parameters_description()
for transform_name, transform_class in self._transform_classes.items()
}

with self.state:
self.state.transform_descriptions = transform_descriptions

def update_transforms_values(self):
def serialize_transform(item):
name = item["name"]
transform = item["instance"]
return {"name": name, "parameters": transform.get_parameters()}

state_transforms = list(map(serialize_transform, self.context.transforms))

with self.state:
self.state.transforms = state_transforms

def transform_apply_ui(self):
with html.Div(trame_server=self.server):
quasar.QBtn(
"Apply",
"Apply Transforms",
click=(self.server.controller.apply_transform),
classes="full-width",
flat=True,
)

def transforms_ui(self):
with html.Div(trame_server=self.server):
TransformsWidget(
values=("transforms",),
descriptions=("transform_descriptions",),
add_transform=(self.on_add_transform, "[$event]"),
remove_transform=(self.on_remove_transform, "[$event]"),
type_changed=(self.on_type_changed, "[$event]"),
params_changed=(self.on_params_changed, "[$event]"),
)

@property
def ui(self):
if self._ui is None:
Expand All @@ -89,14 +136,9 @@ def ui(self):
v_model=("leftDrawerOpen", True),
side="left",
elevated=True,
width="500",
):
self.transform_select_ui()

with html.Div(
classes="q-pa-md q-ma-md",
style="border-style: solid; border-width: thin; border-radius: 0.5rem; border-color: lightgray;",
):
self.transform_params_ui()
self.transforms_ui()

self.transform_apply_ui()

Expand Down
31 changes: 17 additions & 14 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def __init__(
known_args, _ = self.server.cli.parse_known_args()
self.state.inference_models = known_args.models
self.state.object_detection_model = self.state.inference_models[0]
self.state.setdefault("image_list_ids", [])
self.state.setdefault("dataset_ids", [])
self.state.setdefault("user_selected_ids", [])

self.images = images or Images(server)

Expand Down Expand Up @@ -161,23 +164,24 @@ def delete_meta_state(old_ids, new_ids):

self._on_transform_fn = None

self._transforms: Dict[str, trans.ImageTransform] = {
"blur": trans.GaussianBlurTransform(),
"invert": trans.InvertTransform(),
"downsample": trans.DownSampleTransform(),
"identity": trans.IdentityTransform(),
self._transform_classes: Dict[str, type[trans.ImageTransform]] = {
"blur": trans.GaussianBlurTransform,
"invert": trans.InvertTransform,
"downsample": trans.DownSampleTransform,
"identity": trans.IdentityTransform,
}

if nrtk_trans.nrtk_transforms_available():
self._transforms["nrtk_pybsm"] = nrtk_trans.NrtkPybsmTransform()
self._transform_classes["nrtk_pybsm"] = nrtk_trans.NrtkPybsmTransform

# Add transform from YAML definition
self._transforms.update(nrtk_yaml.generate_transforms())
self._transform_classes.update(nrtk_yaml.generate_transforms())

self._parameters_app._transforms = self._transforms
self._parameters_app._transform_classes = self._transform_classes

self.state.transforms = [k for k in self._transforms.keys()]
self.state.current_transform = self.state.transforms[0]
# Initialize the transforms pipeline to the identity
self._parameters_app._default_transform = "blur"
self._parameters_app.on_add_transform()

init_visible_columns(self.state)

Expand Down Expand Up @@ -260,7 +264,8 @@ async def _update_transformed_images(self, dataset_ids):
if not self.state.transform_enabled:
return

transform = self._transforms[self.state.current_transform]
transforms = list(map(lambda t: t["instance"], self.context.transforms))
transform = trans.ChainedImageTransform(transforms)

id_to_matching_size_img = {}
for id in dataset_ids:
Expand Down Expand Up @@ -399,9 +404,7 @@ def on_hover(self, hover_event):

def settings_widget(self):
with html.Div(classes="col"):
self._parameters_app.transform_select_ui()
with html.Div(classes="q-pa-md q-ma-md"):
self._parameters_app.transform_params_ui()
self._parameters_app.transforms_ui()

def apply_ui(self):
with html.Div():
Expand Down
3 changes: 2 additions & 1 deletion src/nrtk_explorer/library/nrtk_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def set_parameters(self, params: Dict[str, Any]):
self._perturber.sensor.D = params["D"]
self._perturber.sensor.f = params["f"]

def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
aperture_description: ParameterDescription = {
"type": "float",
"label": "Effective Aperture (m)",
Expand Down
48 changes: 41 additions & 7 deletions src/nrtk_explorer/library/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def get_parameters(self) -> Dict[str, Any]:
def set_parameters(self, params: Dict[str, Any]):
raise NotImplementedError()

@abc.abstractmethod
def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
raise NotImplementedError()

@abc.abstractmethod
Expand All @@ -44,14 +44,44 @@ class ImageTransform(Transform[Image, Image]):
pass


class ChainedImageTransform(ImageTransform):
def __init__(self, transforms: list[ImageTransform]):
self.transforms = transforms

def execute(self, input: Image, *input_args: Any) -> Image:
output = input

for transform in self.transforms:
output = transform.execute(output, *input_args)

return output

def get_parameters(self) -> Dict[str, Any]:
raise NotImplementedError(
"Set/Get parameters on the individual transforms making up the ChainedImageTransform"
)

def set_parameters(self, params: Dict[str, Any]):
raise NotImplementedError(
"Set/Get parameters on the individual transforms making up the ChainedImageTransform"
)

@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
raise NotImplementedError(
"Set/Get parameters on the individual transforms making up the ChainedImageTransform"
)


class IdentityTransform(ImageTransform):
def get_parameters(self) -> Dict[str, Any]:
return {}

def set_parameters(self, params: Dict[str, Any]):
pass

def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
return {}

def execute(self, input: Image, *input_args: Any) -> Image:
Expand All @@ -70,7 +100,8 @@ def get_parameters(self) -> Dict[str, Any]:
def set_parameters(self, params: Dict[str, Any]):
self._radius = params.get("radius", GaussianBlurTransform.default_radius)

def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
radius_description: ParameterDescription = {
"type": "integer",
"default": GaussianBlurTransform.default_radius,
Expand All @@ -94,7 +125,8 @@ def get_parameters(self) -> dict[str, Any]:
def set_parameters(self, params: Dict[str, Any]):
pass

def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
return {}

def execute(self, input: Image, *input_args: Any) -> Image:
Expand All @@ -108,7 +140,8 @@ def get_parameters(self) -> dict[str, Any]:
def set_parameters(self, params: Dict[str, Any]):
pass

def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
return {}

def execute(self, input: Image, *input_args: Any) -> Image:
Expand Down Expand Up @@ -146,7 +179,8 @@ def set_parameters(self, params: Dict[str, Any]):
self._boolean_value = params.get("boolean_param", TestTransform.default_boolean)
self._select_value = params.get("select_param", TestTransform.default_select)

def get_parameters_description(self) -> Dict[str, ParameterDescription]:
@classmethod
def get_parameters_description(cls) -> Dict[str, ParameterDescription]:
string_description: ParameterDescription = {
"type": "string",
"default": TestTransform.default_string,
Expand Down
Loading

0 comments on commit bad9106

Please sign in to comment.