// =========================================================== //
//                                                             //
//   File      : NT_taxonomy.cxx                               //
//   Purpose   : Compare two trees by taxonomy                 //
//                                                             //
//   Coded by Ralf Westram (coder@reallysoft.de) in May 2015   //
//   http://www.arb-home.de/                                   //
//                                                             //
// =========================================================== //

#include "ad_trees.h"
#include "NT_local.h"

#include <aw_window.hxx>
#include <aw_root.hxx>
#include <aw_awar.hxx>
#include <aw_msg.hxx>

#include <sel_boxes.hxx>
#include <TreeCallbacks.hxx>
#include <TreeAdmin.h>

#include <arb_progress.h>
#include <arb_global_defs.h>
#include <item_sel_list.h>

#include <set>
#include <map>

#define TREE_COMPARE_PREFIX     "ad_tree/compare/"

#define AWAR_TREE_COMPARE_ACTION         TREE_COMPARE_PREFIX "action"
#define AWAR_TREE_COMPARE_MIN_TAX_LEVELS TREE_COMPARE_PREFIX "taxdist"
#define AWAR_TREE_COMPARE_WRITE_FIELD    TREE_COMPARE_PREFIX "field"

enum Action { // uses same values as NT_mark_all_cb; see ../SL/TREEDISP/TreeCallbacks.cxx@MARK_MODE_LOWER_BITS
    UNMARK = 0,
    MARK   = 1,
    INVERT = 2,
};

enum Target {
    ALL,
    TAX,
    COMMON,
    MISSING_LEFT,
    MISSING_RIGHT,
};

static TreeNode *findParentGroup(TreeNode *node) { // @@@ DRY vs/using TreeNode::find_parent_with_groupInfo?
    TreeNode *parent_group = NULp;

    while (!parent_group && node->father) {
        node = node->father;
        if (node->is_normal_group()) {
            parent_group = node;
        }
    }

    return parent_group;
}

static int countTaxLevel(TreeNode *node) { // @@@ DRY vs TreeNode::calc_clade_level????
    int       taxlevel     = node->is_leaf() ? 0 : node->is_normal_group();
    TreeNode *parent_group = findParentGroup(node);
    while (parent_group) {
        ++taxlevel;
        parent_group = findParentGroup(parent_group);
    }
    return taxlevel;
}

static int calcTaxDifference(TreeNode *g1, TreeNode *g2) {
    // returns difference in taxonomy-levels
    //
    // difference defined such that:
    //   diff("/A/B", "/A/B") == 0
    //   diff("/A/B", "/A"  ) == 1
    //   diff("/A/B", "/A/C") == 2
    //   diff("/A/B/C", "/A/D/E") == 4
    //   diff("/A/B/C", "/A/D/C") == 4
    //   diff("/A/B/C", "/A/C") == 3

    nt_assert(g1->is_normal_group() && g2->is_normal_group()); // has to be called with root-nodes of groups!

    TreeNode *p1 = findParentGroup(g1);
    TreeNode *p2 = findParentGroup(g2);

    int taxdiff = 0;
    if (p1) {
        if (p2) {
            int pdiff  = calcTaxDifference(p1, p2);
            int p1diff = calcTaxDifference(p1, g2) + 1;
            int p2diff = calcTaxDifference(g1, p2) + 1;

            if (pdiff || strcmp(g1->name, g2->name) != 0) {
                // if parent-taxonomy differs -> ignore names of g1/g2 -> diff=2
                // if parent-taxonomy matches -> if names of g1 and g2 match -> no diff
                //                               if names of g1 and g2 differ -> diff=2
                pdiff += 2;
            }

            taxdiff = std::min(pdiff, std::min(p1diff, p2diff));
        }
        else {
            taxdiff = calcTaxDifference(p1, g2) + 1;
        }
    }
    else {
        if (p2) {
            taxdiff = calcTaxDifference(g1, p2) + 1;
        }
        else {
            if (strcmp(g1->name, g2->name) != 0) { // logic similar to (p1 && p2) above
                taxdiff = 2;
            }
        }
    }

    return taxdiff;
}

class SpeciesInTwoTrees {
    TreeNode *tree1;
    TreeNode *tree2;

public:
    SpeciesInTwoTrees()
        : tree1(NULp),
          tree2(NULp)
    {}

    void setSpecies(TreeNode *species, bool first) {
        nt_assert(species->is_leaf());
        if (first) {
            nt_assert(!tree1);
            tree1 = species;
        }
        else {
            nt_assert(!tree2);
            tree2 = species;
        }
    }

