00001
00009 #include "party.h"
00010
00011
00022 void C_prediction(const double *y, int n, int q, const double *weights,
00023 const double sweights, double *ans) {
00024
00025 int i, j, jn;
00026
00027 for (j = 0; j < q; j++) {
00028 ans[j] = 0.0;
00029 jn = j * n;
00030 for (i = 0; i < n; i++)
00031 ans[j] += weights[i] * y[jn + i];
00032 ans[j] = ans[j] / sweights;
00033 }
00034 }
00035
00036
00048 void C_Node(SEXP node, SEXP learnsample, SEXP weights,
00049 SEXP fitmem, SEXP controls, int TERMINAL) {
00050
00051 int nobs, ninputs, jselect, q, j, k, i;
00052 double mincriterion, sweights, *dprediction;
00053 double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
00054 double *standstat, *splitstat;
00055 SEXP responses, inputs, x, expcovinf, linexpcov;
00056 SEXP varctrl, splitctrl, gtctrl, tgctrl, split, testy, predy;
00057 double *dxtransf, *dweights, *thisweights;
00058 int *itable;
00059
00060 nobs = get_nobs(learnsample);
00061 ninputs = get_ninputs(learnsample);
00062 varctrl = get_varctrl(controls);
00063 splitctrl = get_splitctrl(controls);
00064 gtctrl = get_gtctrl(controls);
00065 tgctrl = get_tgctrl(controls);
00066 mincriterion = get_mincriterion(gtctrl);
00067 responses = GET_SLOT(learnsample, PL2_responsesSym);
00068 inputs = GET_SLOT(learnsample, PL2_inputsSym);
00069 testy = get_test_trafo(responses);
00070 predy = get_predict_trafo(responses);
00071 q = ncol(testy);
00072
00073
00074
00075
00076 C_GlobalTest(learnsample, weights, fitmem, varctrl,
00077 gtctrl, get_minsplit(splitctrl),
00078 REAL(S3get_teststat(node)), REAL(S3get_criterion(node)));
00079
00080
00081 sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym),
00082 PL2_sumweightsSym))[0];
00083 REAL(VECTOR_ELT(node, S3_SUMWEIGHTS))[0] = sweights;
00084
00085
00086 dprediction = REAL(S3get_prediction(node));
00087
00088
00089
00090 C_prediction(REAL(predy), nobs, ncol(predy), REAL(weights),
00091 sweights, dprediction);
00092
00093
00094 teststat = REAL(S3get_teststat(node));
00095 pvalue = REAL(S3get_criterion(node));
00096
00097
00098
00099
00100 for (j = 0; j < 2; j++) {
00101
00102 smax = C_max(pvalue, ninputs);
00103 REAL(S3get_maxcriterion(node))[0] = smax;
00104
00105
00106 if (smax > mincriterion && !TERMINAL) {
00107
00108
00109 jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
00110
00111
00112 x = get_variable(inputs, jselect);
00113 if (has_missings(inputs, jselect)) {
00114 expcovinf = GET_SLOT(get_varmemory(fitmem, jselect),
00115 PL2_expcovinfSym);
00116 thisweights = C_tempweights(jselect, weights, fitmem, inputs);
00117 } else {
00118 expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
00119 thisweights = REAL(weights);
00120 }
00121
00122
00123 if (!is_nominal(inputs, jselect)) {
00124
00125
00126 split = S3get_primarysplit(node);
00127
00128
00129
00130 if (get_savesplitstats(tgctrl)) {
00131 C_init_orderedsplit(split, nobs);
00132 splitstat = REAL(S3get_splitstatistics(split));
00133 } else {
00134 C_init_orderedsplit(split, 0);
00135 splitstat = REAL(get_splitstatistics(fitmem));
00136 }
00137
00138 C_split(REAL(x), 1, REAL(testy), q, thisweights, nobs,
00139 INTEGER(get_ordering(inputs, jselect)), splitctrl,
00140 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00141 expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
00142 splitstat);
00143 S3set_variableID(split, jselect);
00144 } else {
00145
00146
00147 split = S3get_primarysplit(node);
00148
00149
00150
00151 if (get_savesplitstats(tgctrl)) {
00152 C_init_nominalsplit(split,
00153 LENGTH(get_levels(inputs, jselect)),
00154 nobs);
00155 splitstat = REAL(S3get_splitstatistics(split));
00156 } else {
00157 C_init_nominalsplit(split,
00158 LENGTH(get_levels(inputs, jselect)),
00159 0);
00160 splitstat = REAL(get_splitstatistics(fitmem));
00161 }
00162
00163 linexpcov = get_varmemory(fitmem, jselect);
00164 standstat = Calloc(get_dimension(linexpcov), double);
00165 C_standardize(REAL(GET_SLOT(linexpcov,
00166 PL2_linearstatisticSym)),
00167 REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
00168 REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
00169 get_dimension(linexpcov), get_tol(splitctrl),
00170 standstat);
00171
00172 C_splitcategorical(INTEGER(x),
00173 LENGTH(get_levels(inputs, jselect)),
00174 REAL(testy), q, thisweights,
00175 nobs, standstat, splitctrl,
00176 GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
00177 expcovinf, &cutpoint,
00178 INTEGER(S3get_splitpoint(split)),
00179 &maxstat, splitstat);
00180
00181
00182
00183
00184
00185 itable = INTEGER(S3get_table(split));
00186 dxtransf = REAL(get_transformation(inputs, jselect));
00187 for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
00188 itable[k] = 0;
00189 for (i = 0; i < nobs; i++) {
00190 if (dxtransf[k * nobs + i] * thisweights[i] > 0) {
00191 itable[k] = 1;
00192 continue;
00193 }
00194 }
00195 }
00196
00197 Free(standstat);
00198 }
00199 if (maxstat == 0) {
00200
00201 if (j == 1) {
00202 S3set_nodeterminal(node);
00203 } else {
00204
00205 pvalue[jselect - 1] = R_NegInf;
00206 }
00207 } else {
00208 S3set_variableID(split, jselect);
00209 break;
00210 }
00211 } else {
00212 S3set_nodeterminal(node);
00213 break;
00214 }
00215 }
00216 }
00217
00218
00227 SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
00228
00229 SEXP ans;
00230
00231 PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
00232 C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample),
00233 get_maxsurrogate(get_splitctrl(controls)),
00234 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
00235
00236 C_Node(ans, learnsample, weights, fitmem, controls, 0);
00237 UNPROTECT(1);
00238 return(ans);
00239 }