// ============================================================= //
//                                                               //
//   File      : CT_part.cxx                                     //
//   Purpose   :                                                 //
//                                                               //
//   Institute of Microbiology (Technical University Munich)     //
//   http://www.arb-home.de/                                     //
//                                                               //
// ============================================================= //

/* This module is designed to organize the data structure partitions
   partitions represent the edges of a tree */
// the partitions are implemented as an array of longs
// Each leaf in a GBT-Tree is represented as one Bit in the Partition

#include "CT_part.hxx"
#include "CT_common.hxx"

#define BITS_PER_PELEM (sizeof(PELEM)*8)

#if defined(DUMP_PART_INIT) || defined(UNIT_TESTS)
static const char *readable_cutmask(PELEM mask) {
    static char readable[BITS_PER_PELEM+1];
    memset(readable, '0', BITS_PER_PELEM);
    readable[BITS_PER_PELEM]        = 0;

    for (int b = BITS_PER_PELEM-1; b >= 0; --b) {
        if (mask&1) readable[b] = '1';
        mask = mask>>1;
    }
    return readable;
}
#endif

PartitionSize::PartitionSize(const int len)
    : cutmask(0),
      longs((((len + 7) / 8)+sizeof(PELEM)-1) / sizeof(PELEM)),
      bits(len),
      id(0)
{
    /*! Function to initialize the global variables above
     * @param len number of bits the part should content
     *
     * result: calculate cutmask, longs, plen
     */

    int j      = len % BITS_PER_PELEM;
    if (!j) j += BITS_PER_PELEM;

    for (int i=0; i<j; i++) {
        cutmask <<= 1;
        cutmask |= 1;
    }

#if defined(DEBUG)
    size_t possible = longs*BITS_PER_PELEM;
    arb_assert((possible-bits)<BITS_PER_PELEM); // longs is too big (wasted space)

#if defined(DUMP_PART_INIT)
    printf("leafs=%i\n", len);
    printf("cutmask='%s'\n", readable_cutmask(cutmask));
    printf("longs=%i (can hold %zu bits)\n", longs, possible);
    printf("bits=%i\n", bits);
#endif
#endif
}

#if defined(NTREE_DEBUG_FUNCTIONS)

static const CharPtrArray *namesPtr = NULp;

void PART::start_pretty_printing(const CharPtrArray& names) { namesPtr = &names; }
void PART::stop_pretty_printing() { namesPtr = NULp; }

void PART::print() const {
    // ! Testfunction to print a part
    int       k     = 0;
    const int longs = get_longs();
    const int plen  = info->get_bits();

    if (namesPtr) {
        const CharPtrArray& names = *namesPtr;
        for (int part = 0; part<=1; ++part) {
            // bool first = true;
            for (int i=0; i<longs; i++) {
                PELEM el = 1;
                for (int j=0; k<plen && size_t(j)<sizeof(PELEM)*8; j++, k++) {
                    bool bitset = p[i] & el;
                    if (bitset == part) {
                        const char *name = names[k];
#if 1
                        fputc(name[0], stdout); // first char of name
#else
                        if (!first) fputc(',', stdout);
                        else first = false;
                        fputs(name, stdout); // full name
#endif
                    }
                    el <<= 1;
                }
            }
            if (!part) {
                fputs("---", stdout);
                k = 0;
            }
        }
    }
    else {
        for (int i=0; i<longs; i++) {
            PELEM el = 1;
            for (int j=0; k<plen && size_t(j)<sizeof(PELEM)*8; j++, k++) {
                bool bitset = p[i] & el;
                fputc('0'+bitset, stdout);
                el <<= 1;
            }
        }
    }

    printf("  len=%.5f  prob=%5.1f%%  w.len=%.5f  leaf=%i  dist2center=%i\n",
           len, weight*100.0, get_len(), is_leaf_edge(), distance_to_tree_center());
}
#endif

PART *PartitionSize::create_root() const {
    /*! build a partition that totally consists of 111111...1111 that is needed to
     * build the root of a specific ntree
     */

    PART *p = new PART(this, 1.0);
    p->invert();
    arb_assert(p->is_valid());
    return p;
}

bool PART::overlaps_with(const PART *other) const {
    /*! test if two parts overlap (i.e. share common bits)
     */

    arb_assert(is_valid());
    arb_assert(other->is_valid());

    const int longs = get_longs();
    for (int i=0; i<longs; i++) {
        if (p[i] & other->p[i]) return true;
    }
    return false;
}

