Skip to content

Commit

Permalink
Removes unnecessary ViT-GP hyper-parameters.
Browse files Browse the repository at this point in the history
Due to [pull #489](google/edward2#489) to `edward2.jax.nn.RandomFeatureGaussianProcess`. Some of the special hyper-parameter configs are no longer needed. Therefore we remove them to simplify the model API.

PiperOrigin-RevId: 388484029
  • Loading branch information
jereliu authored and copybara-github committed Aug 9, 2021
1 parent 6fb3245 commit c5b3205
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 45 deletions.
26 changes: 11 additions & 15 deletions baselines/jft/experiments/jft300m_vit_base16_sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ def get_config():

pp_common = '|value_range(-1, 1)'
pp_common += f'|onehot({config.num_classes})'
# To use ancestor "smearing", use this line instead:
# pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels") # pylint: disable=line-too-long
# To use ancestor 'smearing', use this line instead:
# pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long
pp_common += '|keep("image", "labels")'
config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common
config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common
config.shuffle_buffer_size = 250_000 # Per host, so small-ish is ok.

config.log_training_steps = 50
config.log_eval_steps = 1000
# NOTE: eval is very fast O(seconds) so it's fine to run it often.
config.checkpoint_steps = 1000
# NOTE: For pretraining, save infrequently to prevent crowding diskspace.
config.checkpoint_steps = 517790

# Model section
config.model = ml_collections.ConfigDict()
Expand All @@ -66,11 +66,11 @@ def get_config():
config.model.classifier = 'token' # Or 'gap'
config.model.representation_size = 768

# GP layer parameters.
# Gaussian process layer parameters.
config.gp_layer = ml_collections.ConfigDict()
config.gp_layer.normalize_input = True
config.gp_layer.random_feature_scale = 1. # 1. or None
config.gp_layer.random_feature_stddev = 0.025 # 1. or 0.025
# Use momentum for pre-training to prevent numeric error when inverting a
# precision matrix accumulated over 300M data.
config.gp_layer.covmat_momentum = .999

# Optimizer section
config.optim_name = 'Adam'
Expand All @@ -82,7 +82,8 @@ def get_config():

# TODO(lbeyer): make a mini-language like preprocessings.
config.lr = ml_collections.ConfigDict()
config.lr.base = 8e-4 # LR has to be lower for larger models!
# LR has to be lower for GP layer and on larger models.
config.lr.base = 4e-4
config.lr.warmup_steps = 10_000
config.lr.decay_type = 'linear'
config.lr.linear_end = 1e-5
Expand All @@ -96,9 +97,4 @@ def get_config():


def get_sweep(hyper):
# lr_grid = [3e-4, 4e-4, 5e-4, 6e-4]
# stddev_grid = [0.01, 0.02, 0.03, 0.04, 0.05]
return hyper.product([
# hyper.sweep('config.lr.base', lr_grid),
# hyper.sweep('config.gp_layer.random_feature_stddev', stddev_grid)
])
return hyper.product([])
42 changes: 16 additions & 26 deletions baselines/jft/sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from clu import periodic_actions
import flax
import flax.jax_utils as flax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
Expand Down Expand Up @@ -72,7 +71,7 @@ def accumulate_gradient_with_states(
accum_steps):
"""Improved version of `u.accumulate_gradient()` that allows for states."""
# This function handles the `loss_and_grad_fn` function which takes a state
# arguement and returns ((losses, states), grads).
# argument and returns ((losses, states), grads).
if accum_steps and accum_steps > 1:
assert images.shape[0] % accum_steps == 0, (
f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}')
Expand Down Expand Up @@ -102,27 +101,16 @@ def acc_grad_and_loss(i, l_s_g):


def get_gp_kwargs(gp_config):
"""Extract keyword arguement parameters for the Gaussian process layer."""
normalize_input = gp_config.get('normalize_input', True)
kernel_stddev = gp_config.get('random_feature_stddev', 1.)
feature_scale = gp_config.get('random_feature_scale', -1.)
"""Extract keyword argument parameters for the Gaussian process layer."""
covmat_momentum = gp_config.get('covmat_momentum', 0.999)

logging.info('gp_config.normalize_input = %s', normalize_input)
logging.info('gp_config.random_feature_stddev = %s', kernel_stddev)
logging.info('gp_config.random_feature_scale = %s', feature_scale)
# Extracts model parameter.
logging.info('gp_config.covmat_momentum = %s', covmat_momentum)

feature_scale = None if feature_scale < 0. else feature_scale
kernel_init = nn.initializers.normal(stddev=kernel_stddev)
hidden_kwargs = dict(feature_scale=feature_scale, kernel_init=kernel_init)
covmat_momentum = None if covmat_momentum < 0. else covmat_momentum
covmat_kwargs = dict(momentum=covmat_momentum)

# Assemble into kwargs dictionary.
gp_layer_kwargs = dict(
normalize_input=normalize_input,
hidden_kwargs=hidden_kwargs,
covmat_kwargs=covmat_kwargs)
# Assembles into kwargs dictionary.
gp_layer_kwargs = dict(covmat_kwargs=covmat_kwargs)

return gp_layer_kwargs

Expand Down Expand Up @@ -337,7 +325,7 @@ def representation_fn(params, images, labels, mask, states):
@partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
def update_fn(opt, states, lr, images, labels, rng):
"""Update step."""

# TODO(jereliu): Expand to allow precision matrix resetting.
measurements = {}

if config.get('mixup') and config.mixup.p:
Expand Down Expand Up @@ -423,17 +411,17 @@ def decay_fn(v, wd):
checkpoint['states'],
checkpoint['extra'])
elif config.get('model_init'):
write_note(f'Initialize model from {config.model_init}...')
raise ValueError(
'Load from `config.model_init` checkpoint is currently not supported.')
# Load trainable parameters from the checkpoint.
# This does not cause issue for SNGP since all non-trainable parameters
# (random feature, precision matrix, etc) are last-layer parameters that
# should be re-trained during fine-tuning.
write_note(f'Initialize trainable parameters from {config.model_init}...')
# TODO(dusenberrymw): Replace and test load function.
# pylint:disable=unreachable
loaded = resformer.load(params_cpu, config.model_init, config.get('model'))
opt_cpu = opt_cpu.replace(target=loaded)
if jax.host_id() == 0:
logging.info('Restored parameter overview:')
parameter_overview.log_parameter_overview(loaded)
# pylint:enable=unreachable

