// =============================================================== //
//                                                                 //
//   File      : AP_pos_var.cxx                                    //
//   Purpose   : provides PVP calculation                          //
//                                                                 //
//   Institute of Microbiology (Technical University Munich)       //
//   http://www.arb-home.de/                                       //
//                                                                 //
// =============================================================== //

#include "AP_pos_var.h"

#include <AP_pro_a_nucs.hxx>
#include <TreeNode.h>

#include <arb_progress.h>
#include <arb_strbuf.h>

#include <cctype>

#define ap_assert(cond) arb_assert(cond)

AP_pos_var::AP_pos_var(GBDATA *gb_main_, const char *ali_name_, long ali_len_, bool is_nuc_, const char *tree_name_) :
    gb_main(gb_main_),
    treeLeafs(0),
    treeNodes(0),
    progress(NULp),
    transitions(NULp),
    transversions(NULp),
    ali_len(ali_len_),
    ali_name(ARB_strdup(ali_name_)),
    tree_name(ARB_strdup(tree_name_)),
    is_nuc(is_nuc_)
{
    for (int i=0; i<256; i++) {
        frequencies[i]       = NULp;
        char_2_freq[i]       = 0;
        char_2_transition[i] = 0;
        char_2_transition[i] = 0;
    }
}

AP_pos_var::~AP_pos_var() {
    delete progress;
    free(ali_name);
    free(tree_name);
    free(transitions);
    free(transversions);
    for (int i=0; i<256; i++) free(frequencies[i]);
}

const char *AP_pos_var::parsimony(TreeNode *tree, GB_UINT4 *bases, GB_UINT4 *tbases) {
    GB_ERROR error = NULp;

    if (tree->is_leaf()) {
        if (tree->gb_node) {
            GBDATA *gb_data = GBT_find_sequence(tree->gb_node, ali_name);
            if (gb_data) {
                size_t seq_len = ali_len;
                if (GB_read_string_count(gb_data) < seq_len) {
                    seq_len = GB_read_string_count(gb_data);
                }

                unsigned char *sequence = (unsigned char*)GB_read_char_pntr(gb_data);
                for (size_t i = 0; i< seq_len; i++) {
                    long L = char_2_freq[sequence[i]];
                    if (L) {
                        ap_assert(frequencies[L]);
                        frequencies[L][i]++;
                    }
                }

                if (bases) {
                    for (size_t i = 0; i< seq_len; i++) bases[i] = char_2_transition[sequence[i]];
                }
                if (tbases) {
                    for (size_t i = 0; i< seq_len; i++) tbases[i] = char_2_transversion[sequence[i]];
                }
            }
        }
    }
    else {
        GB_UINT4 *ls  = ARB_calloc<GB_UINT4>(ali_len);
        GB_UINT4 *rs  = ARB_calloc<GB_UINT4>(ali_len);
        GB_UINT4 *lts = ARB_calloc<GB_UINT4>(ali_len);
        GB_UINT4 *rts = ARB_calloc<GB_UINT4>(ali_len);

        if (!error) error = this->parsimony(tree->get_leftson(), ls, lts);
        if (!error) error = this->parsimony(tree->get_rightson(), rs, rts);
        if (!error) {
            for (long i=0; i< ali_len; i++) {
                long l = ls[i];
                long r = rs[i];
                if (l & r) {
                    if (bases) bases[i] = l&r;
                }
                else {
                    transitions[i] ++;
                    if (bases) bases[i] = l|r;
                }
                l = lts[i];
                r = rts[i];
                if (l & r) {
                    if (tbases) tbases[i] = l&r;
                }
                else {
                    transversions[i] ++;
                    if (tbases) tbases[i] = l|r;
                }
            }
        }

        free(lts);
        free(rts);

        free(ls);
        free(rs);
    }
    progress->inc_and_check_user_abort(error);
    return error;
}


// Calculate the positional variability: control procedure
GB_ERROR AP_pos_var::retrieve(TreeNode *tree) {
    ap_assert(!treeNodes); // calling retrieve() multiple times is untested

    GB_ERROR error = NULp;
    if (!tree) {
        error = "No tree specified for PVP calculation";
    }
    else {
        treeLeafs = GBT_count_leafs(tree);
        treeNodes = leafs_2_nodes(treeLeafs, ROOTED); // used for progress

        ap_assert(treeNodes>0);

        if (is_nuc) {
            char_2_freq[(unsigned char)'a'] = 'A';
            char_2_freq[(unsigned char)'A'] = 'A';
            char_2_freq[(unsigned char)'c'] = 'C';
            char_2_freq[(unsigned char)'C'] = 'C';
            char_2_freq[(unsigned char)'g'] = 'G';
            char_2_freq[(unsigned char)'G'] = 'G';
            char_2_freq[(unsigned char)'t'] = 'U';
            char_2_freq[(unsigned char)'T'] = 'U';
            char_2_freq[(unsigned char)'u'] = 'U';
            char_2_freq[(unsigned char)'U'] = 'U';

            unsigned char *char_2_bitstring = (unsigned char *)AP_create_dna_to_ap_bases();

            for (int i=0; i<256; i++) {
                int j;
                if (i=='-') j = '.'; else j = i;
                long base = char_2_transition[i] = char_2_bitstring[j];
                char_2_transversion[i] = 0;
                if (base & (AP_A | AP_G)) char_2_transversion[i] = 1;
                if (base & (AP_C | AP_T)) char_2_transversion[i] |= 2;
            }
            delete [] char_2_bitstring;
        }
        else {
            AWT_translator *translator = AWT_get_user_translator(gb_main);

            long char_2_bitstring[256];
            {
                for (int i=0; i<256; ++i) {
                    char_2_bitstring[i] = 0;
                }
                int aa_max = translator->MaxAA();
                for (int i = 0; i<=aa_max; ++i) {
                    long          bs   = translator->index2bitset(i);
                    unsigned char spro = translator->index2spro(i);

                    char_2_bitstring[spro] = bs;
                }
            }

            for (int i=0; i<256; ++i) {
                char_2_transversion[i] = 0;
                long base = char_2_transition[i] = char_2_bitstring[i];
                if (base) char_2_freq[i] = toupper(i);
            }
        }

        progress = new arb_progress(treeNodes);

        for (int i=0; i<256; i++) {
            int j = char_2_freq[i];
            if (j && !frequencies[j]) ARB_calloc(frequencies[j], ali_len);
        }

        ARB_calloc(transitions,   ali_len);
        ARB_calloc(transversions, ali_len);

        error = this->parsimony(tree);
    }

    return error;
}

