00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056 #include "gnn_divergence.h"
00057
00058
00059
00060
00061
00062
00063
00064 static int
00065 gnn_divergence_f (gnn_node *node,
00066 const gsl_vector *x,
00067 const gsl_vector *w,
00068 gsl_vector *y);
00069
00070 static int
00071 gnn_divergence_dx (gnn_node *node,
00072 const gsl_vector *x,
00073 const gsl_vector *w,
00074 const gsl_vector *dy,
00075 gsl_vector *dx);
00076
00077 static int
00078 gnn_divergence_dw (gnn_node *node,
00079 const gsl_vector *x,
00080 const gsl_vector *w,
00081 const gsl_vector *dy,
00082 gsl_vector *dw);
00083
00084 static void
00085 gnn_divergence_destroy (gnn_node *node);
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103 static int
00104 gnn_divergence_f (gnn_node *node,
00105 const gsl_vector *x,
00106 const gsl_vector *w,
00107 gsl_vector *y)
00108 {
00109 size_t n;
00110 size_t m;
00111 size_t k;
00112 gnn_divergence *dnode;
00113
00114
00115 dnode = (gnn_divergence *) node;
00116
00117
00118 n = gnn_node_input_get_size (node);
00119 m = gnn_node_output_get_size (node);
00120
00121
00122 for (k=0; k<m; ++k)
00123 {
00124 size_t i;
00125 double yk;
00126
00127 i = k % n;
00128 yk = gsl_vector_get (x, i);
00129 gsl_vector_set (y, k, yk);
00130 }
00131
00132 return 0;
00133 }
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146 static int
00147 gnn_divergence_dx (gnn_node *node,
00148 const gsl_vector *x,
00149 const gsl_vector *w,
00150 const gsl_vector *dy,
00151 gsl_vector *dx)
00152 {
00153 size_t n;
00154 size_t m;
00155 size_t k;
00156 gnn_divergence *dnode;
00157
00158
00159 dnode = (gnn_divergence *) node;
00160
00161
00162 n = gnn_node_input_get_size (node);
00163 m = gnn_node_output_get_size (node);
00164
00165
00166 gsl_vector_set_zero (dx);
00167 for (k=0; k<m; ++k)
00168 {
00169 size_t i;
00170 double dxi;
00171
00172 i = k % n;
00173 dxi = gsl_vector_get (dx, i);
00174 dxi += gsl_vector_get (dy, k);
00175 gsl_vector_set (dx, i, dxi);
00176 }
00177
00178 return 0;
00179 }
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192 static int
00193 gnn_divergence_dw (gnn_node *node,
00194 const gsl_vector *x,
00195 const gsl_vector *w,
00196 const gsl_vector *dy,
00197 gsl_vector *dw)
00198 {
00199 return 0;
00200 }
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218 gnn_node *
00219 gnn_divergence_new (size_t input_size, size_t output_size)
00220 {
00221 int status;
00222 size_t k;
00223 gnn_node *node;
00224 gnn_divergence *dnode;
00225
00226
00227 if (input_size < 1)
00228 {
00229 GSL_ERROR_VAL ("input size should be strictly positive",
00230 GSL_EINVAL, NULL);
00231 }
00232 if (input_size > output_size)
00233 {
00234 GSL_ERROR_VAL ("output size should be greater than the input size",
00235 GSL_EINVAL, NULL);
00236 }
00237
00238
00239 dnode = (gnn_divergence *) malloc (sizeof (gnn_divergence));
00240 if (dnode == NULL)
00241 {
00242 GSL_ERROR_VAL ("could not allocate memory for gnn_divergence node",
00243 GSL_ENOMEM, NULL);
00244 }
00245
00246
00247 node = (gnn_node *) dnode;
00248
00249
00250 status = gnn_node_init (node,
00251 "gnn_divergence",
00252 gnn_divergence_f,
00253 gnn_divergence_dx,
00254 gnn_divergence_dw,
00255 NULL);
00256 if (status)
00257 {
00258 GSL_ERROR_VAL ("could not initialize gnn_divergence node",
00259 GSL_EFAILED, NULL);
00260 }
00261
00262 status = gnn_node_set_sizes (node, input_size, output_size, 0);
00263 if (status)
00264 {
00265 GSL_ERROR_VAL ("could not set sizes for gnn_divergence node",
00266 GSL_EFAILED, NULL);
00267 }
00268
00269
00270 dnode->rep = output_size / input_size;
00271 dnode->rep += (output_size % input_size == 0)? 0 : 1;
00272
00273 return node;
00274 }
00275
00276