Skip to content

Commit

Permalink
Changes external references to `keras.utils.serialize_keras_object/de…
Browse files Browse the repository at this point in the history
…serialize_keras_object` to legacy serialization API in preparation for switching all of Keras to new serialization format.

PiperOrigin-RevId: 502975348
  • Loading branch information
nkovela1 authored and edward-bot committed Jan 19, 2023
1 parent 4cd7215 commit 9df18a8
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
7 changes: 4 additions & 3 deletions edward2/tensorflow/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,16 @@ def get_config(self):


def serialize(initializer):
return tf.keras.utils.serialize_keras_object(initializer)
return tf.keras.utils.legacy.serialize_keras_object(initializer)


def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
return tf.keras.utils.legacy.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='constraints')
printable_module_name='constraints',
)


def get(identifier, value=None):
Expand Down
7 changes: 4 additions & 3 deletions edward2/tensorflow/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,15 +851,16 @@ def get_config(self):


def serialize(initializer):
return tf.keras.utils.serialize_keras_object(initializer)
return tf.keras.utils.legacy.serialize_keras_object(initializer)


def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
return tf.keras.utils.legacy.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='initializers')
printable_module_name='initializers',
)


def get(identifier, value=None):
Expand Down
9 changes: 5 additions & 4 deletions edward2/tensorflow/layers/gaussian_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_config(self):
return {
'variance': self.variance,
'bias': self.bias,
'encoder': tf.keras.utils.serialize_keras_object(self.encoder),
'encoder': tf.keras.utils.legacy.serialize_keras_object(self.encoder),
}


Expand Down Expand Up @@ -250,9 +250,10 @@ def compute_output_shape(self, input_shape):
def get_config(self):
config = {
'units': self.units,
'mean_fn': tf.keras.utils.serialize_keras_object(self.mean_fn),
'covariance_fn': tf.keras.utils.serialize_keras_object(
self.covariance_fn),
'mean_fn': tf.keras.utils.legacy.serialize_keras_object(self.mean_fn),
'covariance_fn': tf.keras.utils.legacy.serialize_keras_object(
self.covariance_fn
),
'conditional_inputs': None, # don't serialize as it can be large
'conditional_outputs': None, # don't serialize as it can be large
}
Expand Down
7 changes: 4 additions & 3 deletions edward2/tensorflow/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,15 +388,16 @@ def get_config(self):


def serialize(initializer):
return tf.keras.utils.serialize_keras_object(initializer)
return tf.keras.utils.legacy.serialize_keras_object(initializer)


def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
return tf.keras.utils.legacy.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='regularizers')
printable_module_name='regularizers',
)


def get(identifier, value=None):
Expand Down

0 comments on commit 9df18a8

Please sign in to comment.