void PART::invert() {
    //! invert a part
    //
    // Each edge in a tree connects two subtrees.
    // These subtrees are represented by inverse partitions

    arb_assert(is_valid());

    const int longs = get_longs();
    for (int i=0; i<longs; i++) { // LOOP_VECTORIZED
        p[i] = ~p[i];
    }
    p[longs-1] &= get_cutmask();

    members = get_maxsize()-members;

    arb_assert(is_valid());
}

void PART::invertInSuperset(const PART *superset) {
    arb_assert(is_valid());
    arb_assert(is_subset_of(superset));

    const int longs = get_longs();
    for (int i=0; i<longs; i++) { // LOOP_VECTORIZED
        p[i] = p[i] ^ superset->p[i];
    }
    p[longs-1] &= get_cutmask();

    members = superset->get_members()-members;

    arb_assert(is_valid());
    arb_assert(is_subset_of(superset));
}


void PART::add_members_from(const PART *source) {
    //! destination = source or destination
    arb_assert(source->is_valid());
    arb_assert(is_valid());

    bool distinct = disjunct_from(source);

    const int longs = get_longs();
    for (int i=0; i<longs; i++) { // LOOP_VECTORIZED
        p[i] |= source->p[i];
    }

    if (distinct) {
        members += source->members;
    }
    else {
        members = count_members();
    }

    arb_assert(is_valid());
}


bool PART::equals(const PART *other) const {
    /*! return true if p1 and p2 are equal
     */
    arb_assert(is_valid());
    arb_assert(other->is_valid());

    const int longs = get_longs();
    for (int i=0; i<longs; i++) {
        if (p[i] != other->p[i]) return false;
    }
    return true;
}


unsigned PART::key() const {
    //! calculate a hashkey from part
    arb_assert(is_valid());

    PELEM ph = 0;
    const int longs = get_longs();
    for (int i=0; i<longs; i++) { // LOOP_VECTORIZED
        ph ^= p[i];
    }

    return ph;
}

inline uint8_t bytebitcount(uint8_t byte) {
    uint8_t count = 0;
    for (uint8_t b = 0; b<8; ++b) {
        if (byte&1) ++count;
        byte = byte>>1;
    }
    return count;
}
struct bitcounter {
    uint8_t bytebits[256];
    bitcounter() {
        for (unsigned i = 0; i<256; ++i) {
            bytebits[i] = bytebitcount(i);
        }
    }
};

inline int bitcount(PELEM e) {
    static bitcounter counted; // static lookup table

    int leafs = 0;
#if defined(DUMP_PART_DISTANCE)
    fprintf(stdout, "bitcount(%04x) = ", e);
#endif
    for (size_t bi = 0; bi<sizeof(e); ++bi) {
        leafs += counted.bytebits[e&0xff];
        e      = e>>8;
    }
#if defined(DUMP_PART_DISTANCE)
    fprintf(stdout, "%i\n", leafs);
#endif
    return leafs;
}

int PART::count_members() const {
    //! count the number of leafs in partition
    int leafs = 0;
    const int longs = get_longs();
    for (int i = 0; i<(longs-1); ++i) {
        leafs += bitcount(p[i]);
    }
    leafs += bitcount(p[longs-1] & get_cutmask());
    return leafs;
}

bool PART::is_standardized() const { // @@@ inline
    /*! true if PART is in standard representation.
     * @see standardize()
     */

    // may be any criteria which differs between PART and its inverse
    // if you change the criteria, generated trees will change
    // (because branch-insertion-order is affected)

    return bit_is_set(0);
}

void PART::standardize() {
    /*! standardize the partition
     *
     * Generally two PARTs are equivalent, if one is the inverted version of the other.
     * A standardized PART is equal for equivalent PARTs, i.e. may be used as key (as done in PartRegistry)
     */
    arb_assert(is_valid());
    if (!is_standardized()) {
        invert();
        arb_assert(is_standardized());
    }
    arb_assert(is_valid());
}

int PART::index() const {
    /*! calculate the first bit set in p,
     *
     * this is only useful if only one bit is set,
     * this is used to identify leafs in a ntree
     *
     * ATTENTION: p has to exist
     */
    arb_assert(is_valid());
    arb_assert(is_leaf_edge());

    int pos   = 0;
    const int longs = get_longs();
    for (int i=0; i<longs; i++) {
        PELEM p_temp = p[i];
        pos = i * sizeof(PELEM) * 8;
        if (p_temp) {
            for (; p_temp; p_temp >>= 1, pos++) {
                ;
            }
            break;
        }
    }
    return pos-1;
}

