From 846e76c2254347ee6b1ba986198b7c3828b700b1 Mon Sep 17 00:00:00 2001 From: Jeremiah Liu Date: Mon, 9 Aug 2021 13:07:37 -0700 Subject: [PATCH] Removes unnecessary ViT-GP hyper-parameters. Due to [pull #489](https://github.com/google/edward2/pull/489) to `edward2.jax.nn.RandomFeatureGaussianProcess`. Some of the special hyper-parameter configs are no longer needed. Therefore we remove them to simplify the model API. PiperOrigin-RevId: 389706866 --- .../experiments/jft300m_vit_base16_sngp.py | 26 +++++------- baselines/jft/sngp.py | 42 +++++++------------ baselines/jft/sngp_test.py | 8 ++-- 3 files changed, 31 insertions(+), 45 deletions(-) diff --git a/baselines/jft/experiments/jft300m_vit_base16_sngp.py b/baselines/jft/experiments/jft300m_vit_base16_sngp.py index 09789f607..728a3cf5c 100644 --- a/baselines/jft/experiments/jft300m_vit_base16_sngp.py +++ b/baselines/jft/experiments/jft300m_vit_base16_sngp.py @@ -40,8 +40,8 @@ def get_config(): pp_common = '|value_range(-1, 1)' pp_common += f'|onehot({config.num_classes})' - # To use ancestor "smearing", use this line instead: - # pp_common += f'|onehot({config.num_classes}, key="labels_extended", key_result="labels") # pylint: disable=line-too-long + # To use ancestor 'smearing', use this line instead: + # pp_common += f'|onehot({config.num_classes}, key='labels_extended', key_result='labels') # pylint: disable=line-too-long pp_common += '|keep("image", "labels")' config.pp_train = 'decode_jpeg_and_inception_crop(224)|flip_lr' + pp_common config.pp_eval = 'decode|resize_small(256)|central_crop(224)' + pp_common @@ -49,8 +49,8 @@ def get_config(): config.log_training_steps = 50 config.log_eval_steps = 1000 - # NOTE: eval is very fast O(seconds) so it's fine to run it often. - config.checkpoint_steps = 1000 + # NOTE: For pretraining, save infrequently to prevent crowding diskspace. + config.checkpoint_steps = 517790 # Model section config.model = ml_collections.ConfigDict() @@ -66,11 +66,11 @@ def get_config(): config.model.classifier = 'token' # Or 'gap' config.model.representation_size = 768 - # GP layer parameters. + # Gaussian process layer parameters. config.gp_layer = ml_collections.ConfigDict() - config.gp_layer.normalize_input = True - config.gp_layer.random_feature_scale = 1. # 1. or None - config.gp_layer.random_feature_stddev = 0.025 # 1. or 0.025 + # Use momentum for pre-training to prevent numeric error when inverting a + # precision matrix accumulated over 300M data. + config.gp_layer.covmat_momentum = .999 # Optimizer section config.optim_name = 'Adam' @@ -82,7 +82,8 @@ def get_config(): # TODO(lbeyer): make a mini-language like preprocessings. config.lr = ml_collections.ConfigDict() - config.lr.base = 8e-4 # LR has to be lower for larger models! + # LR has to be lower for GP layer and on larger models. + config.lr.base = 4e-4 config.lr.warmup_steps = 10_000 config.lr.decay_type = 'linear' config.lr.linear_end = 1e-5 @@ -96,9 +97,4 @@ def get_config(): def get_sweep(hyper): - # lr_grid = [3e-4, 4e-4, 5e-4, 6e-4] - # stddev_grid = [0.01, 0.02, 0.03, 0.04, 0.05] - return hyper.product([ - # hyper.sweep('config.lr.base', lr_grid), - # hyper.sweep('config.gp_layer.random_feature_stddev', stddev_grid) - ]) + return hyper.product([]) diff --git a/baselines/jft/sngp.py b/baselines/jft/sngp.py index 32a588cc8..04afa35f5 100644 --- a/baselines/jft/sngp.py +++ b/baselines/jft/sngp.py @@ -27,7 +27,6 @@ from clu import periodic_actions import flax import flax.jax_utils as flax_utils -import flax.linen as nn import jax import jax.numpy as jnp import ml_collections @@ -72,7 +71,7 @@ def accumulate_gradient_with_states( accum_steps): """Improved version of `u.accumulate_gradient()` that allows for states.""" # This function handles the `loss_and_grad_fn` function which takes a state - # arguement and returns ((losses, states), grads). + # argument and returns ((losses, states), grads). if accum_steps and accum_steps > 1: assert images.shape[0] % accum_steps == 0, ( f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}') @@ -102,27 +101,16 @@ def acc_grad_and_loss(i, l_s_g): def get_gp_kwargs(gp_config): - """Extract keyword arguement parameters for the Gaussian process layer.""" - normalize_input = gp_config.get('normalize_input', True) - kernel_stddev = gp_config.get('random_feature_stddev', 1.) - feature_scale = gp_config.get('random_feature_scale', -1.) + """Extract keyword argument parameters for the Gaussian process layer.""" covmat_momentum = gp_config.get('covmat_momentum', 0.999) - logging.info('gp_config.normalize_input = %s', normalize_input) - logging.info('gp_config.random_feature_stddev = %s', kernel_stddev) - logging.info('gp_config.random_feature_scale = %s', feature_scale) + # Extracts model parameter. logging.info('gp_config.covmat_momentum = %s', covmat_momentum) - - feature_scale = None if feature_scale < 0. else feature_scale - kernel_init = nn.initializers.normal(stddev=kernel_stddev) - hidden_kwargs = dict(feature_scale=feature_scale, kernel_init=kernel_init) + covmat_momentum = None if covmat_momentum < 0. else covmat_momentum covmat_kwargs = dict(momentum=covmat_momentum) - # Assemble into kwargs dictionary. - gp_layer_kwargs = dict( - normalize_input=normalize_input, - hidden_kwargs=hidden_kwargs, - covmat_kwargs=covmat_kwargs) + # Assembles into kwargs dictionary. + gp_layer_kwargs = dict(covmat_kwargs=covmat_kwargs) return gp_layer_kwargs @@ -337,7 +325,7 @@ def representation_fn(params, images, labels, mask, states): @partial(jax.pmap, axis_name='batch', donate_argnums=(0,)) def update_fn(opt, states, lr, images, labels, rng): """Update step.""" - + # TODO(jereliu): Expand to allow precision matrix resetting. measurements = {} if config.get('mixup') and config.mixup.p: @@ -423,17 +411,17 @@ def decay_fn(v, wd): checkpoint['states'], checkpoint['extra']) elif config.get('model_init'): - write_note(f'Initialize model from {config.model_init}...') - raise ValueError( - 'Load from `config.model_init` checkpoint is currently not supported.') + # Load trainable parameters from the checkpoint. + # This does not cause issue for SNGP since all non-trainable parameters + # (random feature, precision matrix, etc) are last-layer parameters that + # should be re-trained during fine-tuning. + write_note(f'Initialize trainable parameters from {config.model_init}...') # TODO(dusenberrymw): Replace and test load function. - # pylint:disable=unreachable loaded = resformer.load(params_cpu, config.model_init, config.get('model')) opt_cpu = opt_cpu.replace(target=loaded) if jax.host_id() == 0: logging.info('Restored parameter overview:') parameter_overview.log_parameter_overview(loaded) - # pylint:enable=unreachable write_note('Kicking off misc stuff...') first_step = int(opt_cpu.state.step) # Might be a DeviceArray type. @@ -482,6 +470,7 @@ def decay_fn(v, wd): mw.step_start(step) with jax.profiler.TraceContext('train_step', step_num=step, _r=1): + # TODO(jereliu): Expand to allow precision matrix resetting. (opt_repl, states_repl, loss_value, rngs_loop, extra_measurements) = update_fn( opt_repl, @@ -505,8 +494,9 @@ def decay_fn(v, wd): # alive while they'll be updated in a future step, creating hard to debug # memory errors (see b/160593526). Also, takes device 0's params only. # We will also do the same for untrainable parameters (`states`). This is - # ok since both `random features` and `predictive covariance` are frozen - # or task-specific parameters that are not important for pre-training. + # ok since `random features` are frozen throughout pre-training, and + # `predictive covariance` are irrelevant for downstream finetuning and + # will be discarded anyway. opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) states_cpu = jax.tree_map(lambda x: np.array(x[0]), states_repl) diff --git a/baselines/jft/sngp_test.py b/baselines/jft/sngp_test.py index 016c05683..0da03f836 100644 --- a/baselines/jft/sngp_test.py +++ b/baselines/jft/sngp_test.py @@ -115,10 +115,10 @@ def get_config(classifier, representation_size): class SNGPTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( - ('token', 2, 1111.4404296875, 16258.519965277777, 0.16999999806284904), - ('token', None, 13992.8515625, 3621.3713107638887, 0.20999999344348907), - ('gap', 2, 8779.61328125, 3998.798285590278, 0.12999999895691872), - ('gap', None, 11279.3515625, 3212.2536892361113, 0.2199999988079071), + ('token', 2, 916.2851, 1954.3369140625, 0.16999999806284904), + ('token', None, 290.0307, 915.987548828125, 0.20999999344348907), + ('gap', 2, 695.6460, 600.8613823784722, 0.12999999895691872), + ('gap', None, 192.9434, 341.7078450520833, 0.2199999988079071), ) def test_sngp_script(self, classifier, representation_size, correct_train_loss, correct_val_loss,