// ========================================================= //
//                                                           //
//   File      : SyncRoot.cxx                                //
//   Purpose   : Sync roots of trees                         //
//                                                           //
//   Coded by Ralf Westram (coder@reallysoft.de) in May 20   //
//   http://www.arb-home.de/                                 //
//                                                           //
// ========================================================= //

#include "CT_part.hxx"
#include "SyncRoot.hxx"

#include <TreeRead.h>

using namespace std;

void RootSynchronizer::beginDeconstructionPhase() {
    arb_assert(!deconstructionPhase());

    get_species_names(species_names);
    speciesSpacePtr = new SpeciesSpace(species_names);
    treePartsPtr    = new TreeParts(*speciesSpacePtr, *this);

    arb_assert(deconstructionPhase());
}

GB_ERROR RootSynchronizer::deconstructTree(int treeIdx, bool provideProgress) {
    if (!deconstructionPhase()) beginDeconstructionPhase();

    GB_ERROR error = NULp;
    if (!valid_tree_index(treeIdx)) {
        error = GBS_global_string("invalid tree index %i (valid 0-%i)", treeIdx, int(get_tree_count())-1);
    }
    else {
        if (dtree.size() <= size_t(treeIdx)) dtree.resize(get_tree_count());
        arb_assert(dtree.size()>size_t(treeIdx));

        if (dtree[treeIdx].isNull()) {
            const SizeAwareTree *tree = get_tree(treeIdx);
            if (!tree) {
                error = GBS_global_string("tree at index #%i vanished (internal error)", treeIdx);
            }
            else {
                dtree[treeIdx] = new DeconstructedTree(*speciesSpacePtr);
                error          = dtree[treeIdx]->deconstruct_weighted(tree, treePartsPtr->get_tree_PART(treeIdx), get_tree_info(treeIdx).species_count(), 1.0, provideProgress, speciesSpacePtr->get_allSpecies(), DMODE_ROOTSYNC);
                if (!error) dtree[treeIdx]->start_sorted_retrieval();
            }
        }
    }
    return error;
}

inline void showDeconstructingSubtitle(arb_progress& progress, int treeNr) {
    progress.subtitle(GBS_global_string("Deconstructing tree #%i", treeNr+1));
}

ErrorOrSizeAwareTreePtr RootSynchronizer::find_best_root_candidate(int inTree, int accordingToTree, int& best_dist, bool provideProgress) {
    GB_ERROR             error  = NULp;
    const SizeAwareTree *result = NULp;

    best_dist = INT_MAX;

    const bool deconInTree = !has_deconstructed_tree(inTree);
    const bool deconAcTree = !has_deconstructed_tree(accordingToTree);

    SmartPtr<arb_progress> progress;
    if (provideProgress) progress = new arb_progress(2UL); // 50% deconstruct + 50% search

    // deconstruct involved trees:
    {
        SmartPtr<arb_progress> decon_progress;
        if (provideProgress) {
            const size_t steps = deconInTree + deconAcTree;
            if (steps) decon_progress = new arb_progress(steps);
        }
        if (!error) {
            const bool update = deconAcTree && decon_progress.isSet();
            if (update) showDeconstructingSubtitle(*progress, accordingToTree);
            error = deconstructTree(accordingToTree, provideProgress);
            if (update) decon_progress->inc_and_check_user_abort(error);
        }
        if (!error) {
            const bool update = deconInTree && decon_progress.isSet();
            if (update) showDeconstructingSubtitle(*decon_progress, inTree);
            error = deconstructTree(inTree, provideProgress);
            if (update) decon_progress->inc_and_check_user_abort(error);
        }

        if (provideProgress) progress->inc_and_check_user_abort(error);
    }

    if (!error) {
        if (provideProgress) progress->subtitle("Searching best matching root");

        const SizeAwareTree *accordingRoot     = get_tree(accordingToTree);
        const PART          *accordingRootPART = dtree[accordingToTree]->find_part(accordingRoot->get_leftson());
        arb_assert(accordingRootPART);

        int best_idx;
        find_best_matching_PART_in(best_dist, best_idx, accordingRootPART, *dtree[inTree], get_tree_PART(accordingToTree), get_tree_PART(inTree), provideProgress);

        arb_assert(best_idx != -1); // always expect some "best" match

        result = DOWNCAST(const SizeAwareTree*, PART_FWD::get_origin(dtree[inTree]->peek_part(best_idx)));
        arb_assert(result);

        if (provideProgress) progress->inc_and_check_user_abort(error);
    }

    if (error && provideProgress) progress->done();

    return ErrorOrSizeAwareTreePtr(error, result);
}

