# ------------------------------------------ # General options: gpus, snap, metrics, seed # ------------------------------------------ if gpus isNone: gpus = 1 assertisinstance(gpus, int) ifnot (gpus >= 1and gpus & (gpus - 1) == 0): raise UserError('--gpus must be a power of two') args.num_gpus = gpus
if snap isNone: snap = 50 assertisinstance(snap, int) if snap < 1: raise UserError('--snap must be at least 1') args.image_snapshot_ticks = snap args.network_snapshot_ticks = snap
if metrics isNone: metrics = ['fid50k_full'] assertisinstance(metrics, list) ifnotall(metric_main.is_valid_metric(metric) for metric in metrics): raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) args.metrics = metrics
# Load training set. if rank == 0: print('Loading training set...') training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) if rank == 0: print() print('Num images: ', len(training_set)) print('Image shape:', training_set.image_shape) print('Label shape:', training_set.label_shape) print()
其中的dnnlib.util.construct_class_by_name方法定义如下:
1 2 3
defconstruct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: """Finds the python class with the given name and constructs it with the given arguments.""" return call_func_by_name(*args, func_name=class_name, **kwargs)
用于查找具有给定名称的python类,并用给定参数构造它。对应
1 2
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
这里的注解,得到的training_set是subclass of training.dataset.Dataset。 其中**training_set_kwargs中有: