在这里对StyleGAN 2 ADA的源码进行一个粗略的解读,主要目的是获得ADA的使用方法,顺便学习一下代码风格。

EasyDict

1
args = dnnlib.EasyDict()

其中,dnnlib为自定义包,EasyDict是其中定义的一个类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""

def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)

def __setattr__(self, name: str, value: Any) -> None:
self[name] = value

def __delattr__(self, name: str) -> None:
del self[name]

这里__getattr__这里应该输入key返回值的函数,__getattr__是新增键值对的函数,__delattr__是删除键值对的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class Example:
def __init__(self):
self.data = {'a': 1, 'b': 2, 'c': 3}

def __getattr__(self, name):
if name in self.data:
return self.data[name]
else:
raise AttributeError(f"'Example' object has no attribute '{name}'")

# 创建 Example 对象
obj = Example()

# 访问存在的属性
print(obj.a) # 输出: 1
print(obj.b) # 输出: 2

# 访问不存在的属性
print(obj.x) # 引发 AttributeError 异常,提示属性不存在

定义了__getattr__之后,就可以通过.xxxx来当字典用了。

General options: gpus, snap, metrics, seed

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
    # ------------------------------------------
# General options: gpus, snap, metrics, seed
# ------------------------------------------
if gpus is None:
gpus = 1
assert isinstance(gpus, int)
if not (gpus >= 1 and gpus & (gpus - 1) == 0):
raise UserError('--gpus must be a power of two')
args.num_gpus = gpus

if snap is None:
snap = 50
assert isinstance(snap, int)
if snap < 1:
raise UserError('--snap must be at least 1')
args.image_snapshot_ticks = snap
args.network_snapshot_ticks = snap

if metrics is None:
metrics = ['fid50k_full']
assert isinstance(metrics, list)
if not all(metric_main.is_valid_metric(metric) for metric in metrics):
raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
args.metrics = metrics

if seed is None:
seed = 0
assert isinstance(seed, int)
args.random_seed = seed

这里是判断函数的输入是否满足要求,并存入args中,因为args dnnlib.EasyDict()的实例化,而其中定义了__getattr__方法,因此通过args.num_gpus = gpus的方法赋值是对的。

从下面这段代码可以看到setup_training_loop_kwargs整个函数返回的是描述desc与存储参数的args

1
2
def setup_training_loop_kwargs(...)
return desc, args

main函数

最终main函数中调用了subprocess_fn:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def subprocess_fn(rank, args, temp_dir):
dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)

# Init torch.distributed.
if args.num_gpus > 1:
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
if os.name == 'nt':
init_method = 'file:///' + init_file.replace('\\', '/')
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
else:
init_method = f'file://{init_file}'
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)

# Init torch_utils.
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
if rank != 0:
custom_ops.verbosity = 'none'

# Execute training loop.
training_loop.training_loop(rank=rank, **args)

可以看到最后把储存参数的字典args输入到training_loop.training_loop里面去了。

training_loop解读

下面解读training_loop的内容,首先是载入training set

Load training set

1
2
3
4
5
6
7
8
9
10
11
12
# Load training set.
if rank == 0:
print('Loading training set...')
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
if rank == 0:
print()
print('Num images: ', len(training_set))
print('Image shape:', training_set.image_shape)
print('Label shape:', training_set.label_shape)
print()

其中的dnnlib.util.construct_class_by_name方法定义如下:

1
2
3
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
"""Finds the python class with the given name and constructs it with the given arguments."""
return call_func_by_name(*args, func_name=class_name, **kwargs)

用于查找具有给定名称的python类,并用给定参数构造它。对应

1
2
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) 
# subclass of training.dataset.Dataset

这里的注解,得到的training_setsubclass of training.dataset.Dataset
其中**training_set_kwargs中有:

1
2
args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)

1
2
3
4
class Dataset(torch.utils.data.Dataset):
...
class ImageFolderDataset(Dataset):
...

可以看出,training_set是属于ImageFolderDataset类,而ImageFolderDataset类是Dataset的继承,符合注解中的”training_setsubclass of training.dataset.Dataset“。

Setup augmentation

1
2
3
4
5
6
7
8
9
10
# Setup augmentation.
if rank == 0:
print('Setting up augmentation...')
augment_pipe = None
ada_stats = None
if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
augment_pipe.p.copy_(torch.as_tensor(augment_p))
if ada_target is not None:
ada_stats = training_stats.Collector(regex='Loss/signs/real')
1
args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe])

这里说明augment_pipetraining.augment.AugmentPipe类。

1
2
3
4
5
6
7
8
9
10
class AugmentPipe(torch.nn.Module):
def __init__(self,
xflip=0, rotate90=0, xint=0, xint_max=0.125,
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
):
super().__init__()
......
1
2
3
4
5
6
7
8
import torch.nn as nn
import torch
class net(nn.Module):
def __init__(self):
super(net,self).__init__()
self.register_buffer("a",torch.ones(2,3))#从此,self.a其实就是torch.ones(2,3)。
def forward(self,x):
return x+self.a#使用