00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051 #include "gnn_momentum.h"
00052
00053
00054
00055
00056
00057
00058
00059 typedef struct _gnn_momentum gnn_momentum;
00060
00061 struct _gnn_momentum
00062 {
00063 gnn_trainer trainer;
00064 gsl_vector *w;
00065 gsl_vector *mw;
00066 double mu;
00067 double eta;
00068 };
00069
00070
00071 static int
00072 gnn_momentum_train (gnn_trainer *trainer);
00073
00074 static void
00075 gnn_momentum_destroy (gnn_trainer *trainer);
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090 static int
00091 gnn_momentum_reset (gnn_trainer *trainer)
00092 {
00093 gnn_momentum *mtrainer;
00094
00095 assert (trainer != NULL);
00096
00097 mtrainer = (gnn_momentum *) trainer;
00098
00099
00100 gnn_node_param_get (trainer->node, mtrainer->w);
00101
00102
00103 gsl_vector_set_zero (mtrainer->mw);
00104
00105 return 0;
00106 }
00107
00108
00109
00110
00111
00112
00113
00114
00115 static int
00116 gnn_momentum_train (gnn_trainer *trainer)
00117 {
00118 gsl_vector *dw;
00119 gnn_momentum *mtrainer;
00120
00121
00122 mtrainer = (gnn_momentum *) trainer;
00123
00124
00125 gnn_trainer_batch_process (trainer);
00126
00127
00128 dw = gnn_trainer_batch_get_dw (trainer);
00129
00130
00131 gnn_trainer_batch_next (trainer);
00132
00133
00134 gsl_vector_scale (dw, - mtrainer->mu);
00135
00136
00137 gsl_vector_scale (mtrainer->mw, mtrainer->eta);
00138 gsl_vector_add (mtrainer->mw, dw);
00139
00140
00141 gsl_vector_add (mtrainer->w, mtrainer->mw);
00142
00143
00144 gnn_node_param_set (trainer->node, mtrainer->w);
00145
00146 return 0;
00147 }
00148
00149
00150
00151
00152
00153
00154
00155 static void
00156 gnn_momentum_destroy (gnn_trainer *trainer)
00157 {
00158 gnn_momentum *mtrainer;
00159
00160 assert (trainer != NULL);
00161
00162 mtrainer = (gnn_momentum *) trainer;
00163
00164 if (mtrainer->w != NULL)
00165 gsl_vector_free (mtrainer->w);
00166 if (mtrainer->mw != NULL)
00167 gsl_vector_free (mtrainer->mw);
00168
00169 return;
00170 }
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194 gnn_trainer *
00195 gnn_momentum_new (gnn_node *node,
00196 gnn_criterion *crit,
00197 gnn_dataset *data,
00198 double mu,
00199 double eta)
00200 {
00201 int status;
00202 gnn_trainer *trainer;
00203 gnn_momentum *mtrainer;
00204
00205
00206 if (mu <= 0.0)
00207 {
00208 GSL_ERROR_VAL ("learning factor should be stricly positive",
00209 GSL_EINVAL, NULL);
00210 }
00211
00212 if (eta < 0.0)
00213 {
00214 GSL_ERROR_VAL ("momentum factor should be positive",
00215 GSL_EINVAL, NULL);
00216 }
00217
00218
00219 mtrainer = (gnn_momentum *) malloc (sizeof (gnn_momentum));
00220 if (mtrainer == NULL)
00221 {
00222 GSL_ERROR_VAL ("couldn't allocate memory for gnn_momentum",
00223 GSL_ENOMEM, NULL);
00224 }
00225
00226
00227 trainer = (gnn_trainer *) mtrainer;
00228
00229
00230 status = gnn_trainer_init (trainer,
00231 "gnn_momentum",
00232 node,
00233 crit,
00234 data,
00235 gnn_momentum_reset,
00236 gnn_momentum_train,
00237 gnn_momentum_destroy);
00238 if (status)
00239 {
00240 GSL_ERROR_VAL ("couldn't initialize gnn_momentum",
00241 GSL_EFAILED, NULL);
00242 }
00243
00244
00245 mtrainer->mu = mu;
00246 mtrainer->eta = eta;
00247 mtrainer->w = gsl_vector_alloc (gnn_node_param_get_size (node));
00248 mtrainer->mw = gsl_vector_alloc (gnn_node_param_get_size (node));
00249 if (mtrainer->w == NULL || mtrainer->mw == NULL)
00250 {
00251 gnn_trainer_destroy (trainer);
00252 GSL_ERROR_VAL ("couldn't allocate memory for gnn_momentum",
00253 GSL_ENOMEM, NULL);
00254 }
00255
00256 return trainer;
00257 }
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269 double
00270 gnn_momentum_get_mu (gnn_trainer *trainer)
00271 {
00272 gnn_momentum *mtrainer;
00273
00274 assert (trainer != NULL);
00275
00276 mtrainer = (gnn_momentum *) trainer;
00277 return mtrainer->mu;
00278 }
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292 int
00293 gnn_momentum_set_mu (gnn_trainer *trainer, double mu)
00294 {
00295 gnn_momentum *mtrainer;
00296
00297 assert (trainer != NULL);
00298
00299
00300 if (mu <= 0.0)
00301 GSL_ERROR ("learning factor should be stricly positive", GSL_EINVAL);
00302
00303
00304 mtrainer = (gnn_momentum *) trainer;
00305
00306
00307 mtrainer->mu = mu;
00308
00309 return 0;
00310 }
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322 double
00323 gnn_momentum_get_eta (gnn_trainer *trainer)
00324 {
00325 gnn_momentum *mtrainer;
00326
00327 assert (trainer != NULL);
00328
00329 mtrainer = (gnn_momentum *) trainer;
00330 return mtrainer->eta;
00331 }
00332
00333
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345 int
00346 gnn_momentum_set_eta (gnn_trainer *trainer, double eta)
00347 {
00348 gnn_momentum *mtrainer;
00349
00350 assert (trainer != NULL);
00351
00352
00353 if (eta < 0.0)
00354 GSL_ERROR ("momentum factor should be positive", GSL_EINVAL);
00355
00356
00357 mtrainer = (gnn_momentum *) trainer;
00358
00359
00360 mtrainer->eta = eta;
00361
00362 return 0;
00363 }
00364
00365
00366
00367