write_note('Kicking off misc stuff...')
first_step = int(opt_cpu.state.step) # Might be a DeviceArray type.
Expand Down Expand Up @@ -482,6 +470,7 @@ def decay_fn(v, wd):
mw.step_start(step)

with jax.profiler.TraceContext('train_step', step_num=step, _r=1):
# TODO(jereliu): Expand to allow precision matrix resetting.
(opt_repl, states_repl, loss_value, rngs_loop,
extra_measurements) = update_fn(
opt_repl,
Expand All @@ -505,8 +494,9 @@ def decay_fn(v, wd):
# alive while they'll be updated in a future step, creating hard to debug
# memory errors (see b/160593526). Also, takes device 0's params only.
# We will also do the same for untrainable parameters (`states`). This is
# ok since both `random features` and `predictive covariance` are frozen
# or task-specific parameters that are not important for pre-training.
# ok since `random features` are frozen throughout pre-training, and
# `predictive covariance` are irrelevant for downstream finetuning and
# will be discarded anyway.
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl)

Expand Down
8 changes: 4 additions & 4 deletions baselines/jft/sngp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def get_config(classifier, representation_size):
class SNGPTest(parameterized.TestCase, tf.test.TestCase):

@parameterized.parameters(
('token', 2, 1111.4404296875, 16258.519965277777, 0.16999999806284904),
('token', None, 13992.8515625, 3621.3713107638887, 0.20999999344348907),
('gap', 2, 8779.61328125, 3998.798285590278, 0.12999999895691872),
('gap', None, 11279.3515625, 3212.2536892361113, 0.2199999988079071),
('token', 2, 916.2851, 1954.3369140625, 0.16999999806284904),
('token', None, 290.0307, 915.987548828125, 0.20999999344348907),
('gap', 2, 695.6460, 600.8613823784722, 0.12999999895691872),
('gap', None, 192.9434, 341.7078450520833, 0.2199999988079071),
)
def test_sngp_script(self, classifier, representation_size,
correct_train_loss, correct_val_loss,
Expand Down

0 comments on commit c5b3205

Please sign in to comment.