void RootSynchronizer::find_best_matching_PART_in(int& best_dist, int &best_idx, const PART *part, const DeconstructedTree& in, const PART *tree_part, const PART *tree_in, bool provideProgress) {
    // reset result params:
    best_dist = INT_MAX;
    best_idx  = -1;

    SmartPtr<arb_progress> findBestProgress;
    if (provideProgress) {
        findBestProgress = new arb_progress(in.get_part_count());
    }

    for (size_t idx = 0; idx<in.get_part_count(); ++idx) {
        const PART *pin = in.peek_part(idx);
        arb_assert(pin);

        if (represents_existing_edge(pin)) {
            int dist = PART_FWD::calcDistance(part, pin, tree_part, tree_in);
            if (dist<best_dist) {
                best_idx  = idx;
                best_dist = dist;
            }
            else if (best_dist == 0) {
                arb_assert(dist>best_dist); // multiple perfect matches should not occur
            }
        }
        if (provideProgress) {
            ++*findBestProgress;
            if (findBestProgress->aborted()) break;
        }
    }
}

void RootSynchronizer::find_worst_matching_PART_in(int& worst_dist, int &worst_idx, const PART *part, const DeconstructedTree& in, const PART *tree_part, const PART *tree_in) {
    // reset result params:
    worst_dist = INT_MIN;
    worst_idx  = -1;

    for (size_t idx = 0; idx<in.get_part_count(); ++idx) {
        const PART *pin = in.peek_part(idx);
        arb_assert(pin);

        if (!represents_existing_edge(pin)) continue;

        int dist = PART_FWD::calcDistance(part, pin, tree_part, tree_in);
        if (dist>worst_dist) {
            worst_idx  = idx;
            worst_dist = dist;
        }
    }
}

// #define DUMP_AGAIN

MultirootPtr RootSynchronizer::find_better_multiroot(const Multiroot& start, int best_distSum, int best_centerDist, int *movesPerTree, arb_progress *progress) {
    // best_distSum should be start.distanceSum() or better

    Multiroot    modified(start);
    MultirootPtr best;
    const int    nodes = start.size();

    int leftMoves = 0;
    for (int t = 0; t<nodes; ++t) {
        leftMoves += movesPerTree[t];
    }

    for (int t = 0; t<nodes && best.isNull(); ++t) {
        if (movesPerTree[t]>0) {
            --movesPerTree[t];

            ConstSizeAwareTreePtr    node = start.get_node(t);
            ConstSizeAwareTreeVector neighbors;

            // store (up to) 4 neighbors nodes (representing adjacent edges):
            {
                if (!node->is_leaf()) { // try branches to both sons
                    neighbors.push_back(node->get_leftson());
                    neighbors.push_back(node->get_rightson());
                }

                ConstSizeAwareTreePtr brother = node->get_brother();
                arb_assert(brother);

                if (node->is_son_of_root()) {
                    if (!brother->is_leaf()) { // try branches to both sons of brother
                        neighbors.push_back(brother->get_leftson());
                        neighbors.push_back(brother->get_rightson());
                    }
                }
                else { // try branches from father to brother and grandpa (or uncle at root)
                    neighbors.push_back(brother);
                    neighbors.push_back(node->get_father());
                }

                arb_assert(neighbors.size()>0);
                arb_assert(neighbors.size()<=4);
            }

            // iterate all neighbors:
            for (ConstSizeAwareTreeVector::const_iterator n = neighbors.begin(); n != neighbors.end() && best.isNull(); ++n) {
                ConstSizeAwareTreePtr next_node = *n;
                modified.replace_node(t, next_node);

                // calc current distance and keep best found Multiroot:
                int mod_distSum = modified.distanceSum(*this);
                if (mod_distSum<=best_distSum) {
                    bool takeModified = mod_distSum<best_distSum;
                    if (!takeModified) {
                        arb_assert(mod_distSum == best_distSum);
                        const int mod_centerDist = modified.distanceToCenterSum(*this);
                        if (mod_centerDist<best_centerDist) {
#if defined(DUMP_AGAIN)
                            fprintf(stderr, "- again found mod_distSum=%i (center dist: %i -> %i)\n", mod_distSum, best_centerDist, mod_centerDist);
#endif
                            best_centerDist = mod_centerDist;
                            takeModified    = true;
                        }
                    }
                    if (takeModified) {
                        best_distSum = mod_distSum;
                        best         = new Multiroot(modified);
                    }
                }

                if (progress && progress->aborted()) {
                    break;
                }

                if (leftMoves>1 && best.isNull()) {
                    MultirootPtr recursed = find_better_multiroot(modified, best_distSum, best_centerDist, movesPerTree, progress);
                    if (recursed.isSet()) {
                        int recursed_distSum = recursed->distanceSum(*this);
                        if (recursed_distSum<=best_distSum) {
                            bool takeRecursed = recursed_distSum<best_distSum;
                            if (!takeRecursed) {
                                arb_assert(recursed_distSum == best_distSum);
                                const int rec_centerDist = recursed->distanceToCenterSum(*this);
                                if (rec_centerDist<best_centerDist) {
#if defined(DUMP_AGAIN)
                                    fprintf(stderr, "- again found recursed_distSum=%i (center dist: %i -> %i)\n", recursed_distSum, best_centerDist, rec_centerDist);
#endif
                                    best_centerDist = rec_centerDist;
                                    takeRecursed    = true;
                                }
                            }
                            if (takeRecursed) {
                                best_distSum = recursed_distSum;
                                best         = recursed;
                            }
                        }
                    }
                }
            }

            ++movesPerTree[t];
        }
    }
    return best;
}

