diff --git a/source/skeleton.cpp b/source/skeleton.cpp index 2785757064b07d4c4b9dfb73135f9d1e0f854c11..c7919b23fdf5cdfc38057c5e776ffe25066be00d 100644 --- a/source/skeleton.cpp +++ b/source/skeleton.cpp @@ -432,6 +432,22 @@ protected: } }; +struct key_action_tp +{ + ChangeType type; + key_tp coord; + + void forward(std::vector<LogicalMatrix*>& adjs) + { + uint32_t listID = std::get<0>(coord); + uint32_t from = std::get<1>(coord); + uint32_t to = std::get<2>(coord); + + (*adjs[listID])(from, to) = type; + } +}; + + struct IData { @@ -465,11 +481,14 @@ protected: */ std::vector<key_tp> ord_V_G; + std::vector<key_action_tp> unord_V_G; + /** * store the current adjacency graph and the best MAP model found during the search * > update only when better model is found */ std::vector<LogicalMatrix*> G; + std::vector<LogicalMatrix*> G_sampled; double probabilityMaxOptim = -99999999999.0; std::vector<LogicalMatrix*> G_optim; @@ -502,9 +521,21 @@ public: void SK_SS() { this->partialSK_SS(); + uint32_t sampled_subgraphs = this->unord_V_G.size(); + IntegerVector indices = Rcpp::sample(sampled_subgraphs, this->settings.I); + indices.insert(indices.begin(), 1); + std::sort(indices.begin(), indices.end()); + for(uint32_t i = 0; i < this->settings.I; i++) { Rcout << "SK-SS, Iteration: " << i << "\n"; + + for(uint32_t j = indices[i]; j < indices[i+1]; j++) + { + this->unord_V_G[j-1].forward(this->G_sampled); + } + + this->emptyCurrCountsAndG(); this->SK_SS_phase2_iter(); } @@ -529,7 +560,7 @@ public: protected: void SK_SS_phase2_iter() { - this->V_G.sampleGraph(this->G); // sample adjacency matrix, e.g., sample G_0 + //this->V_G.sampleGraph(this->G); // sample adjacency matrix, e.g., sample G_0 NumericVector twoScores = this->callScoreFull(); IntegerVector ls = Rcpp::sample(2, 2); @@ -707,7 +738,8 @@ protected: { const IntegerMatrix& sk = (*SK_list)[i]; - *(this->G)[i] = LogicalMatrix(sk.nrow(), sk.ncol()); + //*(this->G)[i] = LogicalMatrix(sk.nrow(), sk.ncol()); + *(this->G)[i] = *(this->G_sampled)[i]; Rcpp::rownames(*(this->G)[i]) = Rcpp::rownames(sk); Rcpp::colnames(*(this->G)[i]) = Rcpp::colnames(sk); } @@ -734,6 +766,10 @@ protected: (this->G_optim).push_back(new LogicalMatrix(sk.nrow(), sk.ncol())); Rcpp::rownames(*(this->G_optim)[i]) = Rcpp::rownames(sk); Rcpp::colnames(*(this->G_optim)[i]) = Rcpp::colnames(sk); + + (this->G_sampled).push_back(new LogicalMatrix(sk.nrow(), sk.ncol())); + Rcpp::rownames(*(this->G_sampled)[i]) = Rcpp::rownames(sk); + Rcpp::colnames(*(this->G_sampled)[i]) = Rcpp::colnames(sk); } } @@ -822,7 +858,11 @@ protected: } // Try to add word to Trie - this->V_G.appendChainAddOnly(this->ord_V_G); + //this->V_G.appendChainAddOnly(this->ord_V_G); + key_action_tp to_push; + to_push.type = action; + to_push.coord = kt; + this->unord_V_G.push_back(to_push); } void CheckVisitedGraphOptimum(NumericVector& scoreChanges)