Skip to content

Commit

Permalink
Switches Keras object serialization to new logic and changes public A…
Browse files Browse the repository at this point in the history
…PI for deserialize_keras_object/serialize_keras_object to the new functions.

PiperOrigin-RevId: 480676373
  • Loading branch information
nkovela1 authored and edward-bot committed Oct 13, 2022
1 parent 5338ae3 commit 73a6da5
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
1 change: 0 additions & 1 deletion edward2/tensorflow/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def serialize(initializer):
def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='constraints')

Expand Down
8 changes: 4 additions & 4 deletions edward2/tensorflow/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def __init__(self,
distribution: Random distribution to use. One of "truncated_normal", or
"untruncated_normal".
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed`
for behavior.
`tf.set_random_seed` for behavior.
Raises:
ValueError: In case of an invalid value for the "scale", mode" or
Expand Down Expand Up @@ -529,7 +528,9 @@ def get_config(self):


class TrainableHeNormal(TrainableNormal):
"""Trainable normal initialized per He et al. 2015, given a ReLU nonlinearity.
"""Trainable normal initialized per He et al.
2015, given a ReLU nonlinearity.
The distribution is initialized to a Normal scaled by `sqrt(2 / fan_in)`,
where `fan_in` is the number of input units. A ReLU nonlinearity is assumed
Expand Down Expand Up @@ -857,7 +858,6 @@ def serialize(initializer):
def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='initializers')

Expand Down
1 change: 0 additions & 1 deletion edward2/tensorflow/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ def serialize(initializer):
def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='regularizers')

Expand Down

0 comments on commit 73a6da5

Please sign in to comment.