    bool occursInBothTrees() const { return tree1 && tree2; }
    int calcTaxDiffLevel() const {
        TreeNode *parent_group1 = findParentGroup(tree1);
        TreeNode *parent_group2 = findParentGroup(tree2);

        int taxdiff = 0;

        if (parent_group1) {
            if (parent_group2) {
                taxdiff = calcTaxDifference(parent_group1, parent_group2);
            }
            else {
                taxdiff = countTaxLevel(parent_group1);
            }
        }
        else {
            if (parent_group2) {
                taxdiff = countTaxLevel(parent_group2);
            }
            // else -> both outside any group -> no diff
        }

        return taxdiff;
    }
};

typedef std::set<const char *, charpLess> NameSet;
typedef std::map<const char *, SpeciesInTwoTrees, charpLess> TwoTreeMap;

static void mapTree(TreeNode *node, TwoTreeMap& tmap, bool first) {
    if (node->is_leaf()) {
        nt_assert(node->name);
        tmap[node->name].setSpecies(node, first);
    }
    else {
        mapTree(node->leftson, tmap, first);
        mapTree(node->rightson, tmap, first);
    }
}

static void mark_action(AW_window *aws, TREE_canvas *ntw, Target target) {
    AW_root *aw_root = aws->get_root();

    Action action = Action(aw_root->awar(AWAR_TREE_COMPARE_ACTION)->read_int());

    arb_progress progress("Mark species");
    if (target == ALL) {
        NT_mark_all_cb(NULp, ntw, action);
    }
    else {
        progress.subtitle("Loading trees");

        const char *treename_left  = TreeAdmin::source_tree_awar(aw_root)->read_char_pntr();
        const char *treename_right = TreeAdmin::dest_tree_awar(aw_root)->read_char_pntr();

        GBDATA         *gb_main = ntw->gb_main;
        GB_transaction  ta(gb_main);

        GBDATA *gb_species_data = GBT_get_species_data(gb_main);

        TreeNode *tree_left  = GBT_read_tree(gb_main, treename_left,  new SimpleRoot);
        TreeNode *tree_right = NULp;

        GB_ERROR load_error = NULp;
        if (!tree_left) {
            load_error = GB_await_error();
        }
        else {
            tree_right = GBT_read_tree(gb_main, treename_right, new SimpleRoot);
            if (!tree_right) load_error = GB_await_error();
        }

        size_t missing   = 0;
        size_t targetted = 0;
        bool   had_error = false;

        if (load_error) {
            aw_message(load_error);
            had_error = load_error;
        }
        else {
            nt_assert(tree_left && tree_right);
            if (target == TAX) {
                int min_tax_levels = atoi(aw_root->awar(AWAR_TREE_COMPARE_MIN_TAX_LEVELS)->read_char_pntr());

                TwoTreeMap tmap;
                mapTree(tree_left,  tmap, true);
                mapTree(tree_right, tmap, false);

                size_t commonSpeciesCount = 0;
                for (TwoTreeMap::iterator s = tmap.begin(); s != tmap.end(); ++s) {
                    if (s->second.occursInBothTrees()) commonSpeciesCount++;
                }

                const char *fieldName    = prepare_and_get_selected_itemfield(aw_root, AWAR_TREE_COMPARE_WRITE_FIELD, gb_main, SPECIES_get_selector(), FIF_ALLOW_NONE);
                bool        writeToField = fieldName;
                GB_ERROR    error        = GB_incur_error();

                if (!error) {
                    arb_progress subprogress("Comparing taxonomy info", commonSpeciesCount);
                    for (TwoTreeMap::iterator s = tmap.begin(); s != tmap.end() && !error; ++s) {
                        const SpeciesInTwoTrees& species = s->second;

                        if (species.occursInBothTrees()) {
                            int taxDiffLevel = species.calcTaxDiffLevel();
                            if (taxDiffLevel>min_tax_levels) {
                                ++targetted;
                                GBDATA *gb_species = GBT_find_species_rel_species_data(gb_species_data, s->first);
                                if (!gb_species) {
                                    ++missing;
                                }
                                else {
                                    switch (action) {
                                        case UNMARK: GB_write_flag(gb_species, 0);                         break;
                                        case MARK:   GB_write_flag(gb_species, 1);                         break;
                                        case INVERT: GB_write_flag(gb_species, !GB_read_flag(gb_species)); break;
                                    }
                                    if (writeToField) {
                                        GBDATA *gb_field     = GBT_searchOrCreate_itemfield_according_to_changekey(gb_species, fieldName, SPECIES_get_selector().change_key_path);
                                        if (!gb_field) error = GB_await_error();
                                        if (!error) error    = GB_write_lossless_int(gb_field, taxDiffLevel);
                                    }
                                }
                            }
                            subprogress.inc_and_check_user_abort(error);
                        }
                    }
                }
                aw_message_if(error);
                had_error = error;
            }
            else {
                progress.subtitle("Intersecting tree members");

                NameSet in_left;
                NameSet in_right;
                {
                    size_t   count_left, count_right;
                    GB_CSTR *names_left  = GBT_get_names_of_species_in_tree(tree_left,  &count_left);
                    GB_CSTR *names_right = GBT_get_names_of_species_in_tree(tree_right, &count_right);

                    for(size_t i= 0; i<count_left;  ++i) in_left .insert(names_left[i]);
                    for(size_t i= 0; i<count_right; ++i) in_right.insert(names_right[i]);

                    free(names_right);
                    free(names_left);
                }

                {
                    NameSet& in_one   = target == MISSING_LEFT ? in_right : in_left;
                    NameSet& in_other = target == MISSING_LEFT ? in_left : in_right;

                    for (NameSet::const_iterator i = in_one.begin(); i != in_one.end(); ++i) {
                        bool is_in_other = in_other.find(*i) != in_other.end();
                        bool is_target   = is_in_other == (target == COMMON);

                        if (is_target) {
                            ++targetted;
                            GBDATA *gb_species = GBT_find_species_rel_species_data(gb_species_data, *i);
                            if (!gb_species) {
                                ++missing;
                            }
                            else {
                                switch (action) {
                                    case UNMARK: GB_write_flag(gb_species, 0);                         break;
                                    case MARK:   GB_write_flag(gb_species, 1);                         break;
                                    case INVERT: GB_write_flag(gb_species, !GB_read_flag(gb_species)); break;
                                }
                            }
                        }
                    }
                }
            }
        }

        if (!had_error) {
            if (!targetted) {
                aw_message("Warning: no species targetted");
            }
            else if (missing) {
                aw_message(GBS_global_string("Warning: %zu targeted species could not be found\n"
                                             "(might be caused by zombies in your trees)", missing));
            }
        }

        destroy(tree_right);
        destroy(tree_left);
    }
}