GB_ERROR AP_pos_var::delete_aliEntry_from_SAI(const char *sai_name) {
    // deletes existing alignment sub-container from SAI 'sai_name'

    GBDATA *gb_extended = GBT_find_SAI(gb_main, sai_name);
    if (gb_extended) {
        GBDATA *gb_ali = GB_search(gb_extended, ali_name, GB_FIND);
        if (gb_ali) {
            return GB_delete(gb_ali);
        }
    }
    return NULp; // sai/ali did not exist
}

GB_ERROR AP_pos_var::save_aliEntry_to_SAI(const char *sai_name) {
    // save whole SAI
    // or
    // add alignment sub-container to existing SAI

    GB_ERROR  error       = NULp;
    GBDATA   *gb_extended = GBT_find_or_create_SAI(gb_main, sai_name);

    if (!gb_extended) error = GB_await_error();
    else {
        GBDATA *gb_ali     = GB_search(gb_extended, ali_name, GB_DB);
        if (!gb_ali) error = GB_await_error();
        else {
            const char *description =
                GBS_global_string("PVP: Positional Variability by Parsimony: tree '%s' ntaxa %li",
                                  tree_name, treeLeafs);

            error = GBT_write_string(gb_ali, "_TYPE", description);
        }

        if (!error) {
            char *data = ARB_calloc<char>(ali_len+1);
            int  *sum  = ARB_calloc<int>(ali_len);

            for (int j=0; j<256 && !error; j++) {                   // get sum of frequencies
                if (frequencies[j]) {
                    for (int i=0; i<ali_len; i++) { // LOOP_VECTORIZED
                        sum[i] += frequencies[j][i];
                    }

                    if (j >= 'A' && j <= 'Z') {
                        GBDATA *gb_freq     = GB_search(gb_ali, GBS_global_string("FREQUENCIES/N%c", j), GB_INTS);
                        if (!gb_freq) error = GB_await_error();
                        else    error       = GB_write_ints(gb_freq, frequencies[j], ali_len);
                    }
                }
            }

            if (!error) {
                GBDATA *gb_transi     = GB_search(gb_ali, "FREQUENCIES/TRANSITIONS", GB_INTS);
                if (!gb_transi) error = GB_await_error();
                else    error         = GB_write_ints(gb_transi, transitions, ali_len);
            }
            if (!error) {
                GBDATA *gb_transv     = GB_search(gb_ali, "FREQUENCIES/TRANSVERSIONS", GB_INTS);
                if (!gb_transv) error = GB_await_error();
                else    error         = GB_write_ints(gb_transv, transversions, ali_len);
            }

            if (!error) {
                int    max_categ = 0;
                double logbase   = sqrt(2.0);
                double lnlogbase = log(logbase);
                double b         = .75;
                double max_rate  = 1.0;

                int tenPercentOfLeafs = treeLeafs*0.1;
                if ((10*tenPercentOfLeafs) < treeLeafs) ++tenPercentOfLeafs;

                for (int i=0; i<ali_len; i++) {
                    if (sum[i] < tenPercentOfLeafs) { // less than 10% valid characters; as documented in ../../HELP_SOURCE/source/pos_var_pars.hlp@valid
                        data[i] = '.';
                        continue;
                    }
                    if (transitions[i] == 0) {
                        data[i] = '-';
                        continue;
                    }
                    double rate = transitions[i] / (double)sum[i];
                    if (rate >= b * .95) {
                        rate = b * .95;
                    }
                    rate = -b * log(1-rate/b);
                    if (rate > max_rate) rate = max_rate;
                    rate /= max_rate;       // scaled  1.0 == fast rate
                    // ~0.0 slow rate
                    double dcat = -log(rate)/lnlogbase;
                    int icat = (int)dcat;
                    if (icat > 35) icat = 35;
                    if (icat >= max_categ) max_categ = icat + 1;
                    data[i] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"[icat];
                }

                error = GBT_write_string(gb_ali, "data", data);

                if (!error) {
                    // Generate Categories
                    GBS_strstruct out(1000);
                    for (int i = 0; i<max_categ; i++) {
                        out.putfloat(pow(1.0/logbase, i));
                        out.put(' ');
                    }

                    error = GBT_write_string(gb_ali, "_CATEGORIES", out.get_data());
                }
            }

            free(sum);
            free(data);
        }
    }

    return error;
}

