Main Page | Directories | File List | File Members | Related Pages

Node.c

Go to the documentation of this file.
00001 
00009 #include "party.h"
00010 
00011 
00022 void C_prediction(const double *y, int n, int q, const double *weights, 
00023                   const double sweights, double *ans) {
00024 
00025     int i, j, jn;
00026     
00027     for (j = 0; j < q; j++) {
00028         ans[j] = 0.0;
00029         jn = j * n;
00030         for (i = 0; i < n; i++) 
00031             ans[j] += weights[i] * y[jn + i];
00032         ans[j] = ans[j] / sweights;
00033     }
00034 }
00035 
00036 
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights, 
00049             SEXP fitmem, SEXP controls, int TERMINAL) {
00050     
00051     int nobs, ninputs, jselect, yORDERED, q, j;
00052     double mincriterion, sweights, *dprediction;
00053     double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054     double *standstat, *splitstat;
00055     SEXP responses, inputs, y, x, expcovinf, thisweights, linexpcov;
00056     SEXP varctrl, splitctrl, gtctrl, tgctrl, split, joint;
00057     
00058     nobs = get_nobs(learnsample);
00059     ninputs = get_ninputs(learnsample);
00060     varctrl = get_varctrl(controls);
00061     splitctrl = get_splitctrl(controls);
00062     gtctrl = get_gtctrl(controls);
00063     tgctrl = get_tgctrl(controls);
00064     mincriterion = get_mincriterion(gtctrl);
00065     responses = GET_SLOT(learnsample, PL2_responsesSym);
00066     inputs = GET_SLOT(learnsample, PL2_inputsSym);
00067     yORDERED = is_ordinal(responses, 1); 
00068     y = get_transformation(responses, 1);
00069     q = ncol(y);
00070     joint = GET_SLOT(responses, PL2_jointtransfSym);
00071 
00072     /* <FIXME> we compute C_GlobalTest even for TERMINAL nodes! </FIXME> */
00073 
00074     /* compute the test statistics and the node criteria for each input */        
00075     C_GlobalTest(learnsample, weights, fitmem, varctrl,
00076                  gtctrl, get_minsplit(splitctrl), 
00077                  REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00078     
00079     /* sum of weights: C_GlobalTest did nothing if sweights < mincriterion */
00080     sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym), 
00081                              PL2_sumweightsSym))[0];
00082 
00083     /* compute the prediction of this node */
00084     dprediction = REAL(S3get_prediction(node));
00085 
00086     /* <FIXME> feed raw numeric values OR dummy encoded factors as y 
00087        Problem: what happens for survival times ? */
00088     C_prediction(REAL(joint), nobs, ncol(joint), REAL(weights), 
00089                      sweights, dprediction);
00090     /* </FIXME> */
00091 
00092     teststat = REAL(S3get_teststat(node));
00093     pvalue = REAL(S3get_criterion(node));
00094 
00095     /* try the two out of ninputs best inputs variables */
00096     /* <FIXME> be more flexible and add a parameter controlling
00097                the number of inputs tried </FIXME> */
00098     for (j = 0; j < 2; j++) {
00099 
00100         smax = C_max(pvalue, ninputs);
00101         REAL(S3get_maxcriterion(node))[0] = smax;
00102     
00103         /* if the global null hypothesis was rejected */
00104         if (smax > mincriterion && !TERMINAL) {
00105 
00106             /* the input variable with largest association to the response */
00107             jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00108 
00109             /* get the raw numeric values or the codings of a factor */
00110             x = get_variable(inputs, jselect);
00111             if (has_missings(inputs, jselect)) {
00112                 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect), 
00113                                     PL2_expcovinfSym);
00114                 thisweights = get_weights(fitmem, jselect);
00115             } else {
00116                 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00117                thisweights = weights;
00118             }
00119 
00120             /* <FIXME> handle ordered factors separatly??? </FIXME> */
00121             if (!is_nominal(inputs, jselect)) {
00122             
00123                 /* search for a split in a ordered variable x */
00124                 split = S3get_primarysplit(node);
00125                 
00126                 /* check if the n-vector of splitstatistics 
00127                    should be returned for each primary split */
00128                 if (get_savesplitstats(tgctrl)) {
00129                     C_init_orderedsplit(split, nobs);
00130                     splitstat = REAL(S3get_splitstatistics(split));
00131                 } else {
00132                     C_init_orderedsplit(split, 0);
00133                     splitstat = REAL(get_splitstatistics(fitmem));
00134                 }
00135 
00136                 C_split(REAL(x), 1, REAL(y), q, REAL(weights), nobs,
00137                         INTEGER(get_ordering(inputs, jselect)), 
00138                         REAL(VECTOR_ELT(GET_SLOT(responses, PL2_scoresSym), 0)),
00139                         yORDERED, splitctrl, 
00140                         GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00141                         expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00142                         splitstat);
00143                 S3set_variableID(split, jselect);
00144              } else {
00145            
00146                  /* search of a set of levels (split) in a numeric variable x */
00147                  split = S3get_primarysplit(node);
00148                  
00149                 /* check if the n-vector of splitstatistics 
00150                    should be returned for each primary split */
00151                 if (get_savesplitstats(tgctrl)) {
00152                     C_init_nominalsplit(split, 
00153                         LENGTH(get_levels(inputs, jselect)), 
00154                         nobs);
00155                     splitstat = REAL(S3get_splitstatistics(split));
00156                 } else {
00157                     C_init_nominalsplit(split, 
00158                         LENGTH(get_levels(inputs, jselect)), 
00159                         0);
00160                     splitstat = REAL(get_splitstatistics(fitmem));
00161                 }
00162           
00163                  linexpcov = get_varmemory(fitmem, jselect);
00164                  standstat = Calloc(get_dimension(linexpcov), double);
00165                  C_standardize(REAL(GET_SLOT(linexpcov, 
00166                                              PL2_linearstatisticSym)),
00167                                REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00168                                REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00169                                get_dimension(linexpcov), get_tol(splitctrl), 
00170                                standstat);
00171  
00172                  C_splitcategorical(INTEGER(x), 
00173                                     LENGTH(get_levels(inputs, jselect)), 
00174                                     REAL(y), q, REAL(weights), 
00175                                     nobs, REAL(VECTOR_ELT(GET_SLOT(responses, 
00176                                                PL2_scoresSym), 0)),
00177                                     yORDERED, standstat, splitctrl, 
00178                                     GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00179                                     expcovinf, &cutpoint, 
00180                                     INTEGER(S3get_splitpoint(split)),
00181                                     &maxstat, splitstat);
00182                  Free(standstat);
00183             }
00184             if (maxstat == 0) {
00185                 warning("no admissible split found\n");
00186             
00187                 if (j == 1) {          
00188                     S3set_nodeterminal(node);
00189                 } else {
00190                     pvalue[jselect - 1] = 0.0;
00191                 }
00192             } else {
00193                 S3set_variableID(split, jselect);
00194                 break;
00195             }
00196         } else {
00197             S3set_nodeterminal(node);
00198             break;
00199         }
00200     }
00201 }       
00202 
00203 
00212 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00213             
00214      SEXP ans;
00215      
00216      PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00217      C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample), 
00218                  get_maxsurrogate(get_splitctrl(controls)),
00219                  ncol(GET_SLOT(GET_SLOT(learnsample, PL2_responsesSym), 
00220                       PL2_jointtransfSym)));
00221 
00222      C_Node(ans, learnsample, weights, fitmem, controls, 0);
00223      UNPROTECT(1);
00224      return(ans);
00225 }

Generated on Thu Jun 23 14:31:48 2005 for party by  doxygen 1.4.2