void NT_create_compare_taxonomy_awars(AW_root *aw_root, AW_default props) {
    char *currTree = aw_root->awar(AWAR_TREE_NAME)->read_string();

    aw_root->awar_int   (AWAR_TREE_COMPARE_ACTION,         MARK,              props);
    aw_root->awar_string(AWAR_TREE_COMPARE_MIN_TAX_LEVELS, "0",               props);
    aw_root->awar_string(AWAR_TREE_COMPARE_WRITE_FIELD,    NO_FIELD_SELECTED, props);

    free(currTree);
}

AW_window *NT_create_compare_taxonomy_window(AW_root *aw_root, TREE_canvas *ntw) {
    AW_window_simple *aws = new AW_window_simple;
    aws->init(aw_root, "COMPARE_TAXONOMY", "Compare taxonomy");
    aws->load_xfig("compare_taxonomy.fig");

    aws->at("close");
    aws->callback(AW_POPDOWN);
    aws->create_button("CLOSE", "CLOSE", "C");

    aws->at("help");
    aws->callback(makeHelpCallback("compare_taxonomy.hlp"));
    aws->create_button("HELP", "HELP", "H");

    aws->at("action");
    aws->create_toggle_field(AWAR_TREE_COMPARE_ACTION);
    aws->insert_default_toggle("mark",   "m", MARK);
    aws->insert_toggle        ("unmark", "u", UNMARK);
    aws->insert_toggle        ("invert", "i", INVERT);
    aws->update_toggle_field();

    aws->at("all");       aws->callback(makeWindowCallback(mark_action, ntw, ALL));           aws->create_autosize_button("all",       "all species");
    aws->at("tax");       aws->callback(makeWindowCallback(mark_action, ntw, TAX));           aws->create_autosize_button("tax",       "species with taxonomy changed");
    aws->at("common");    aws->callback(makeWindowCallback(mark_action, ntw, COMMON));        aws->create_autosize_button("common",    "common species");
    aws->at("missleft");  aws->callback(makeWindowCallback(mark_action, ntw, MISSING_LEFT));  aws->create_autosize_button("missleft",  "species missing in left tree");
    aws->at("missright"); aws->callback(makeWindowCallback(mark_action, ntw, MISSING_RIGHT)); aws->create_autosize_button("missright", "species missing in right tree");

    aws->at("levels");
    aws->create_input_field(AWAR_TREE_COMPARE_MIN_TAX_LEVELS, 5);

    create_itemfield_selection_button(aws, FieldSelDef(AWAR_TREE_COMPARE_WRITE_FIELD, ntw->gb_main, SPECIES_get_selector(), FIELD_FILTER_INT_WRITEABLE, "taxdiff-field", SF_ALLOW_NEW), "field");

    NT_create_twoTreeSelection(aws);

    return aws;
}


