79 >>>>>>>>>>##⽣成器和判别器>>>>>>>>>>>## 8081def extract(v):
82 return v.data.storage().tolist()
83
降压散84def stats(d):
85 return [np.mean(d), np.std(d)]
86
87def get_moments(d):
88 # Return the first 4 moments of the data provided
89 mean = an(d) #⽣成的⾼斯分布求均值
90 diffs = d - mean
91 var = an(torch.pow(diffs, 2.0))
92 std = torch.pow(var, 0.5) #⽣成的⾼斯分布求标准差元素与
93 zscores = diffs / std
94 skews = an(torch.pow(zscores, 3.0))
95 kurtoses = an(torch.pow(zscores, 4.0)) - 3.0 # excess kurtosis, should be 0 for Gaussian
96 final = torch.cat((shape(1,), shape(1,), shape(1,), shape(1,))) #⼀个向量,有四个元素,如代码
97 return final
98
99def decorate_with_diffs(data, exponent, remove_raw_data=False):
100 mean = an(data.data, 1, keepdim=True)
101 mean_broadcast = torch.s(data.size()), list()[0][0])
102 diffs = torch.pow(data - Variable(mean_broadcast), exponent)
103 if remove_raw_data:
104 return torch.cat([diffs], 1)
105 else:
106 return torch.cat([data, diffs], 1)
107
108def train():
109 # Model parameters
110 g_input_size = 1 # Random noise dimension coming into generator, per output vector
111 g_hidden_size = 5 # Generator complexity
112 g_output_size = 1 # Size of generated output vector
113 d_input_size = 500 # Minibatch size - cardinality of distributions
114 d_hidden_size = 10 # Discriminator complexity
115 d_output_size = 1 # Single dimension for 'real' vs. 'fake' classification
116 minibatch_size = d_input_size
117
118 d_learning_rate = 1e-3
119 g_learning_rate = 1e-3
120 sgd_momentum = 0.9
121
122 num_epochs = 5000
123 print_interval = 100
124 d_steps = 20
125 g_steps = 20
126
127 dfe, dre, ge = 0, 0, 0
128 d_real_data, d_fake_data, g_fake_data = None, None, None
129
130 discriminator_activation_function = torch.sigmoid
131 generator_activation_function = torch.tanh
132
133 d_sampler = get_distribution_sampler(data_mean, data_stddev)
134 gi_sampler = get_generator_input_sampler()
135 G = Generator(input_size=g_input_size,