Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389640780
  • Loading branch information
znado authored and copybara-github committed Aug 9, 2021
1 parent e11fa50 commit 6fb3245
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ stages:
jobs:
include:
- stage: lint
python: "3.6"
python: "3.7"
script:
- set -v # print commands as they're executed
- set -e # fail and exit on any command erroring
- rm -rf *.egg-info/
- pylint --jobs=2 --rcfile=pylintrc *.py
- pylint --jobs=2 --rcfile=pylintrc */
python:
- "3.6"
- "3.7"
install:
- set -v # print commands as they're executed
- set -e # fail and exit on any command erroring
Expand Down
19 changes: 18 additions & 1 deletion baselines/jft/batchensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@
import batchensemble_utils # local file import

# TODO(dusenberrymw): Open-source remaining imports.
u = None
ensemble = None
default_input_pipeline = None
jft_latest_pipeline = None
metric_writers = None
partitioning = None
train = None
experts_utils = None
xprof = None
core = None
metrics = None
ema = None
pp_builder = None
config_flags = None
xm = None
xm_api = None
BIG_VISION_DIR = None


config_flags.DEFINE_config_file(
Expand All @@ -58,7 +75,7 @@
def restore_model_and_put_to_devices(
config: ml_collections.ConfigDict,
output_dir: str,
partition_specs: Sequence[PartitionSpec],
partition_specs: Sequence[partitioning.PartitionSpec],
model: flax.nn.Module,
optimizer: flax.optim.Optimizer,
train_iter: Iterable[Any],
Expand Down
1 change: 1 addition & 0 deletions baselines/jft/batchensemble_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import jax.numpy as jnp

# TODO(dusenberrymw): Open-source remaining imports.
core = None


EvaluationOutput = Tuple[jnp.ndarray, ...]
Expand Down
10 changes: 10 additions & 0 deletions baselines/jft/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
from tensorflow.io import gfile
import uncertainty_baselines as ub

fewshot = None
input_pipeline = None
resformer = None
u = None
pp_builder = None
xm = None
xm_api = None
# TODO(dusenberrymw): Open-source remaining imports.


Expand Down Expand Up @@ -77,11 +84,14 @@ def main(argv):
# tf.data pipeline not being deterministic even if we would set TF seed.
rng = jax.random.PRNGKey(config.get('seed', 0))

xm_xp = None
xm_wu = None
def write_note(note):
if jax.host_id() == 0:
logging.info('NOTE: %s', note)
write_note('Initializing...')

fillin = lambda *_: None
# Verify settings to make sure no checkpoints are accidentally missed.
if config.get('keep_checkpoint_steps'):
assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
Expand Down
1 change: 0 additions & 1 deletion baselines/jft/deterministic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Tests for the deterministic ViT on JFT-300M model script."""
import os
import pathlib
import shutil
import tempfile

from absl import flags
Expand Down
48 changes: 48 additions & 0 deletions baselines/jft/experiments/common_fewshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# coding=utf-8
# Copyright 2021 The Uncertainty Baselines Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Most common few-shot eval configuration."""

import ml_collections


def get_fewshot(batch_size=None, target_resolution=224, resize_resolution=256,
runlocal=False):
"""Returns a standard-ish fewshot eval configuration."""
config = ml_collections.ConfigDict()
if batch_size:
config.batch_size = batch_size
config.representation_layer = 'pre_logits'
config.log_steps = 25_000
config.datasets = { # pylint: disable=g-long-ternary
'birds': ('caltech_birds2011', 'train', 'test'),
'caltech': ('caltech101', 'train', 'test'),
'cars': ('cars196:2.1.0', 'train', 'test'),
'cifar100': ('cifar100', 'train', 'test'),
'col_hist': ('colorectal_histology', 'train[:2000]', 'train[2000:]'),
'dtd': ('dtd', 'train', 'test'),
'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'),
'pets': ('oxford_iiit_pet', 'train', 'test'),
'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
} if not runlocal else {
'pets': ('oxford_iiit_pet', 'train', 'test'),
}
config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
config.shots = [1, 5, 10, 25]
config.l2_regs = [2.0 ** i for i in range(-10, 20)]
config.walk_first = ('imagenet', 10) if not runlocal else ('pets', 10)

return config
2 changes: 1 addition & 1 deletion baselines/jft/experiments/jft300m_vit_base16.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# pylint: enable=line-too-long

import ml_collections
# TODO(dusenberrymw): Open-source remaining imports.
import get_fewshot # local file import


def get_config():
Expand Down
10 changes: 10 additions & 0 deletions baselines/jft/heteroscedastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
import uncertainty_baselines as ub

# TODO(dusenberrymw): Open-source remaining imports.
fewshot = None
input_pipeline = None
resformer = None
u = None
pp_builder = None
xm = None
xm_api = None


ml_collections.config_flags.DEFINE_config_file(
Expand Down Expand Up @@ -77,11 +84,14 @@ def main(argv):
# tf.data pipeline not being deterministic even if we would set TF seed.
rng = jax.random.PRNGKey(config.get('seed', 0))

xm_xp = None
xm_wu = None
def write_note(note):
if jax.host_id() == 0:
logging.info('NOTE: %s', note)
write_note('Initializing...')

fillin = lambda *_: None
# Verify settings to make sure no checkpoints are accidentally missed.
if config.get('keep_checkpoint_steps'):
assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
Expand Down
10 changes: 10 additions & 0 deletions baselines/jft/sngp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
import uncertainty_baselines as ub

# TODO(dusenberrymw): Open-source remaining imports.
fewshot = None
input_pipeline = None
resformer = None
u = None
pp_builder = None
xm = None
xm_api = None


ml_collections.config_flags.DEFINE_config_file(
Expand Down Expand Up @@ -143,11 +150,14 @@ def main(argv):
# tf.data pipeline not being deterministic even if we would set TF seed.
rng = jax.random.PRNGKey(config.get('seed', 0))

xm_xp = None
xm_wu = None
def write_note(note):
if jax.host_id() == 0:
logging.info('NOTE: %s', note)
write_note('Initializing...')

fillin = lambda *_: None
# Verify settings to make sure no checkpoints are accidentally missed.
if config.get('keep_checkpoint_steps'):
assert config.get('checkpoint_steps'), 'Specify `checkpoint_steps`.'
Expand Down
1 change: 0 additions & 1 deletion baselines/jft/sngp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Tests for the ViT-SNGP on JFT-300M model script."""
import os
import pathlib
import shutil
import tempfile

from absl import flags
Expand Down
10 changes: 2 additions & 8 deletions uncertainty_baselines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,5 @@ def _lazy_import(name):
return imported


for module_name in _IMPORTS:
try:
_lazy_import(module_name)
except ModuleNotFoundError:
logging.error(
'Skipped importing top level uncertainty_baselines module %s due to '
'ModuleNotFoundError:', module_name, exc_info=True)

# Lazily load any top level modules when accessed. Requires Python 3.7.
__getattr__ = _lazy_import

0 comments on commit 6fb3245

Please sign in to comment.