int PART::insertionOrder_cmp(const PART *other) const {
    // defines order in which edges will be inserted into the consensus tree

    if (this == other) return 0;

    int cmp = is_leaf_edge() - other->is_leaf_edge();

    if (!cmp) {
        cmp = -double_cmp(weight, other->weight); // insert bigger weight first
        if (!cmp) {
            int centerdist1 = distance_to_tree_center();
            int centerdist2 = other->distance_to_tree_center();

            cmp = centerdist1-centerdist2; // insert central edges first

            if (!cmp) {
                cmp = -double_cmp(get_len(), other->get_len()); // NOW REALLY insert bigger len first
                                                                // (change affected test results: increased in-tree-distance,
                                                                // but reduced parsimony value of result-trees)
                if (!cmp) {
                    cmp = id - other->id; // strict by definition
                    arb_assert(cmp);
                }
            }
        }
    }

    return cmp;
}
inline int PELEM_cmp(const PELEM& p1, const PELEM& p2) {
    return p1<p2 ? -1 : (p1>p2 ? 1 : 0);
}

int PART::topological_cmp(const PART *other) const {
    // define a strict order on topologies defined by edges

    if (this == other) return 0;

    arb_assert(is_standardized());
    arb_assert(other->is_standardized());

    int cmp = members - other->members;
    if (!cmp) {
        const int longs = get_longs();
        for (int i = 0; !cmp && i<longs; ++i) {
            cmp = PELEM_cmp(p[i], other->p[i]);
        }
    }

    arb_assert(contradicted(cmp, equals(other)));

    return cmp;
}

#if defined(DUMP_PART_DISTANCE)
static void dumpbits(const PELEM p) {
    PELEM el = 1;
    for (int j=0; size_t(j)<sizeof(PELEM)*8; j++) {
        bool bitset = p & el;
        fputc("-1"[bitset], stdout);
        el <<= 1;
    }
}
#endif

int PART::distanceTo(const PART *other, const PART *this_superset, const PART *other_superset) const {
    /*! calculate the distance between two PARTs.
     * 'this' is the first part to compare
     * @param other second PART to compare
     * @param this_superset whole tree (of which 'this' represents one edge)
     * @param other_superset whole tree (of which 'other' represents one edge)
     *
     * The distance D is calculated as follows:
     *     D    = O + min(d1, d2)
     * where
     *     O  := number of species present in one superset only
     *     d1 := |union(t0, o0)| - |intersection(t0,o0)| + |union(ti, oi)| - |intersection(ti,oi)|
     *     d2 := |union(t0, oi)| - |intersection(t0,oi)| + |union(ti, o0)| - |intersection(ti,o0)|
     * where
     *     t0 := 'this'      ti := inverse of 'this'  in this_superset
     *     o0 := 'other'     oi := inverse of 'other' in this_superset
     */

#if defined(DUMP_PART_DISTANCE)
    fputs("this:          ", stdout); print();
    fputs("other:         ", stdout); other->print();
    fputs("this_superset: ", stdout); this_superset->print();
    fputs("other_superset:", stdout); other_superset->print();
#endif


#if defined(ASSERTION_USED)
    if (this != this_superset) { // avoid that calls from calcTreeDistance fail here
        if (!is_real_son_of(this_superset)) { // if 'this' is NOT inside tree 'this_superset' ...
            PART *thisInverse = clone();
            thisInverse->invert();
            arb_assert(thisInverse->is_real_son_of(this_superset)); // assert inverse of 'this'  is inside tree 'this_superset'
            delete thisInverse;
        }
    }
    if (other != other_superset) { // avoid that calls from calcTreeDistance fail here
        if (!other->is_real_son_of(other_superset)) { // if 'other' is NOT inside tree 'other_superset' ...
            PART *otherInverse = other->clone();
            otherInverse->invert();
            arb_assert(otherInverse->is_real_son_of(other_superset)); // assert inverse of 'other' is inside tree 'other_superset'
            delete otherInverse;
        }
    }
#endif

    int dist = 0;

    const int longs = get_longs();
    for (int i = 0; i<longs; ++i) {
        PELEM ts = this_superset->p[i];
        PELEM os = other_superset->p[i];

        if (i == (longs-1)) {
            const PELEM cutmask = this_superset->get_cutmask(); // should be identical for all involved PARTs

            ts = ts & cutmask;
            os = os & cutmask;
        }

        const PELEM O  = ts ^ os; // calculate superset difference
        const PELEM si = ts & os; // calculate superset intersection

        const PELEM t0 = p[i] & si;
        const PELEM o0 = other->p[i] & si;

        const PELEM ti = t0 ^ si; // like invertInSuperset, but only performed in superset intersection
        const PELEM oi = o0 ^ si;

        // calculate all 4 possible difference-parts:
        const PELEM d00 = t0 ^ o0; // union(t0, o0) - intersection(t0,o0)
        const PELEM d0i = t0 ^ oi;
        const PELEM di0 = ti ^ o0;
        const PELEM dii = ti ^ oi;

        const int d1 = bitcount(d00) + bitcount(dii); // calculate absolute values and sum pairwise
        const int d2 = bitcount(d0i) + bitcount(di0);

        const int idist = bitcount(O) + std::min(d1, d2); // calculate whole difference (of current PELEM)

#if defined(DUMP_PART_DISTANCE)

#define DUMPBITS(var) do { fprintf(stdout, "%5s = %04x = ", #var, var); dumpbits(var); fputc('\n', stdout); } while(0)
#define DUMPINT(var)  fprintf(stdout, "%5s = %i\n", #var, var)

        DUMPINT(i);

        DUMPBITS(ts);
        DUMPBITS(os);
        DUMPBITS(t0);
        DUMPBITS(o0);
        DUMPBITS(ti);
        DUMPBITS(oi);
        DUMPBITS(O);
        DUMPBITS(d00);
        DUMPBITS(d0i);
        DUMPBITS(di0);
        DUMPBITS(dii);

        DUMPINT(d1);
        DUMPINT(d2);
        DUMPINT(idist);
#endif

        dist += idist; // sum up
    }

#if defined(DUMP_PART_DISTANCE)
    fprintf(stdout, "resulting dist=%i\n", dist);
#endif

    return dist;
}

