00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070 #include <math.h>
00071 #include <gsl/gsl_blas.h>
00072 #include <gsl/gsl_randist.h>
00073 #include "gnn_utilities.h"
00074 #include "gnn_weight.h"
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094 typedef struct _gnn_fir gnn_fir;
00095
00096 struct _gnn_fir
00097 {
00098 gnn_node node;
00099
00100 size_t delay;
00101 size_t offset;
00102
00103 gsl_matrix_view *W_view;
00104 gsl_matrix_view *dW_view;
00105 gsl_matrix_int_view *Wf_view;
00106
00107 gsl_matrix **W;
00108 gsl_matrix **dW;
00109 gsl_matrix_int **Wf;
00110
00111 gsl_matrix *xb;
00112 gsl_vector_view *X_view;
00113 gsl_vector **X;
00114
00115 };
00116
00117
00118 static int
00119 gnn_fir_f (gnn_node *node,
00120 const gsl_vector *x,
00121 const gsl_vector *w,
00122 gsl_vector *y);
00123
00124
00125 static int
00126 gnn_fir_dx (gnn_node *node,
00127 const gsl_vector *x,
00128 const gsl_vector *w,
00129 const gsl_vector *dy,
00130 gsl_vector *dx);
00131
00132 static int
00133 gnn_fir_dw (gnn_node *node,
00134 const gsl_vector *x,
00135 const gsl_vector *w,
00136 const gsl_vector *dy,
00137 gsl_vector *dw);
00138
00139
00140 static void
00141 gnn_fir_destroy (gnn_node *node);
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161 static int
00162 gnn_fir_f (gnn_node *node,
00163 const gsl_vector *x,
00164 const gsl_vector *w,
00165 gsl_vector *y)
00166 {
00167 size_t j;
00168 size_t k;
00169 gnn_fir *fnode;
00170
00171 fnode = (gnn_fir *) node;
00172
00173
00174 fnode->offset = (fnode->offset > 0)? fnode->offset - 1 : fnode->delay;
00175
00176
00177 gsl_vector_memcpy (fnode->X[fnode->offset], x);
00178
00179
00180 for (k=0; k<=fnode->delay; ++k)
00181 {
00182 j = (fnode->offset + k) % (fnode->delay + 1);
00183 gsl_blas_dgemv (CblasNoTrans, 1.0, fnode->W[k], fnode->X[j], 0.0, y);
00184 }
00185
00186 return 0;
00187 }
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201 static int
00202 gnn_fir_dx (gnn_node *node,
00203 const gsl_vector *x,
00204 const gsl_vector *w,
00205 const gsl_vector *dy,
00206 gsl_vector *dx)
00207 {
00208 size_t j;
00209 size_t k;
00210 gnn_fir *fnode;
00211
00212 fnode = (gnn_fir *) node;
00213
00214
00215 gsl_blas_dgemv (CblasTrans, 1.0, fnode->W[0], dy, 0.0, dx);
00216
00217 return 0;
00218 }
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233 static int
00234 gnn_fir_dw (gnn_node *node,
00235 const gsl_vector *x,
00236 const gsl_vector *w,
00237 const gsl_vector *dy,
00238 gsl_vector *dw)
00239 {
00240 size_t j;
00241 size_t k;
00242 size_t input_size;
00243 size_t output_size;
00244 gsl_matrix_view X;
00245 gsl_matrix_view DY;
00246 gnn_fir *fnode;
00247
00248 fnode = (gnn_fir *) node;
00249
00250
00251 input_size = gnn_node_input_get_size (node);
00252 output_size = gnn_node_output_get_size (node);
00253
00254
00255 DY = gsl_matrix_view_vector ((gsl_vector *) dy, output_size, 1);
00256
00257
00258 for (k=0; k<=fnode->delay; ++k)
00259 {
00260 j = (fnode->offset + k) % (fnode->delay + 1);
00261
00262
00263 X = gsl_matrix_view_vector ((gsl_vector *) fnode->X[j], 1, input_size);
00264
00265
00266 gsl_blas_dgemm (CblasNoTrans, CblasNoTrans,
00267 1.0, &(DY.matrix), &(X.matrix), 1.0, fnode->dW[k]);
00268 }
00269
00270 return 0;
00271 }
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282 static void
00283 gnn_fir_destroy (gnn_node *node)
00284 {
00285 gnn_fir *fnode;
00286 assert (node != NULL);
00287
00288 fnode = (gnn_fir *) node;
00289
00290 if (fnode->W_view != NULL) free (fnode->W_view);
00291 if (fnode->dW_view != NULL) free (fnode->dW_view);
00292 if (fnode->Wf_view != NULL) free (fnode->Wf_view);
00293 if (fnode->W != NULL) free (fnode->W);
00294 if (fnode->dW != NULL) free (fnode->dW);
00295 if (fnode->Wf != NULL) free (fnode->Wf);
00296
00297 if (fnode->X_view != NULL) free (fnode->X_view);
00298 if (fnode->X != NULL) free (fnode->X);
00299 if (fnode->xb != NULL) gsl_matrix_free (fnode->xb);
00300
00301 return;
00302 }
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320 gnn_node *
00321 gnn_fir_new (size_t input_size, size_t output_size, size_t delay)
00322 {
00323 size_t k;
00324 int status;
00325 size_t pblock_size;
00326 size_t param_size;
00327 gnn_node *node;
00328 gnn_fir *fnode;
00329 gsl_vector *w;
00330 gsl_vector *dw;
00331 gsl_vector_int *f;
00332
00333
00334 if (input_size < 1 || output_size < 1)
00335 {
00336 GSL_ERROR_VAL ("gnn_fir's sizes should be stricly positive",
00337 GSL_EINVAL, NULL);
00338 }
00339 if (delay < 1)
00340 {
00341 GSL_ERROR_VAL ("gnn_fir's dealy should be stricly positive",
00342 GSL_EINVAL, NULL);
00343 }
00344
00345
00346 pblock_size = input_size * output_size;
00347 param_size = (delay + 1) * pblock_size;
00348
00349
00350 node = (gnn_node *) malloc (sizeof (gnn_fir));
00351 if (node == NULL)
00352 {
00353 GSL_ERROR_VAL ("couldn't allocate memory for gnn_fir's node",
00354 GSL_ENOMEM, NULL);
00355 }
00356
00357
00358 status = gnn_node_init (node,
00359 "gnn_fir",
00360 gnn_fir_f,
00361 gnn_fir_dx,
00362 gnn_fir_dw,
00363 gnn_fir_destroy);
00364 if (status)
00365 {
00366 GSL_ERROR_VAL ("could not initialize gnn_fir node", GSL_EINVAL, NULL);
00367 }
00368
00369 status = gnn_node_set_sizes (node, input_size, output_size, param_size);
00370 if (status)
00371 {
00372 GSL_ERROR_VAL ("could not set sizes for gnn_fir node",
00373 GSL_EINVAL, NULL);
00374 }
00375
00376
00377 fnode = (gnn_fir *) node;
00378 w = gnn_node_local_get_w (node);
00379 dw = gnn_node_local_get_dw (node);
00380 f = gnn_node_local_get_f (node);
00381
00382
00383 fnode->xb = gsl_matrix_calloc (delay + 1, input_size);
00384 if (fnode->xb == NULL)
00385 {
00386 gnn_node_destroy (node);
00387 GSL_ERROR_VAL ("could not allocate FIR buffers",
00388 GSL_EINVAL, NULL);
00389 }
00390
00391
00392 fnode->W_view
00393 = (gsl_matrix_view *) malloc (sizeof (gsl_matrix_view) * (delay + 1));
00394 fnode->dW_view
00395 = (gsl_matrix_view *) malloc (sizeof (gsl_matrix_view) * (delay + 1));
00396 fnode->Wf_view
00397 = (gsl_matrix_int_view *)
00398 malloc (sizeof (gsl_matrix_int_view) * (delay + 1));
00399 fnode->W
00400 = (gsl_matrix **) malloc (sizeof (gsl_matrix *) * (delay + 1));
00401 fnode->dW
00402 = (gsl_matrix **) malloc (sizeof (gsl_matrix *) * (delay + 1));
00403 fnode->Wf
00404 = (gsl_matrix_int **) malloc (sizeof (gsl_matrix_int *) * (delay + 1));
00405
00406 fnode->X_view
00407 = (gsl_vector_view *) malloc (sizeof (gsl_vector_view) * (delay + 1));
00408 fnode->X
00409 = (gsl_vector **) malloc (sizeof (gsl_vector *) * (delay + 1));
00410
00411
00412 if ( (fnode->W_view == NULL)
00413 || (fnode->dW_view == NULL)
00414 || (fnode->Wf_view == NULL)
00415 || (fnode->W == NULL)
00416 || (fnode->dW == NULL)
00417 || (fnode->Wf == NULL)
00418 || (fnode->X_view == NULL)
00419 || (fnode->X == NULL) )
00420 {
00421 gnn_node_destroy (node);
00422 GSL_ERROR_VAL ("could not set sizes for gnn_fir node",
00423 GSL_EINVAL, NULL);
00424 }
00425
00426
00427 for (k=0; k<=delay; ++k)
00428 {
00429 gsl_vector_view v;
00430 gsl_vector_view dv;
00431 gsl_vector_int_view vf;
00432
00433 v = gsl_vector_subvector (w, k * pblock_size, pblock_size);
00434 dv = gsl_vector_subvector (dw, k * pblock_size, pblock_size);
00435 vf = gsl_vector_int_subvector (f, k * pblock_size, pblock_size);
00436
00437 fnode->W_view[k]
00438 = gsl_matrix_view_vector (&(v.vector), output_size, input_size);
00439 fnode->dW_view[k]
00440 = gsl_matrix_view_vector (&(dv.vector), output_size, input_size);
00441 fnode->Wf_view[k]
00442 = gsl_matrix_int_view_vector (&(vf.vector), output_size, input_size);
00443 fnode->W[k] = &(fnode->W_view[k].matrix);
00444 fnode->dW[k] = &(fnode->dW_view[k].matrix);
00445 fnode->Wf[k] = &(fnode->Wf_view[k].matrix);
00446
00447 fnode->X_view[k]
00448 = gsl_matrix_row (fnode->xb, k);
00449 fnode->X[k]
00450 = &(fnode->X_view[k].vector);
00451 }
00452
00453
00454 fnode->delay = delay;
00455 fnode->offset = delay;
00456
00457 return node;
00458 }
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470 int
00471 gnn_fir_init (gnn_node *node)
00472 {
00473 int i;
00474 int l;
00475 gsl_rng *r;
00476 gsl_vector *w;
00477
00478
00479 r = gnn_get_rng ();
00480
00481
00482 w = gnn_node_local_get_w (node);
00483
00484
00485 for (i=0; i<w->size; ++i)
00486 {
00487 double rnd;
00488 rnd = gsl_rng_uniform (r) - 0.5;
00489 gsl_vector_set (w, i, rnd);
00490 }
00491
00492
00493 gnn_node_local_update (node);
00494
00495 return 0;
00496 }
00497