1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
| from torch import nn
class Generator(nn.Module): def __init__(self, z_dim, ):
super().__init__() self.z_dim = z_dim net = []
channels_in = [self.z_dim, 512, 256, 128, 64] channels_out = [512, 256, 128, 64, 3] active = ["R", "R", "R", "R", "tanh"] stride = [1, 2, 2, 2, 2] padding = [0, 1, 1, 1, 1] for i in range(len(channels_in)): net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i], kernel_size=4, stride=stride[i], padding=padding[i], bias=False)) if active[i] == "R": net.append(nn.BatchNorm2d(num_features=channels_out[i])) net.append(nn.ReLU()) elif active[i] == "tanh": net.append(nn.Tanh())
self.generator = nn.Sequential(*net) self.weight_init()
def weight_init(self): for m in self.generator.modules(): if isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight.data, 0, 0.02)
elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight.data, 0, 0.02) nn.init.constant_(m.bias.data, 0)
def forward(self, x): out = self.generator(x) return out
class Discriminator(nn.Module): def __init__(self): super().__init__()
net = [] channels_in = [3, 64, 128, 256, 512] channels_out = [64, 128, 256, 512, 1] padding = [1, 1, 1, 1, 0] active = ["LR", "LR", "LR", "LR", "sigmoid"] for i in range(len(channels_in)): net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i], kernel_size=4, stride=2, padding=padding[i], bias=False)) if i == 0: net.append(nn.LeakyReLU(0.2)) elif active[i] == "LR": net.append(nn.BatchNorm2d(num_features=channels_out[i])) net.append(nn.LeakyReLU(0.2)) elif active[i] == "sigmoid": net.append(nn.Sigmoid())
self.discriminator = nn.Sequential(*net) self.weight_init()
def weight_init(self): for m in self.discriminator.modules(): if isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight.data, 0, 0.02)
elif isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight.data, 0, 0.02) nn.init.constant_(m.bias.data, 0)
def forward(self, x): out = self.discriminator(x) out = out.view(x.size(0), -1) return out
|