int PART_FWD::calcDistance(const PART *e1, const PART *e2, const PART *t1, const PART *t2) {
    /*! calculate the distance between two PARTs (see distanceTo for details).
     * The result is the number of species that were added, removed and/or moved to the
     * other side of the partition.
     * @param e1 first PART to compare
     * @param e2 second PART to compare
     * @param t1 whole tree (of which e1 represents one edge)
     * @param t2 whole tree (of which e2 represents one edge)
     */

    return e1->distanceTo(e2, t1, t2);
}

const TreeNode *PART_FWD::get_origin(const PART *part) {
    return part ? part->get_origin() : NULp;
}

int PART_FWD::get_members(const PART *part) {
    return part->get_members();
}

void PART_FWD::destroy_part(PART* part) {
    delete part;
}


// --------------------------------------------------------------------------------

#ifdef UNIT_TESTS
#ifndef TEST_UNIT_H
#include <test_unit.h>
#endif

void TEST_PartRegistry() {
    {
        PartitionSize reg(0);
        TEST_EXPECT_EQUAL(reg.get_bits(), 0);
        TEST_EXPECT_EQUAL(reg.get_longs(), 0);
        // cutmask doesnt matter
    }

    {
        PartitionSize reg(1);
        TEST_EXPECT_EQUAL(reg.get_bits(), 1);
        TEST_EXPECT_EQUAL(reg.get_longs(), 1);
        TEST_EXPECT_EQUAL(readable_cutmask(reg.get_cutmask()), "00000000000000000000000000000001");
    }

    {
        PartitionSize reg(31);
        TEST_EXPECT_EQUAL(reg.get_bits(), 31);
        TEST_EXPECT_EQUAL(reg.get_longs(), 1);
        TEST_EXPECT_EQUAL(readable_cutmask(reg.get_cutmask()), "01111111111111111111111111111111");
    }

    {
        PartitionSize reg(32);
        TEST_EXPECT_EQUAL(reg.get_bits(), 32);
        TEST_EXPECT_EQUAL(reg.get_longs(), 1);
        TEST_EXPECT_EQUAL(readable_cutmask(reg.get_cutmask()), "11111111111111111111111111111111");
    }

    {
        PartitionSize reg(33);
        TEST_EXPECT_EQUAL(reg.get_bits(), 33);
        TEST_EXPECT_EQUAL(reg.get_longs(), 2);
        TEST_EXPECT_EQUAL(readable_cutmask(reg.get_cutmask()), "00000000000000000000000000000001");
    }

    {
        PartitionSize reg(95);
        TEST_EXPECT_EQUAL(reg.get_bits(), 95);
        TEST_EXPECT_EQUAL(reg.get_longs(), 3);
        TEST_EXPECT_EQUAL(readable_cutmask(reg.get_cutmask()), "01111111111111111111111111111111");
    }
}

#endif // UNIT_TESTS

// --------------------------------------------------------------------------------


