00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include "gnn_prototype.h"
00068 #include "gnn_utilities.h"
00069 #include <gsl/gsl_blas.h>
00070 #include <math.h>
00071
00072
00073
00074
00075
00076
00077
00078 static int
00079 gnn_prototype_f (gnn_node *node,
00080 const gsl_vector *x,
00081 const gsl_vector *w,
00082 gsl_vector *y);
00083
00084 static int
00085 gnn_prototype_dx (gnn_node *node,
00086 const gsl_vector *x,
00087 const gsl_vector *w,
00088 const gsl_vector *dy,
00089 gsl_vector *dx);
00090
00091 static int
00092 gnn_prototype_dw (gnn_node *node,
00093 const gsl_vector *x,
00094 const gsl_vector *w,
00095 const gsl_vector *dy,
00096 gsl_vector *dw);
00097
00098 static void
00099 gnn_prototype_destroy (gnn_node *node);
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119 static int
00120 gnn_prototype_f (gnn_node *node,
00121 const gsl_vector *x,
00122 const gsl_vector *w,
00123 gsl_vector *y)
00124 {
00125 size_t j;
00126 size_t m;
00127 gnn_prototype *pnode;
00128 gsl_vector_view wj;
00129
00130
00131 m = gnn_node_output_get_size (node);
00132
00133
00134 pnode = (gnn_prototype *) node;
00135
00136
00137 for (j=0; j<m; ++j)
00138 {
00139 double yj;
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149 wj = gsl_matrix_row (pnode->D, j);
00150 yj = gnn_vector_euclidian_dist (&(wj.vector), x);
00151
00152
00153 gsl_vector_set (pnode->y, j, yj);
00154 }
00155
00156
00157 gsl_vector_memcpy (y, pnode->y);
00158
00159 return 0;
00160 }
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179 static int
00180 gnn_prototype_dx (gnn_node *node,
00181 const gsl_vector *x,
00182 const gsl_vector *w,
00183 const gsl_vector *dy,
00184 gsl_vector *dx)
00185 {
00186 size_t i;
00187 size_t n;
00188 gnn_prototype *pnode;
00189 gsl_vector_view wi;
00190
00191
00192 n = gnn_node_input_get_size (node);
00193
00194
00195 pnode = (gnn_prototype *) node;
00196
00197
00198 for (i=0; i<n; ++i)
00199 {
00200 double xi;
00201 double dxi;
00202
00203
00204 xi = gsl_vector_get (x, i);
00205
00206
00207
00208 gsl_vector_set_all (pnode->col, xi);
00209 wi = gsl_matrix_column (pnode->D, i);
00210 gsl_vector_sub (pnode->col, &(wi.vector));
00211 gsl_vector_div (pnode->col, pnode->y);
00212
00213
00214 gsl_vector_mul (pnode->col, dy);
00215
00216 dxi = gnn_vector_sum_elements (pnode->col);
00217
00218
00219 gsl_vector_set (dx, i, dxi);
00220 }
00221
00222 return 0;
00223 }
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242 static int
00243 gnn_prototype_dw (gnn_node *node,
00244 const gsl_vector *x,
00245 const gsl_vector *w,
00246 const gsl_vector *dy,
00247 gsl_vector *dw)
00248 {
00249 size_t j;
00250 size_t m;
00251 gnn_prototype *pnode;
00252 gsl_vector_view wj;
00253 gsl_vector_view dwj;
00254
00255
00256 m = gnn_node_output_get_size (node);
00257
00258
00259 pnode = (gnn_prototype *) node;
00260
00261
00262 for (j=0; j<m; ++j)
00263 {
00264 double yj;
00265 double dyj;
00266
00267
00268 yj = gsl_vector_get (pnode->y, j);
00269 dyj = gsl_vector_get (dy, j);
00270
00271
00272 wj = gsl_matrix_row (pnode->D, j);
00273 gsl_vector_memcpy (pnode->row, &(wj.vector));
00274 gsl_vector_sub (pnode->row, x);
00275 gsl_vector_scale (pnode->row, dyj / yj);
00276
00277
00278 dwj = gsl_matrix_row (pnode->dD, j);
00279 gsl_vector_add (&(dwj.vector), pnode->row);
00280 }
00281
00282 return 0;
00283 }
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293 static void
00294 gnn_prototype_destroy (gnn_node *node)
00295 {
00296 gnn_prototype *pnode;
00297
00298 assert (node != NULL);
00299
00300 pnode = (gnn_prototype *) node;
00301
00302 if (pnode->y != NULL)
00303 gsl_vector_free (pnode->y);
00304 if (pnode->row != NULL)
00305 gsl_vector_free (pnode->row);
00306 if (pnode->col != NULL)
00307 gsl_vector_free (pnode->col);
00308
00309 return;
00310 }
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332 gnn_node *
00333 gnn_prototype_new (size_t input_size, size_t output_size,
00334 double min, double max)
00335 {
00336 int status;
00337 size_t i, j;
00338 gsl_rng *rng;
00339 gnn_node *node;
00340 gnn_prototype *pnode;
00341 gsl_vector *w;
00342 gsl_vector *dw;
00343 gsl_vector_int *f;
00344
00345
00346 if (input_size < 1)
00347 {
00348 GSL_ERROR_VAL ("input size should be strictly positive",
00349 GSL_EINVAL, NULL);
00350 }
00351 if (output_size < 1)
00352 {
00353 GSL_ERROR_VAL ("output size should be strictly positive",
00354 GSL_EINVAL, NULL);
00355 }
00356
00357
00358 if (min >= max)
00359 {
00360 GSL_ERROR_VAL ("min should be smaller than max",
00361 GSL_EINVAL, NULL);
00362 }
00363
00364
00365 pnode = (gnn_prototype *) malloc (sizeof (gnn_prototype));
00366 if (pnode == NULL)
00367 {
00368 GSL_ERROR_VAL ("couldn't allocate memory for gnn_prototype",
00369 GSL_ENOMEM, NULL);
00370 }
00371
00372
00373 node = (gnn_node *) pnode;
00374
00375
00376 status = gnn_node_init (node,
00377 "gnn_prototype",
00378 gnn_prototype_f,
00379 gnn_prototype_dx,
00380 gnn_prototype_dw,
00381 gnn_prototype_destroy);
00382 if (status)
00383 {
00384 free (node);
00385 GSL_ERROR_VAL ("could not initialize gnn_prototype node",
00386 GSL_EINVAL, NULL);
00387 }
00388
00389
00390 status = gnn_node_set_sizes (node, input_size,
00391 output_size, input_size * output_size);
00392 if (status)
00393 {
00394 gnn_node_destroy (node);
00395 GSL_ERROR_VAL ("could not set sizes for gnn_prototype node",
00396 GSL_EFAILED, NULL);
00397 }
00398
00399
00400 pnode->y = gsl_vector_alloc (output_size);
00401 pnode->row = gsl_vector_alloc (input_size);
00402 pnode->col = gsl_vector_alloc (output_size);
00403 if (pnode->y == NULL || pnode->row == NULL || pnode->col == NULL)
00404 {
00405 gnn_node_destroy (node);
00406 GSL_ERROR_VAL ("could not allocate buffers for gnn_prototype",
00407 GSL_ENOMEM, NULL);
00408 }
00409
00410
00411 w = gnn_node_local_get_w (node);
00412 dw = gnn_node_local_get_dw (node);
00413 f = gnn_node_local_get_f (node);
00414
00415 pnode->D_view = gsl_matrix_view_vector (w, output_size, input_size);
00416 pnode->dD_view = gsl_matrix_view_vector (dw, output_size, input_size);
00417 pnode->Df_view = gsl_matrix_int_view_vector (f, output_size, input_size);
00418
00419 pnode->D = &(pnode->D_view.matrix);
00420 pnode->dD = &(pnode->dD_view.matrix);
00421 pnode->Df = &(pnode->Df_view.matrix);
00422
00423
00424 rng = gnn_get_rng ();
00425 for (j=0; j<output_size; ++j)
00426 for (i=0; i<input_size; ++i)
00427 {
00428 double w;
00429 w = gsl_rng_uniform (rng);
00430 w = (max - min) * w + min;
00431 gsl_matrix_set (pnode->D, j, i, w);
00432 }
00433
00434
00435 gnn_node_local_update (node);
00436
00437 return node;
00438 }
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452 gnn_node *
00453 gnn_prototype_standard_new (size_t input_size, size_t output_size)
00454 {
00455 return gnn_prototype_new (input_size, output_size, -1.0, 1.0);
00456 }
00457
00458
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468 int
00469 gnn_prototype_vector_freeze (gnn_node *node, size_t j)
00470 {
00471 gnn_prototype *pnode;
00472 gsl_vector_int_view v;
00473
00474 assert (node != NULL);
00475
00476
00477 pnode = (gnn_prototype *) node;
00478
00479
00480 if (j < 0 || j >= pnode->Df->size1)
00481 {
00482 GSL_ERROR ("index out of bounds", GSL_EINVAL);
00483 }
00484
00485
00486 gnn_node_local_get_f (node);
00487
00488
00489 v = gsl_matrix_int_row (pnode->Df, j);
00490
00491
00492 gsl_vector_int_set_all (&(v.vector), 1);
00493
00494
00495 gnn_node_local_update (node);
00496
00497 return 0;
00498 }
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510 int
00511 gnn_prototype_vector_unfreeze (gnn_node *node, size_t j)
00512 {
00513 gnn_prototype *pnode;
00514 gsl_vector_int_view v;
00515
00516 assert (node != NULL);
00517
00518
00519 pnode = (gnn_prototype *) node;
00520
00521
00522 if (j < 0 || j >= pnode->Df->size1)
00523 {
00524 GSL_ERROR ("index out of bounds", GSL_EINVAL);
00525 }
00526
00527
00528 gnn_node_local_get_f (node);
00529
00530
00531 v = gsl_matrix_int_row (pnode->Df, j);
00532
00533
00534 gsl_vector_int_set_zero (&(v.vector));
00535
00536
00537 gnn_node_local_update (node);
00538
00539 return 0;
00540 }
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553 int
00554 gnn_prototype_vector_set (gnn_node *node, size_t j, const gsl_vector *v)
00555 {
00556 gsl_vector_view u;
00557 gnn_prototype *pnode;
00558
00559 assert (node != NULL);
00560
00561
00562 pnode = (gnn_prototype *) node;
00563
00564
00565 if (j < 0 || j >= pnode->D->size1)
00566 {
00567 GSL_ERROR ("index out of bounds", GSL_EINVAL);
00568 }
00569
00570
00571 if (gnn_node_input_get_size (node) != v->size)
00572 {
00573 GSL_ERROR ("vector insn't of the correct size", GSL_EINVAL);
00574 }
00575
00576
00577 gnn_node_local_get_w (node);
00578
00579
00580 u = gsl_matrix_row (pnode->D, j);
00581
00582
00583 gsl_vector_memcpy (&(u.vector), v);
00584
00585
00586 gnn_node_local_update (node);
00587
00588 return 0;
00589 }
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599
00600
00601
00602
00603
00604 int
00605 gnn_prototype_vector_get (gnn_node *node, size_t j, gsl_vector *v)
00606 {
00607 gsl_vector_view u;
00608 gnn_prototype *pnode;
00609
00610 assert (node != NULL);
00611
00612
00613 pnode = (gnn_prototype *) node;
00614
00615
00616 if (j < 0 || j >= pnode->D->size1)
00617 {
00618 GSL_ERROR ("index out of bounds", GSL_EINVAL);
00619 }
00620
00621
00622 if (gnn_node_input_get_size (node) != v->size)
00623 {
00624 GSL_ERROR ("vector insn't of the correct size", GSL_EINVAL);
00625 }
00626
00627
00628 gnn_node_local_get_w (node);
00629
00630
00631 u = gsl_matrix_row (pnode->D, j);
00632
00633
00634 gsl_vector_memcpy (v, &(u.vector));
00635
00636
00637 gnn_node_local_update (node);
00638
00639 return 0;
00640 }
00641
00642
00643