GB_ERROR RootSynchronizer::deconstruct_all_trees(bool provideProgress) {
    SmartPtr<arb_progress> progress;
    GB_ERROR               error = NULp;

    if (provideProgress) {
        progress = new arb_progress("Deconstructing trees", get_tree_count());
    }

    const int treeCount = get_tree_count();

    for (int t = 0; t<treeCount && !error; ++t) {
        if (provideProgress) showDeconstructingSubtitle(*progress, t);
        error = deconstructTree(t, provideProgress);
        if (provideProgress) progress->inc_and_check_user_abort(error);
    }

    if (error && provideProgress) progress->done();

    return error;
}

#define DUMP_DEPTH

ErrorOrMultirootPtr RootSynchronizer::find_good_roots_for_trees(const int MAX_DEPTH, arb_progress *progress) {
    GB_ERROR error = deconstruct_all_trees(false);
    arb_assert(deconstructionPhase());

    if (error) {
        MultirootPtr none;
        return ErrorOrMultirootPtr(error, none);
    }

    int depth = 0;

    ErrorOrMultirootPtr emr = get_current_roots();
    if (!emr.hasError()) {
        const int    CANDIDATES = 2;
        MultirootPtr mr[CANDIDATES];
        int          mr_dist[CANDIDATES];
        int          mr_centerDist[CANDIDATES];

        mr[0] = emr.getValue();
        mr[1] = get_innermost_edges(); // add second, speculative multiroot (at centermost branches)!

        int best_c = -1;
        {
            int best_dist       = INT_MAX;
            int best_centerDist = INT_MAX;

            for (int c = 0; c<CANDIDATES; ++c) {
                arb_assert(mr[c].isSet());
                mr_dist[c]       = mr[c]->distanceSum(*this);
                mr_centerDist[c] = mr[c]->distanceToCenterSum(*this);

                if (mr_dist[c]<best_dist || (mr_dist[c] == best_dist && mr_centerDist[c]<best_centerDist)) {
                    best_c          = c;
                    best_dist       = mr_dist[c];
                    best_centerDist = mr_centerDist[c];
                }
            }
        }
        arb_assert(best_c != -1);

        bool done = false;
        while (!done) {
            if (progress) {
                progress->subtitle(GBS_global_string("distance=%i / center distance=%i", mr_dist[best_c], mr_centerDist[best_c]));
                if (progress->aborted()) {
#if defined(DUMP_DEPTH)
                    fprintf(stderr, "Aborting recursion (user abort)\n");
#endif
                    break;
                }
            }

            int cand_checked = 0;
            for (int pass = 1; pass<=2 && !done; ++pass) { // pass1 = optimize best_c; pass2=optimize rest
                for (int c = 0; c<CANDIDATES && !done; ++c) {
                    bool search = pass == 1 ? (c == best_c) : (c != best_c);
                    if (search) {
                        const int nodes = mr[c]->size();
                        int       movesPerTree[nodes];
                        for (int n = 0; n<nodes; ++n) {
                            movesPerTree[n] = depth+1;
                        }
                        MultirootPtr better_mr = find_better_multiroot(*(mr[c]), mr_dist[c], mr_centerDist[c], movesPerTree, progress);
                        ++cand_checked;
                        if (better_mr.isNull()) {
#if defined(DUMP_DEPTH)
                            fprintf(stderr, "Found no better multiroot[%i] at depth=%i (dist=%i; center-dist=%i)\n", c, depth, mr_dist[c], mr_centerDist[c]);
#endif
                            if (cand_checked == CANDIDATES) { // do not increase depth if not all candidates checked yet
                                if (depth == MAX_DEPTH) {
                                    done = true; // no improvement -> done
                                }
                                else {
                                    ++depth; // search deeper
#if defined(DUMP_DEPTH)
                                    fprintf(stderr, "Increasing depth to %i\n", depth);
#endif
                                }
                            }
                        }
                        else {
                            mr[c]            = better_mr;
                            mr_dist[c]       = better_mr->distanceSum(*this);
                            mr_centerDist[c] = better_mr->distanceToCenterSum(*this);

#if defined(DUMP_DEPTH)
                            fprintf(stderr, "Found better multiroot[%i] at depth=%i (dist=%i; center-dist=%i)\n", c, depth, mr_dist[c], mr_centerDist[c]);
#endif
                            if (c != best_c) {
                                if (mr_dist[c]<mr_dist[best_c] || (mr_dist[c] == mr_dist[best_c] && mr_centerDist[c]<mr_centerDist[best_c])) {
                                    best_c = c;
                                }
                            }

                            // decrement depth again after better root-combi was found:
                            if (depth>0) --depth;
#if defined(DUMP_DEPTH)
                            fprintf(stderr, "[continuing with depth=%i]\n", depth);
#endif
                        }
                    }
                }
            }
        }

        return ErrorOrMultirootPtr(NULp, mr[best_c]);
    }
    return emr;
}

