Commit f7e48678 authored by Janne Hellsten's avatar Janne Hellsten
Browse files

Add --allow-tf32 perf tuning argument that can be used to enable tf32

Defaults to keeping tf32 disabled.  This is because we haven't fully
verified training results with fp32 enabled.
parent d3a616a9
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -65,5 +65,6 @@ Options:
  --fp32 BOOL                     Disable mixed-precision training
  --nhwc BOOL                     Use NHWC memory format with FP16
  --nobench BOOL                  Disable cuDNN benchmarking
  --allow-tf32 BOOL               Allow PyTorch to use TF32 internally
  --workers INT                   Override number of DataLoader workers
  --help                          Show this message and exit.
+8 −0
Original line number Diff line number Diff line
@@ -61,6 +61,7 @@ def setup_training_loop_kwargs(
    # Performance options (not included in desc).
    fp32       = None, # Disable mixed-precision training: <bool>, default = False
    nhwc       = None, # Use NHWC memory format with FP16: <bool>, default = False
    allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
    nobench    = None, # Disable cuDNN benchmarking: <bool>, default = False
    workers    = None, # Override number of DataLoader workers: <int>, default = 3
):
@@ -343,6 +344,12 @@ def setup_training_loop_kwargs(
    if nobench:
        args.cudnn_benchmark = False

    if allow_tf32 is None:
        allow_tf32 = False
    assert isinstance(allow_tf32, bool)
    if allow_tf32:
        args.allow_tf32 = True

    if workers is not None:
        assert isinstance(workers, int)
        if not workers >= 1:
@@ -425,6 +432,7 @@ class CommaSeparatedList(click.ParamType):
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')

def main(ctx, outdir, dry_run, **config_kwargs):
+3 −0
Original line number Diff line number Diff line
@@ -115,6 +115,7 @@ def training_loop(
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = disable.
    resume_pkl              = None,     # Network pickle to resume training from.
    cudnn_benchmark         = True,     # Enable torch.backends.cudnn.benchmark?
    allow_tf32              = False,    # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
    abort_fn                = None,     # Callback function for determining whether to abort training. Must return consistent results across ranks.
    progress_fn             = None,     # Callback function for updating training progress. Called for all ranks.
):
@@ -124,6 +125,8 @@ def training_loop(
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.benchmark = cudnn_benchmark    # Improves training speed.
    torch.backends.cuda.matmul.allow_tf32 = allow_tf32  # Allow PyTorch to internally use tf32 for matmul
    torch.backends.cudnn.allow_tf32 = allow_tf32        # Allow PyTorch to internally use tf32 for convolutions
    conv2d_gradfix.enabled = True                       # Improves training speed.
    grid_sample_gradfix.enabled = True                  # Avoids errors with the augmentation pipe.