00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043 #include <stdarg.h>
00044 #include "gnn_serial.h"
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059 typedef struct _stack_buf
00060 {
00061 gsl_vector *xy;
00062 gsl_vector *dydx;
00063
00064 } stack_buffer;
00065
00066
00067
00068
00069
00070
00071
00072
00073 typedef struct _gnn_serial
00074 {
00075 gnn_node node;
00076 stack_buffer *buffers;
00077 } gnn_serial;
00078
00079
00080
00081 static int
00082 stack_buffer_init (stack_buffer *sb, size_t size);
00083
00084 static int
00085 stack_buffer_clear (stack_buffer *sb);
00086
00087 static int
00088 stack_buffer_finalize (stack_buffer *sb);
00089
00090 static int
00091 gnn_serial_make_buffers (gnn_node *node);
00092
00093 static int
00094 gnn_serial_destroy_buffers (gnn_node *node);
00095
00096
00097
00098 static int
00099 gnn_serial_f (gnn_node *node,
00100 const gsl_vector *x,
00101 const gsl_vector *w,
00102 gsl_vector *y);
00103
00104 static int
00105 gnn_serial_dx (gnn_node *node,
00106 const gsl_vector *x,
00107 const gsl_vector *w,
00108 const gsl_vector *dy,
00109 gsl_vector *dx);
00110
00111 static int
00112 gnn_serial_dw (gnn_node *node,
00113 const gsl_vector *x,
00114 const gsl_vector *w,
00115 const gsl_vector *dy,
00116 gsl_vector *dw);
00117
00118 static void
00119 gnn_serial_destroy (gnn_node *node);
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136 static int
00137 stack_buffer_init (stack_buffer *sb, size_t size)
00138 {
00139 assert (sb != NULL);
00140 assert (size >= 0);
00141
00142
00143 if (size > 0)
00144 {
00145 sb->xy = gsl_vector_alloc (size);
00146 sb->dydx = gsl_vector_alloc (size);
00147 if (sb->xy == NULL || sb->dydx == NULL)
00148 {
00149 stack_buffer_finalize (sb);
00150 GSL_ERROR ("couldn't initialize stack buffer", GSL_ENOMEM);
00151 }
00152 }
00153 return 0;
00154 }
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165 static int
00166 stack_buffer_clear (stack_buffer *sb)
00167 {
00168 assert (sb != NULL);
00169
00170
00171 if (sb->xy != NULL && sb->dydx != NULL)
00172 {
00173 gsl_vector_set_zero (sb->xy);
00174 gsl_vector_set_zero (sb->dydx);
00175 }
00176 return 0;
00177 }
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189 static int
00190 stack_buffer_finalize (stack_buffer *sb)
00191 {
00192 assert (sb != NULL);
00193
00194
00195 if (sb->xy != NULL)
00196 gsl_vector_free (sb->xy);
00197 if (sb->dydx != NULL)
00198 gsl_vector_free (sb->dydx);
00199
00200 return 0;
00201 }
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213 static int
00214 gnn_serial_make_buffers (gnn_node *node)
00215 {
00216 int i;
00217 int size;
00218 gnn_serial *stack;
00219 gnn_node *pre;
00220 gnn_node *post;
00221
00222 assert (node != NULL);
00223
00224
00225 size = gnn_node_sub_get_number (node);
00226 if (size < 2)
00227 return 0;
00228
00229
00230 stack = (gnn_serial *) node;
00231
00232
00233 stack->buffers = (stack_buffer *) malloc (sizeof (stack_buffer) * (size-1));
00234 if (stack->buffers == NULL)
00235 GSL_ERROR ("couldn't allocate stack buffers", GSL_ENOMEM);
00236
00237
00238 post = gnn_node_sub_get_node_at (node, 0);
00239 for (i=0; i<size-1; ++i)
00240 {
00241 int m, n, status;
00242
00243 pre = post;
00244 post = gnn_node_sub_get_node_at (node, i + 1);
00245
00246
00247 m = gnn_node_output_get_size (pre);
00248 n = gnn_node_input_get_size (post);
00249 if (m != n)
00250 {
00251 gnn_serial_destroy_buffers (node);
00252 GSL_ERROR ("the size of two consecutive nodes should match",
00253 GSL_EINVAL);
00254 }
00255
00256
00257 status = stack_buffer_init (&(stack->buffers[i]), n);
00258 if (status)
00259 {
00260 gnn_serial_destroy_buffers (node);
00261 GSL_ERROR ("couldn't create stack buffers", GSL_ENOMEM);
00262 }
00263 }
00264
00265 return 0;
00266 }
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277 static int
00278 gnn_serial_destroy_buffers (gnn_node *node)
00279 {
00280 int i;
00281 int size;
00282 gnn_serial *stack;
00283
00284 assert (node != NULL);
00285
00286
00287 size = gnn_node_sub_get_number (node);
00288 if (size < 2)
00289 return 0;
00290
00291
00292 stack = (gnn_serial *) node;
00293
00294
00295 for (i=0; i<size-1; ++i)
00296 {
00297 stack_buffer_finalize (&(stack->buffers[i]));
00298 }
00299
00300 return 0;
00301 }
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318 static int
00319 gnn_serial_f (gnn_node *node,
00320 const gsl_vector *x,
00321 const gsl_vector *w,
00322 gsl_vector *y)
00323 {
00324 int i;
00325 int size;
00326 gnn_serial *stack;
00327
00328 const gsl_vector *xptr;
00329 gnn_node *subnode;
00330 gsl_vector *yptr;
00331
00332 assert (node != NULL);
00333
00334
00335 stack = (gnn_serial *) node;
00336 size = gnn_node_sub_get_number (node);
00337 xptr = x;
00338 subnode = gnn_node_sub_get_node_at (node, 0);
00339
00340
00341 for (i=0; i<size-1; ++i)
00342 {
00343
00344 stack_buffer_clear (&(stack->buffers[i]));
00345
00346
00347 yptr = stack->buffers[i].xy;
00348
00349
00350 gnn_node_eval_f (subnode, xptr, yptr);
00351
00352
00353 xptr = yptr;
00354 subnode = gnn_node_sub_get_node_at (node, i+1);
00355 }
00356
00357
00358 yptr = y;
00359 gnn_node_eval_f (subnode, xptr, yptr);
00360
00361 return 0;
00362 }
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388 static int
00389 gnn_serial_dx (gnn_node *node,
00390 const gsl_vector *x,
00391 const gsl_vector *w,
00392 const gsl_vector *dy,
00393 gsl_vector *dx)
00394 {
00395 int i;
00396 int size;
00397 gnn_serial *stack;
00398
00399 const gsl_vector *dyptr;
00400 gnn_node *subnode;
00401 gsl_vector *dxptr;
00402
00403 assert (node != NULL);
00404
00405
00406 stack = (gnn_serial *) node;
00407 size = gnn_node_sub_get_number (node);
00408 dyptr = dy;
00409 subnode = gnn_node_sub_get_node_at (node, size-1);
00410
00411
00412 for (i=size-1; i>0; --i)
00413 {
00414
00415 dxptr = stack->buffers[i-1].dydx;
00416
00417
00418 gnn_node_eval_dx (subnode, dyptr, dxptr);
00419
00420
00421 dyptr = dxptr;
00422 subnode = gnn_node_sub_get_node_at (node, i-1);
00423 }
00424
00425
00426 dxptr = dx;
00427 gnn_node_eval_dx (subnode, dyptr, dxptr);
00428
00429 return 0;
00430 }
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449 static int
00450 gnn_serial_dw (gnn_node *node,
00451 const gsl_vector *x,
00452 const gsl_vector *w,
00453 const gsl_vector *dy,
00454 gsl_vector *dw)
00455 {
00456 int i;
00457 int size;
00458
00459 assert (node != NULL);
00460
00461
00462 size = gnn_node_sub_get_number (node);
00463
00464
00465 for (i=0; i<size; ++i)
00466 {
00467 gnn_node *subnode;
00468
00469
00470 subnode = gnn_node_sub_get_node_at (node, i);
00471 gnn_node_eval_dw (subnode);
00472 }
00473
00474 return 0;
00475 }
00476
00477
00478
00479
00480
00481
00482
00483
00484
00485 static void
00486 gnn_serial_destroy (gnn_node *node)
00487 {
00488 assert (node != NULL);
00489
00490
00491 gnn_serial_destroy_buffers (node);
00492 }
00493
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512
00513
00514
00515
00516
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528 gnn_node *
00529 gnn_serial_new (int size, ...)
00530 {
00531 size_t i;
00532 va_list args;
00533 gnn_node *s;
00534 gnn_node_vector *v;
00535
00536 assert (size > 0);
00537
00538
00539 v = gnn_node_vector_new (size);
00540 if (v == NULL)
00541 {
00542 GSL_ERROR_VAL ("couldn't allocate node vector for gnn_serial node",
00543 GSL_ENOMEM, NULL);
00544 }
00545
00546
00547 va_start (args, size);
00548 for (i=0; i<size; ++i)
00549 {
00550 gnn_node *n;
00551 n = (gnn_node *) va_arg (args, gnn_node *);
00552 gnn_node_vector_set (v, i, n);
00553 }
00554 va_end (args);
00555
00556
00557 s = gnn_serial_new_with_node_vector (v);
00558
00559
00560 gnn_node_vector_free (v);
00561
00562 return s;
00563 }
00564
00565
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597 gnn_node *
00598 gnn_serial_new_with_node_vector (gnn_node_vector *v)
00599 {
00600 int status;
00601 size_t n;
00602 size_t m;
00603 size_t size;
00604 gnn_node *node;
00605
00606 assert (v != NULL);
00607
00608
00609 size = gnn_node_vector_count_nodes (v);
00610
00611
00612 node = (gnn_node *) malloc (sizeof (gnn_serial));
00613 if (node == NULL)
00614 GSL_ERROR_VAL ("couldn't allocate memory for gnn_serial node",
00615 GSL_EINVAL, NULL);;
00616
00617
00618 status = gnn_node_init (node,
00619 "gnn_serial",
00620 gnn_serial_f,
00621 gnn_serial_dx,
00622 gnn_serial_dw,
00623 gnn_serial_destroy);
00624 if (status)
00625 {
00626 free (node);
00627 GSL_ERROR_VAL ("couldn't init gnn_serial", GSL_ENOMEM, NULL);
00628 }
00629
00630
00631 status = gnn_node_sub_install_node_vector (node, v);
00632 if (status)
00633 {
00634 gnn_node_destroy (node);
00635 GSL_ERROR_VAL ("couldn't install subnodes for gnn_serial",
00636 GSL_EINVAL, NULL);
00637 }
00638
00639
00640 n = gnn_node_input_get_size (gnn_node_sub_get_node_at (node, 0));
00641 m = gnn_node_output_get_size (gnn_node_sub_get_node_at (node, size-1));
00642
00643 status = gnn_node_set_sizes (node, n, m, 0);
00644 if (status)
00645 {
00646 gnn_node_destroy (node);
00647 GSL_ERROR_VAL ("couldn't set sizes for gnn_serial", GSL_EINVAL, NULL);
00648 }
00649
00650
00651 status = gnn_serial_make_buffers (node);
00652 if (status)
00653 {
00654 gnn_node_destroy (node);
00655 GSL_ERROR_VAL ("couldn't install buffers", GSL_EINVAL, NULL);
00656 }
00657
00658 return node;
00659 }
00660
00661
00662
00663
00664
00665
00666
00667
00668
00669
00670
00671 gnn_node *
00672 gnn_serial_get_node (gnn_node *node, int i)
00673 {
00674 return gnn_node_sub_get_node_at (node, i);
00675 }
00676
00677
00678
00679
00680
00681
00682
00683
00684 int
00685 gnn_serial_get_size (gnn_node *node)
00686 {
00687 return gnn_node_sub_get_number (node);
00688 }
00689
00690