ErrorOrMultirootPtr RootSynchronizer::get_current_roots() const {
    MultirootPtr result;
    GB_ERROR     error = NULp;
    if (get_tree_count()<2) {
        error = "Need at least two trees";
    }
    else {
        result = new Multiroot(*this);
    }
    return ErrorOrMultirootPtr(error, result);
}

MultirootPtr RootSynchronizer::get_innermost_edges() const {
    arb_assert(allTreesDeconstructed());

    MultirootPtr mr = new Multiroot(*this);

    // set nodes to innermost edges:
    for (size_t i = 0; i<get_tree_count(); ++i) {
        const PART *innerPart = dtree[i]->find_innermost_part();
        arb_assert(innerPart);

        const SizeAwareTree *innerNode = DOWNCAST(const SizeAwareTree*, PART_FWD::get_origin(innerPart));
        mr->replace_node(i, innerNode);
    }

    return mr;
}

int RootSynchronizer::calcEdgeDistance(int i1, const SizeAwareTree *n1, int i2, const SizeAwareTree *n2) const {
    arb_assert(deconstructionPhase());

    arb_assert(valid_tree_index(i1));
    arb_assert(valid_tree_index(i2));

    arb_assert(!n1->is_root_node());
    arb_assert(!n2->is_root_node());

    const PART *p1 = dtree[i1]->find_part(n1);
    const PART *p2 = dtree[i2]->find_part(n2);

    arb_assert(p1);
    arb_assert(p2);

    const PART *t1 = get_tree_PART(i1);
    const PART *t2 = get_tree_PART(i2);

    return PART_FWD::calcDistance(p1, p2, t1, t2);
}

int RootSynchronizer::calcTreeDistance(int i1, int i2) const {
    const PART *t1 = get_tree_PART(i1);
    const PART *t2 = get_tree_PART(i2);

    return PART_FWD::calcDistance(t1, t2, t1, t2);
}

int RootSynchronizer::minDistanceSum() const {
    int sum = 0;
    for (size_t i = 0; i<get_tree_count(); ++i) {
        for (size_t j = 0; j<i; ++j) {
            sum += calcTreeDistance(i, j);
        }
    }
    return sum;
}

int Multiroot::lazy_eval_distance(const RootSynchronizer& rsync, int i, int j) const {
    int dist = distance.get(i, j);
    if (dist == UNKNOWN_DISTANCE) {
        dist = rsync.calcEdgeDistance(i, get_node(i), j, get_node(j));
        distance.set(i, j, dist);
    }
    arb_assert(dist >= 0); // distance should be up-to-date now!
    return dist;
}

int Multiroot::distanceSum(const RootSynchronizer& rsync) const {
    arb_assert(rsync.deconstructionPhase());

    int sum = 0;
    for (int i = 0; i<size(); ++i) {
        for (int j = 0; j<i; ++j) {
            sum += lazy_eval_distance(rsync, i, j);
        }
    }
    return sum;
}

int Multiroot::distanceToCenterSum(const RootSynchronizer& rsync) const {
    int sum = 0;
    for (int i = 0; i<size(); ++i) {
        const PART *part  = rsync.get_edge_PART(i, get_node(i));
        sum              += part->distance_to_tree_center();
    }
    return sum;
}


int Multiroot::singleTreeDistanceSum(const RootSynchronizer& rsync, int idx) {
    arb_assert(idx>=0 && idx<size());
    int sum = 0;
    for (int i = 0; i<size(); ++i) {
        if (i != idx) {
            sum += lazy_eval_distance(rsync, i, idx);
        }
    }
    return sum;
}

void Multiroot::replace_node(int idx, ConstSizeAwareTreePtr newNode) {
    arb_assert(newNode); // missing node
    arb_assert(idx<size());

    node[idx] = newNode;
    // invalidate distances affected by replaced node:
    for (int i = 0; i<size(); ++i) {
        if (i != idx) {
            distance.set(i, idx, UNKNOWN_DISTANCE);
        }
    }
}

