00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025 #ifndef _GNN_NODE_H_
00026 #define _GNN_NODE_H_
00027
00028
00029
00030
00031
00032 #include "gnn_pbundle.h"
00033
00034
00035
00036
00037
00038
00039
00040 #define GNN_NODE_SUB_INSTALL(_node_, _size_, _status_) \
00041 { \
00042 va_list args; \
00043 va_start(args, (_size_)); \
00044 (_status_) = gnn_node_sub_install ((_node_), (_size_), args); \
00045 va_end(args); \
00046 }
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060 typedef struct _gnn_node gnn_node;
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079 typedef int
00080 (*gnn_node_f) (gnn_node *node,
00081 const gsl_vector *x,
00082 const gsl_vector *w,
00083 gsl_vector *y);
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103 typedef int
00104 (*gnn_node_df) (gnn_node *node,
00105 const gsl_vector *x,
00106 const gsl_vector *w,
00107 const gsl_vector *dy,
00108 gsl_vector *dxw);
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123
00124
00125 typedef void
00126 (*gnn_node_destructor) (gnn_node *layer);
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160 struct _gnn_node
00161 {
00162 const char *type;
00163 int n;
00164 int m;
00165
00166 gnn_phandle *ph;
00167 gnn_pbundle *pb;
00168
00169 gsl_vector x;
00170 gsl_vector dy;
00171
00172 gnn_node_destructor destroy;
00173 gnn_node_f f;
00174 gnn_node_df dw;
00175
00176 gnn_node_df dx;
00177
00178
00179 gnn_node *super;
00180 int nsub;
00181 gnn_node **sub;
00182 };
00183
00184
00185
00186
00187
00188
00189
00190 #ifndef _GNN_NODE_VECTOR_H_
00191 #include "gnn_node_vector.h"
00192 #endif
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204 int
00205 gnn_node_is_root (gnn_node *node);
00206
00207 const char *
00208 gnn_node_get_type_name (gnn_node *node);
00209
00210 int
00211 gnn_node_input_get_size (gnn_node *node);
00212
00213 int
00214 gnn_node_output_get_size (gnn_node *node);
00215
00216 int
00217 gnn_node_destroy (gnn_node *node);
00218
00219 int
00220 gnn_node_evaluate_init (gnn_node *node);
00221
00222 int
00223 gnn_node_evaluate_f (gnn_node *node,
00224 const gsl_vector *x,
00225 gsl_vector *y);
00226
00227 int
00228 gnn_node_evaluate_dx (gnn_node *node,
00229 const gsl_vector *dy,
00230 gsl_vector *dx);
00231
00232 int
00233 gnn_node_evaluate_dw (gnn_node *node,
00234 gsl_vector *dw);
00235
00236
00237
00238
00239
00240
00241
00242 int
00243 gnn_node_param_get_size (gnn_node *node);
00244
00245 int
00246 gnn_node_param_get (gnn_node *node, gsl_vector *w);
00247
00248 int
00249 gnn_node_param_set (gnn_node *node, const gsl_vector *w);
00250
00251 int
00252 gnn_node_param_freeze_flags_get (gnn_node *node, gsl_vector_int *f);
00253
00254 int
00255 gnn_node_param_freeze_flags_set (gnn_node *node, const gsl_vector_int *f);
00256
00257 int
00258 gnn_node_param_freeze (gnn_node *node, int i);
00259
00260 int
00261 gnn_node_param_unfreeze (gnn_node *node, int i);
00262
00263 int
00264 gnn_node_param_is_frozen (gnn_node *node, int i);
00265
00266 int
00267 gnn_node_param_are_frozen (gnn_node *node);
00268
00269 int
00270 gnn_node_param_freeze_all (gnn_node *node);
00271
00272 int
00273 gnn_node_param_unfreeze_all (gnn_node *node);
00274
00275 int
00276 gnn_node_param_share (const gnn_node *node, gnn_node *client);
00277
00278
00279
00280
00281
00282
00283
00284 int
00285 gnn_node_init (gnn_node *node,
00286 const char *type,
00287 gnn_node_f f,
00288 gnn_node_df dx,
00289 gnn_node_df dw,
00290 gnn_node_destructor dest);
00291
00292 int
00293 gnn_node_set_sizes (gnn_node *node, int n, int m, int l);
00294
00295 gsl_vector *
00296 gnn_node_local_get_w (gnn_node *node);
00297
00298 gsl_vector *
00299 gnn_node_local_get_dw (gnn_node *node);
00300
00301 gsl_vector_int *
00302 gnn_node_local_get_f (gnn_node *node);
00303
00304 int
00305 gnn_node_local_update (gnn_node *node);
00306
00307 int
00308 gnn_node_eval_f (gnn_node *node, const gsl_vector *x, gsl_vector *y);
00309
00310 int
00311 gnn_node_eval_dx (gnn_node *node, const gsl_vector *dy, gsl_vector *dx);
00312
00313 int
00314 gnn_node_eval_dw (gnn_node *node);
00315
00316
00317
00318
00319
00320
00321
00322 int
00323 gnn_node_sub_install_specific (gnn_node *node, int nsub, ...);
00324
00325 int
00326 gnn_node_sub_install (gnn_node *node, int nsub, va_list subs);
00327
00328 int
00329 gnn_node_sub_install_node_vector (gnn_node *node, gnn_node_vector *vector);
00330
00331 int
00332 gnn_node_sub_get_number (gnn_node *node);
00333
00334 gnn_node *
00335 gnn_node_sub_get_node_at (gnn_node *node, int i);
00336
00337 gnn_pbundle *
00338 gnn_node_sub_search_params (gnn_node *node, const char *type);
00339
00340
00341
00342 #endif