00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 #include "gnn_parallel.h"
00050 #include <assert.h>
00051
00052
00053
00054
00055
00056
00057
00058 static int
00059 gnn_parallel_f (gnn_node *node,
00060 const gsl_vector *x,
00061 const gsl_vector *w,
00062 gsl_vector *y);
00063
00064 static int
00065 gnn_parallel_dx (gnn_node *node,
00066 const gsl_vector *x,
00067 const gsl_vector *w,
00068 const gsl_vector *dy,
00069 gsl_vector *dx);
00070
00071 static int
00072 gnn_parallel_dw (gnn_node *node,
00073 const gsl_vector *x,
00074 const gsl_vector *w,
00075 const gsl_vector *dy,
00076 gsl_vector *dw);
00077
00078 static void
00079 gnn_parallel_destroy (gnn_node *node);
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099 static int
00100 gnn_parallel_f (gnn_node *node,
00101 const gsl_vector *x,
00102 const gsl_vector *w,
00103 gsl_vector *y)
00104 {
00105 gnn_parallel *pnode;
00106 int size;
00107 int i;
00108
00109
00110 pnode = (gnn_parallel *) node;
00111
00112
00113 size = gnn_node_sub_get_number (node);
00114
00115
00116 for (i=0; i<size; ++i)
00117 {
00118 int in_offset;
00119 int in_length;
00120 int out_offset;
00121 int out_length;
00122 gsl_vector_view sx;
00123 gsl_vector_view sy;
00124 gnn_node *snode;
00125
00126
00127 snode = gnn_node_sub_get_node_at (node, i);
00128
00129
00130 in_offset = pnode->in_off[i];
00131 in_length = pnode->in_size[i];
00132 out_offset = pnode->out_off[i];
00133 out_length = pnode->out_size[i];
00134
00135
00136 sx = gsl_vector_subvector ((gsl_vector *) x, in_offset, in_length);
00137 sy = gsl_vector_subvector (y, out_offset, out_length);
00138
00139
00140 gnn_node_eval_f (snode, &(sx.vector), &(sy.vector));
00141 }
00142
00143 return 0;
00144 }
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175 static int
00176 gnn_parallel_dx (gnn_node *node,
00177 const gsl_vector *x,
00178 const gsl_vector *w,
00179 const gsl_vector *dy,
00180 gsl_vector *dx)
00181 {
00182 gnn_parallel *pnode;
00183 int size;
00184 int i;
00185
00186
00187 pnode = (gnn_parallel *) node;
00188
00189
00190 size = gnn_node_sub_get_number (node);
00191
00192
00193 for (i=0; i<size; ++i)
00194 {
00195 int in_offset;
00196 int in_length;
00197 int out_offset;
00198 int out_length;
00199 gsl_vector_view sdx;
00200 gsl_vector_view sdy;
00201 gnn_node *snode;
00202
00203
00204 snode = gnn_node_sub_get_node_at (node, i);
00205
00206
00207 in_offset = pnode->in_off[i];
00208 in_length = pnode->in_size[i];
00209 out_offset = pnode->out_off[i];
00210 out_length = pnode->out_size[i];
00211
00212
00213 sdy = gsl_vector_subvector ((gsl_vector *) dy, out_offset, out_length);
00214 sdx = gsl_vector_subvector (dx, in_offset, in_length);
00215
00216
00217 gnn_node_eval_dx (snode, &(sdy.vector), &(sdx.vector));
00218 }
00219
00220 return 0;
00221 }
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250 static int
00251 gnn_parallel_dw (gnn_node *node,
00252 const gsl_vector *x,
00253 const gsl_vector *w,
00254 const gsl_vector *dy,
00255 gsl_vector *dw)
00256 {
00257 int size;
00258 int i;
00259
00260
00261 size = gnn_node_sub_get_number (node);
00262
00263
00264 for (i=0; i<size; ++i)
00265 {
00266 gnn_node *snode;
00267
00268
00269 snode = gnn_node_sub_get_node_at (node, i);
00270
00271
00272 gnn_node_eval_dw (snode);
00273 }
00274
00275 return 0;
00276 }
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287 static void
00288 gnn_parallel_destroy (gnn_node *node)
00289 {
00290 gnn_parallel *pnode;
00291
00292 assert (node != NULL);
00293
00294
00295 pnode = (gnn_parallel *) node;
00296
00297
00298 if (pnode->in_off != NULL)
00299 free (pnode->in_off);
00300 if (pnode->in_size != NULL)
00301 free (pnode->in_size);
00302 if (pnode->out_off != NULL)
00303 free (pnode->out_off);
00304 if (pnode->out_size != NULL)
00305 free (pnode->out_size);
00306 }
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334
00335
00336
00337 gnn_node *
00338 gnn_parallel_new (size_t size, ...)
00339 {
00340 size_t i;
00341 va_list args;
00342 gnn_node *p;
00343 gnn_node_vector *v;
00344
00345 assert (size > 0);
00346
00347
00348 v = gnn_node_vector_new (size);
00349 if (v == NULL)
00350 {
00351 GSL_ERROR_VAL ("couldn't allocate node vector for parallel node",
00352 GSL_ENOMEM, NULL);
00353 }
00354
00355
00356 va_start (args, size);
00357 for (i=0; i<size; ++i)
00358 {
00359 gnn_node *n;
00360 n = (gnn_node *) va_arg (args, gnn_node *);
00361 gnn_node_vector_set (v, i, n);
00362 }
00363 va_end (args);
00364
00365
00366 p = gnn_parallel_new_with_node_vector (v);
00367
00368
00369 gnn_node_vector_free (v);
00370
00371 return p;
00372 }
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399 gnn_node *
00400 gnn_parallel_new_with_node_vector (gnn_node_vector *v)
00401 {
00402 int i;
00403 int status;
00404 size_t size;
00405 size_t input_size;
00406 size_t output_size;
00407 gnn_node *node;
00408 gnn_parallel *pnode;
00409
00410 assert (v != NULL);
00411
00412
00413 size = gnn_node_vector_count_nodes (v);
00414
00415
00416 node = (gnn_node *) malloc (sizeof (gnn_parallel));
00417 if (node == NULL)
00418 {
00419 GSL_ERROR_VAL ("could not allocate memory for gnn_parallel",
00420 GSL_ENOMEM, NULL);
00421 }
00422
00423
00424 status = gnn_node_init (node,
00425 "gnn_parallel",
00426 gnn_parallel_f,
00427 gnn_parallel_dx,
00428 gnn_parallel_dw,
00429 gnn_parallel_destroy);
00430 if (status)
00431 {
00432 free (node);
00433 GSL_ERROR_VAL ("could not initialize gnn_parallel node",
00434 GSL_EFAILED, NULL);
00435 }
00436
00437
00438 status = gnn_node_sub_install_node_vector (node, v);
00439 if (status)
00440 {
00441 gnn_node_destroy (node);
00442 GSL_ERROR_VAL ("couldn't install subnodes for gnn_parallel",
00443 GSL_EINVAL, NULL);
00444 }
00445
00446
00447 pnode = (gnn_parallel *) node;
00448
00449
00450 pnode->in_off = (size_t *) malloc (sizeof (size_t) * size);
00451 pnode->in_size = (size_t *) malloc (sizeof (size_t) * size);
00452 pnode->out_off = (size_t *) malloc (sizeof (size_t) * size);
00453 pnode->out_size = (size_t *) malloc (sizeof (size_t) * size);
00454
00455 if ( pnode->in_off == NULL
00456 || pnode->in_size == NULL
00457 || pnode->out_off == NULL
00458 || pnode->out_size == NULL )
00459 {
00460 gnn_node_destroy (node);
00461 GSL_ERROR_VAL ("could not allocate memory for gnn_parallel info arrays",
00462 GSL_ENOMEM, NULL);
00463 }
00464
00465
00466 input_size = 0;
00467 output_size = 0;
00468 for (i=0; i<size; ++i)
00469 {
00470 size_t inlen;
00471 size_t outlen;
00472
00473 gnn_node *subnode;
00474
00475 subnode = gnn_node_sub_get_node_at (node, i);
00476
00477 inlen = gnn_node_input_get_size (subnode);
00478 outlen = gnn_node_output_get_size (subnode);
00479
00480 pnode->in_off[i] = input_size;
00481 pnode->in_size[i] = inlen;
00482 pnode->out_off[i] = output_size;
00483 pnode->out_size[i] = outlen;
00484
00485 input_size += inlen;
00486 output_size += outlen;
00487 }
00488
00489
00490 gnn_node_set_sizes (node, input_size, output_size, 0);
00491
00492 return node;
00493 }
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504 int
00505 gnn_parallel_get_input_offset (gnn_node *node, int i)
00506 {
00507 int sub_size;
00508 gnn_parallel *pnode;
00509
00510 pnode = (gnn_parallel *) node;
00511 sub_size = gnn_node_sub_get_number (node);
00512 if ( (0 <= i) && (i < sub_size) )
00513 return pnode->in_off[i];
00514 else
00515 GSL_ERROR ("index of parallel layer out of bounds", GSL_EINVAL);
00516 }
00517
00518
00519
00520
00521
00522
00523
00524
00525
00526 int
00527 gnn_parallel_get_input_length (gnn_node *node, int i)
00528 {
00529 int sub_size;
00530 gnn_parallel *pnode;
00531
00532 pnode = (gnn_parallel *) node;
00533 sub_size = gnn_node_sub_get_number (node);
00534 if ( (0 <= i) && (i < sub_size) )
00535 return pnode->in_size[i];
00536 else
00537 GSL_ERROR ("index of parallel layer out of bounds", GSL_EINVAL);
00538 }
00539
00540
00541
00542
00543
00544
00545
00546
00547
00548 int
00549 gnn_parallel_get_output_offset (gnn_node *node, int i)
00550 {
00551 int sub_size;
00552 gnn_parallel *pnode;
00553
00554 pnode = (gnn_parallel *) node;
00555 sub_size = gnn_node_sub_get_number (node);
00556 if ( (0 <= i) && (i < sub_size) )
00557 return pnode->out_off[i];
00558 else
00559 GSL_ERROR ("index of parallel layer out of bounds", GSL_EINVAL);
00560 }
00561
00562
00563
00564
00565
00566
00567
00568
00569
00570 int
00571 gnn_parallel_get_output_length (gnn_node *node, int i)
00572 {
00573 int sub_size;
00574 gnn_parallel *pnode;
00575
00576 pnode = (gnn_parallel *) node;
00577 sub_size = gnn_node_sub_get_number (node);
00578 if ( (0 <= i) && (i < sub_size) )
00579 return pnode->out_size[i];
00580 else
00581 GSL_ERROR ("index of parallel layer out of bounds", GSL_EINVAL);
00582 }
00583
00584
00585