00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062 #include <math.h>
00063 #include <gsl/gsl_blas.h>
00064 #include <gsl/gsl_randist.h>
00065 #include <gsl/gsl_eigen.h>
00066 #include "gnn_utilities.h"
00067 #include "gnn_weight.h"
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085 typedef struct _gnn_weight gnn_weight;
00086
00087 struct _gnn_weight
00088 {
00089 gnn_node node;
00090
00091 int linear;
00092
00093 gsl_matrix_view A_view;
00094 gsl_vector_view B_view;
00095 gsl_matrix_view dA_view;
00096 gsl_vector_view dB_view;
00097 gsl_matrix_int_view Af_view;
00098 gsl_vector_int_view Bf_view;
00099
00100 gsl_matrix *A;
00101 gsl_vector *B;
00102 gsl_matrix *dA;
00103 gsl_vector *dB;
00104 gsl_matrix_int *Af;
00105 gsl_vector_int *Bf;
00106
00107 };
00108
00109
00110 static int
00111 gnn_weight_f (gnn_node *node,
00112 const gsl_vector *x,
00113 const gsl_vector *w,
00114 gsl_vector *y);
00115
00116
00117 static int
00118 gnn_weight_dx (gnn_node *node,
00119 const gsl_vector *x,
00120 const gsl_vector *w,
00121 const gsl_vector *dy,
00122 gsl_vector *dx);
00123
00124 static int
00125 gnn_weight_dw (gnn_node *node,
00126 const gsl_vector *x,
00127 const gsl_vector *w,
00128 const gsl_vector *dy,
00129 gsl_vector *dw);
00130
00131
00132 static void
00133 gnn_weight_destroy (gnn_node *node);
00134
00135
00136
00137 static int
00138 gnn_weight_hadamard_recurse (gsl_matrix *A, double fact);
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162 static int
00163 gnn_weight_f (gnn_node *node,
00164 const gsl_vector *x,
00165 const gsl_vector *w,
00166 gsl_vector *y)
00167 {
00168 gnn_weight *wnode;
00169
00170 wnode = (gnn_weight *) node;
00171
00172
00173 gsl_vector_add (y, wnode->B);
00174 gsl_blas_dgemv (CblasNoTrans, 1.0, wnode->A, x, 0.0, y);
00175
00176 return 0;
00177 }
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198 static int
00199 gnn_weight_dx (gnn_node *node,
00200 const gsl_vector *x,
00201 const gsl_vector *w,
00202 const gsl_vector *dy,
00203 gsl_vector *dx)
00204 {
00205 gnn_weight *wnode;
00206
00207 wnode = (gnn_weight *) node;
00208
00209
00210 gsl_blas_dgemv (CblasTrans, 1.0, wnode->A, dy, 0.0, dx);
00211
00212 return 0;
00213 }
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240 static int
00241 gnn_weight_dw (gnn_node *node,
00242 const gsl_vector *x,
00243 const gsl_vector *w,
00244 const gsl_vector *dy,
00245 gsl_vector *dw)
00246 {
00247 int input_size;
00248 int output_size;
00249 gsl_matrix_view X;
00250 gsl_matrix_view DY;
00251 gnn_weight *wnode;
00252
00253 wnode = (gnn_weight *) node;
00254
00255
00256 input_size = gnn_node_input_get_size (node);
00257 output_size = gnn_node_output_get_size (node);
00258
00259
00260 X = gsl_matrix_view_vector ((gsl_vector *) x, 1, input_size);
00261 DY = gsl_matrix_view_vector ((gsl_vector *) dy, output_size, 1);
00262
00263
00264 gsl_blas_dgemm (CblasNoTrans, CblasNoTrans, 1.0, &(DY.matrix),
00265 &(X.matrix), 1.0, wnode->dA);
00266 gsl_vector_add (wnode->dB, dy);
00267
00268 return 0;
00269 }
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280 static void
00281 gnn_weight_destroy (gnn_node *node)
00282 {
00283 return;
00284 }
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301 gnn_node *
00302 gnn_weight_new (int input_size, int output_size)
00303 {
00304 int status;
00305 int param_size;
00306 gnn_node *node;
00307 gnn_weight *wnode;
00308 gsl_vector *w;
00309 gsl_vector *dw;
00310 gsl_vector_int *f;
00311
00312
00313 if (input_size < 1 || output_size < 1)
00314 {
00315 GSL_ERROR_VAL ("gnn_weight's sizes should be stricly positive",
00316 GSL_EINVAL, NULL);
00317 }
00318
00319
00320 param_size = input_size * output_size + output_size;
00321
00322
00323 node = (gnn_node *) malloc (sizeof (gnn_weight));
00324 if (node == NULL)
00325 GSL_ERROR_VAL ("couldn't allocate memory for gnn_weight's node",
00326 GSL_ENOMEM, NULL);
00327
00328
00329 status = gnn_node_init (node,
00330 "gnn_weight",
00331 gnn_weight_f,
00332 gnn_weight_dx,
00333 gnn_weight_dw,
00334 gnn_weight_destroy);
00335 if (status)
00336 GSL_ERROR_VAL ("could not initialize weights node", GSL_EINVAL, NULL);
00337
00338 status = gnn_node_set_sizes (node, input_size, output_size, param_size);
00339 if (status)
00340 GSL_ERROR_VAL ("could not set sizes for weights node", GSL_EINVAL, NULL);
00341
00342
00343 wnode = (gnn_weight *) node;
00344 w = gnn_node_local_get_w (node);
00345 dw = gnn_node_local_get_dw (node);
00346 f = gnn_node_local_get_f (node);
00347
00348
00349 wnode->A_view = gsl_matrix_view_vector (w, output_size, input_size);
00350 wnode->B_view = gsl_vector_subvector (w, input_size * output_size,
00351 output_size);
00352
00353 wnode->dA_view = gsl_matrix_view_vector (dw, output_size, input_size);
00354 wnode->dB_view = gsl_vector_subvector (dw, input_size * output_size,
00355 output_size);
00356
00357 wnode->Af_view = gsl_matrix_int_view_vector (f, output_size, input_size);
00358 wnode->Bf_view = gsl_vector_int_subvector (f, input_size * output_size,
00359 output_size);
00360
00361
00362 wnode->A = &(wnode->A_view.matrix);
00363 wnode->B = &(wnode->B_view.vector);
00364 wnode->dA = &(wnode->dA_view.matrix);
00365 wnode->dB = &(wnode->dB_view.vector);
00366 wnode->Af = &(wnode->Af_view.matrix);
00367 wnode->Bf = &(wnode->Bf_view.vector);
00368
00369 return node;
00370 }
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388 int
00389 gnn_weight_init (gnn_node *node)
00390 {
00391 int i;
00392 int l;
00393 double sigma;
00394 double sqrt12;
00395 double amplitude;
00396 gsl_rng *r;
00397 gsl_vector *w;
00398
00399
00400 r = gnn_get_rng ();
00401
00402
00403 w = gnn_node_local_get_w (node);
00404 l = w->size;
00405
00406
00407 sigma = 1.0 / sqrt (l);
00408 sqrt12 = sqrt (12.0);
00409 amplitude = 1.0 / (sqrt12 * sigma);
00410
00411
00412 for (i=0; i<l; ++i)
00413 {
00414 double random_value;
00415 random_value = amplitude * (gsl_rng_uniform (r) - 0.5);
00416 gsl_vector_set (w, i, random_value);
00417 }
00418
00419
00420 gnn_node_local_update (node);
00421
00422 return 0;
00423 }
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436 gsl_matrix *
00437 gnn_weight_get_A (gnn_node *node)
00438 {
00439 gnn_weight *wnode;
00440
00441 assert (node != NULL);
00442
00443 gnn_node_local_get_w (node);
00444 wnode = (gnn_weight *) node;
00445 return wnode->A;
00446 }
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459 gsl_vector *
00460 gnn_weight_get_B (gnn_node *node)
00461 {
00462 gnn_weight *wnode;
00463
00464 assert (node != NULL);
00465
00466 gnn_node_local_get_w (node);
00467 wnode = (gnn_weight *) node;
00468 return wnode->B;
00469 }
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483 gsl_matrix *
00484 gnn_weight_get_dA (gnn_node *node)
00485 {
00486 gnn_weight *wnode;
00487
00488 assert (node != NULL);
00489
00490 gnn_node_local_get_dw (node);
00491 wnode = (gnn_weight *) node;
00492 return wnode->dA;
00493 }
00494
00495
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505
00506
00507 gsl_vector *
00508 gnn_weight_get_dB (gnn_node *node)
00509 {
00510 gnn_weight *wnode;
00511
00512 assert (node != NULL);
00513
00514 gnn_node_local_get_dw (node);
00515 wnode = (gnn_weight *) node;
00516 return wnode->dB;
00517 }
00518
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531 gsl_matrix_int *
00532 gnn_weight_get_Af (gnn_node *node)
00533 {
00534 gnn_weight *wnode;
00535
00536 assert (node != NULL);
00537
00538 gnn_node_local_get_f (node);
00539 wnode = (gnn_weight *) node;
00540 return wnode->Af;
00541 }
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552
00553
00554
00555 gsl_vector_int *
00556 gnn_weight_get_Bf (gnn_node *node)
00557 {
00558 gnn_weight *wnode;
00559
00560 assert (node != NULL);
00561
00562 gnn_node_local_get_f (node);
00563 wnode = (gnn_weight *) node;
00564 return wnode->Bf;
00565 }
00566
00567
00568
00569
00570
00571
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595
00596
00597
00598
00599 int
00600 gnn_weight_connect (gnn_node *node, size_t i, size_t j, double strength)
00601 {
00602 gnn_weight *wnode;
00603
00604 assert (node != NULL);
00605
00606 gnn_node_local_get_w (node);
00607 gnn_node_local_get_f (node);
00608 wnode = (gnn_weight *) node;
00609
00610
00611 gsl_matrix_set (wnode->A, j, i, strength);
00612 gsl_matrix_int_set (wnode->Af, j, i, 0);
00613 gnn_node_local_update (node);
00614
00615 return 0;
00616 }
00617
00618
00619
00620
00621
00622
00623
00624
00625
00626
00627
00628
00629
00630 int
00631 gnn_weight_disconnect (gnn_node *node, size_t i, size_t j)
00632 {
00633 gnn_weight *wnode;
00634
00635 assert (node != NULL);
00636
00637 gnn_node_local_get_w (node);
00638 gnn_node_local_get_f (node);
00639 wnode = (gnn_weight *) node;
00640
00641
00642 gsl_matrix_set (wnode->A, j, i, 0.0);
00643 gsl_matrix_int_set (wnode->Af, j, i, 1);
00644 gnn_node_local_update (node);
00645
00646 return 0;
00647 }
00648
00649
00650
00651
00652
00653
00654
00655
00656
00657
00658
00659
00660 int
00661 gnn_weight_connect_all (gnn_node *node, double strength)
00662 {
00663 gnn_weight *wnode;
00664
00665 assert (node != NULL);
00666
00667 gnn_node_local_get_w (node);
00668 gnn_node_local_get_f (node);
00669 wnode = (gnn_weight *) node;
00670
00671
00672 gsl_matrix_set_all (wnode->A, strength);
00673 gsl_matrix_int_set_all (wnode->Af, 0);
00674 gnn_node_local_update (node);
00675
00676 return 0;
00677 }
00678
00679
00680
00681
00682
00683
00684
00685
00686
00687
00688 int
00689 gnn_weight_disconnect_all (gnn_node *node)
00690 {
00691 gnn_weight *wnode;
00692
00693 assert (node != NULL);
00694
00695 gnn_node_local_get_w (node);
00696 gnn_node_local_get_f (node);
00697 wnode = (gnn_weight *) node;
00698
00699
00700 gsl_matrix_set_zero (wnode->A);
00701 gsl_matrix_int_set_all (wnode->Af, 1);
00702 gnn_node_local_update (node);
00703
00704 return 0;
00705 }
00706
00707
00708
00709
00710
00711
00712
00713
00714
00715
00716
00717
00718
00719 int
00720 gnn_weight_connect_input (gnn_node *node, size_t i, double strength)
00721 {
00722 gnn_weight *wnode;
00723 gsl_vector_view Acol;
00724 gsl_vector_int_view Afcol;
00725
00726 assert (node != NULL);
00727
00728
00729 if (i < 0 || wnode->A->size2 <= i)
00730 {
00731 GSL_ERROR ("the input's index is out of bounds", GSL_EINVAL);
00732 }
00733
00734 gnn_node_local_get_w (node);
00735 gnn_node_local_get_f (node);
00736 wnode = (gnn_weight *) node;
00737
00738
00739 Acol = gsl_matrix_column (wnode->A, i);
00740 Afcol = gsl_matrix_int_column (wnode->Af, i);
00741
00742
00743 gsl_vector_set_all (&(Acol.vector), strength);
00744 gsl_vector_int_set_all (&(Afcol.vector), 0);
00745 gnn_node_local_update (node);
00746
00747 return 0;
00748 }
00749
00750
00751
00752
00753
00754
00755
00756
00757
00758
00759
00760 int
00761 gnn_weight_disconnect_input (gnn_node *node, size_t i)
00762 {
00763 gnn_weight *wnode;
00764 gsl_vector_view Acol;
00765 gsl_vector_int_view Afcol;
00766
00767 assert (node != NULL);
00768
00769
00770 if (i < 0 || wnode->A->size2 <= i)
00771 {
00772 GSL_ERROR ("the input's index is out of bounds", GSL_EINVAL);
00773 }
00774
00775 gnn_node_local_get_w (node);
00776 gnn_node_local_get_f (node);
00777 wnode = (gnn_weight *) node;
00778
00779
00780 Acol = gsl_matrix_column (wnode->A, i);
00781 Afcol = gsl_matrix_int_column (wnode->Af, i);
00782
00783
00784 gsl_vector_set_zero (&(Acol.vector));
00785 gsl_vector_int_set_all (&(Afcol.vector), 1);
00786 gnn_node_local_update (node);
00787
00788 return 0;
00789 }
00790
00791
00792
00793
00794
00795
00796
00797
00798
00799
00800
00801
00802 int
00803 gnn_weight_freeze_connection (gnn_node *node, size_t i, size_t j)
00804 {
00805 gnn_weight *wnode;
00806
00807 assert (node != NULL);
00808
00809 gnn_node_local_get_f (node);
00810 wnode = (gnn_weight *) node;
00811
00812
00813 gsl_matrix_int_set (wnode->Af, j, i, 1);
00814 gnn_node_local_update (node);
00815
00816 return 0;
00817 }
00818
00819
00820
00821
00822
00823
00824
00825
00826
00827
00828
00829
00830 int
00831 gnn_weight_unfreeze_connection (gnn_node *node, size_t i, size_t j)
00832 {
00833 gnn_weight *wnode;
00834
00835 assert (node != NULL);
00836
00837 gnn_node_local_get_f (node);
00838 wnode = (gnn_weight *) node;
00839
00840
00841 gsl_matrix_int_set (wnode->Af, j, i, 0);
00842 gnn_node_local_update (node);
00843
00844 return 0;
00845 }
00846
00847
00848
00849
00850
00851
00852
00853
00854
00855
00856 int
00857 gnn_weight_freeze_all_connections (gnn_node *node)
00858 {
00859 gnn_weight *wnode;
00860
00861 assert (node != NULL);
00862
00863 gnn_node_local_get_f (node);
00864 wnode = (gnn_weight *) node;
00865
00866
00867 gsl_matrix_int_set_all (wnode->Af, 1);
00868 gnn_node_local_update (node);
00869
00870 return 0;
00871 }
00872
00873
00874
00875
00876
00877
00878
00879
00880
00881
00882 int
00883 gnn_weight_unfreeze_all_connections (gnn_node *node)
00884 {
00885 gnn_weight *wnode;
00886
00887 assert (node != NULL);
00888
00889 gnn_node_local_get_f (node);
00890 wnode = (gnn_weight *) node;
00891
00892
00893 gsl_matrix_int_set_all (wnode->Af, 0);
00894 gnn_node_local_update (node);
00895
00896 return 0;
00897 }
00898
00899
00900
00901
00902
00903
00904
00905
00906
00907
00908
00909
00910
00911 int
00912 gnn_weight_connect_bias (gnn_node *node, size_t j, double strength)
00913 {
00914 gnn_weight *wnode;
00915
00916 assert (node != NULL);
00917
00918 gnn_node_local_get_w (node);
00919 gnn_node_local_get_f (node);
00920 wnode = (gnn_weight *) node;
00921
00922
00923 gsl_vector_set (wnode->B, j, strength);
00924 gsl_vector_int_set (wnode->Bf, j, 0);
00925 gnn_node_local_update (node);
00926
00927 return 0;
00928 }
00929
00930
00931
00932
00933
00934
00935
00936
00937
00938
00939
00940
00941 int
00942 gnn_weight_disconnect_bias (gnn_node *node, size_t j)
00943 {
00944 gnn_weight *wnode;
00945
00946 assert (node != NULL);
00947
00948 gnn_node_local_get_w (node);
00949 gnn_node_local_get_f (node);
00950 wnode = (gnn_weight *) node;
00951
00952
00953 gsl_vector_set (wnode->B, j, 0.0);
00954 gsl_vector_int_set (wnode->Bf, j, 1);
00955 gnn_node_local_update (node);
00956
00957 return 0;
00958 }
00959
00960
00961
00962
00963
00964
00965
00966
00967
00968
00969
00970
00971 int
00972 gnn_weight_connect_all_bias (gnn_node *node, double strength)
00973 {
00974 gnn_weight *wnode;
00975
00976 assert (node != NULL);
00977
00978 gnn_node_local_get_w (node);
00979 gnn_node_local_get_f (node);
00980 wnode = (gnn_weight *) node;
00981
00982
00983 gsl_vector_set_all (wnode->B, strength);
00984 gsl_vector_int_set_zero (wnode->Bf);
00985 gnn_node_local_update (node);
00986
00987 return 0;
00988 }
00989
00990
00991
00992
00993
00994
00995
00996
00997
00998
00999 int
01000 gnn_weight_disconnect_all_bias (gnn_node *node)
01001 {
01002 gnn_weight *wnode;
01003
01004 assert (node != NULL);
01005
01006 gnn_node_local_get_w (node);
01007 gnn_node_local_get_f (node);
01008 wnode = (gnn_weight *) node;
01009
01010
01011 gsl_vector_set_all (wnode->B, 0.0);
01012 gsl_vector_int_set_all (wnode->Bf, 1);
01013 gnn_node_local_update (node);
01014
01015 return 0;
01016 }
01017
01018
01019
01020
01021
01022
01023
01024
01025
01026
01027
01028 int
01029 gnn_weight_freeze_bias (gnn_node *node, size_t j)
01030 {
01031 gnn_weight *wnode;
01032
01033 assert (node != NULL);
01034
01035 gnn_node_local_get_f (node);
01036 wnode = (gnn_weight *) node;
01037
01038
01039 gsl_vector_int_set (wnode->Bf, j, 1);
01040 gnn_node_local_update (node);
01041
01042 return 0;
01043 }
01044
01045
01046
01047
01048
01049
01050
01051
01052
01053
01054
01055 int
01056 gnn_weight_unfreeze_bias (gnn_node *node, size_t j)
01057 {
01058 gnn_weight *wnode;
01059
01060 assert (node != NULL);
01061
01062 gnn_node_local_get_f (node);
01063 wnode = (gnn_weight *) node;
01064
01065
01066 gsl_vector_int_set (wnode->Bf, j, 0);
01067 gnn_node_local_update (node);
01068
01069 return 0;
01070 }
01071
01072
01073
01074
01075
01076
01077
01078
01079
01080
01081 int
01082 gnn_weight_freeze_all_biases (gnn_node *node)
01083 {
01084 gnn_weight *wnode;
01085
01086 assert (node != NULL);
01087
01088 gnn_node_local_get_f (node);
01089 wnode = (gnn_weight *) node;
01090
01091
01092 gsl_vector_int_set_all (wnode->Bf, 1);
01093 gnn_node_local_update (node);
01094
01095 return 0;
01096 }
01097
01098
01099
01100
01101
01102
01103
01104
01105
01106
01107 int
01108 gnn_weight_unfreeze_all_biases (gnn_node *node)
01109 {
01110 gnn_weight *wnode;
01111
01112 assert (node != NULL);
01113
01114 gnn_node_local_get_f (node);
01115 wnode = (gnn_weight *) node;
01116
01117
01118 gsl_vector_int_set_all (wnode->Bf, 0);
01119 gnn_node_local_update (node);
01120
01121 return 0;
01122 }
01123
01124
01125
01126