Loading metrics/metric_utils.py +6 −1 Original line number Diff line number Diff line Loading @@ -213,6 +213,8 @@ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_l # Main loop. item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): if images.shape[1] == 1: images = images.repeat([1, 3, 1, 1]) features = detector(images.to(opts.device), **detector_kwargs) stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) progress.update(stats.num_items) Loading Loading @@ -262,7 +264,10 @@ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) images.append(run_generator(z, c)) features = detector(torch.cat(images), **detector_kwargs) images = torch.cat(images) if images.shape[1] == 1: images = images.repeat([1, 3, 1, 1]) features = detector(images, **detector_kwargs) stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) progress.update(stats.num_items) return stats Loading Loading
metrics/metric_utils.py +6 −1 Original line number Diff line number Diff line Loading @@ -213,6 +213,8 @@ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_l # Main loop. item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): if images.shape[1] == 1: images = images.repeat([1, 3, 1, 1]) features = detector(images.to(opts.device), **detector_kwargs) stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) progress.update(stats.num_items) Loading Loading @@ -262,7 +264,10 @@ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)] c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) images.append(run_generator(z, c)) features = detector(torch.cat(images), **detector_kwargs) images = torch.cat(images) if images.shape[1] == 1: images = images.repeat([1, 3, 1, 1]) features = detector(images, **detector_kwargs) stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) progress.update(stats.num_items) return stats Loading