// =============================================================== //
//                                                                 //
//   File      : AP_filter.cxx                                     //
//   Purpose   : alignment filter / weights                        //
//                                                                 //
//   Institute of Microbiology (Technical University Munich)       //
//   http://www.arb-home.de/                                       //
//                                                                 //
// =============================================================== //

#include "AP_filter.hxx"
#include <arbdb.h>

// ------------------
//      AP_filter

void AP_filter::init(size_t size) {
    filter_mask        = new bool[size];
    filter_len         = size;
    real_len           = 0;
    update             = AP_timer();
    simplify_type      = AWT_FILTER_SIMPLIFY_NOT_INITIALIZED;
    simplify[0]        = 0;   // silence cppcheck-warning
    bootstrap          = NULp;
    filterpos_2_seqpos = NULp;

#if defined(ASSERTION_USED)
    checked_for_validity = false;
#endif
}

void AP_filter::make_permeable(size_t size) {
    init(size);
    real_len = filter_len;
    for (size_t i = 0; i < size; i++) filter_mask[i] = true;
}

void AP_filter::init_from_string(const char *ifilter, const char *zerobases, size_t size) {
    init(size);

    bool   char2mask[256];
    size_t i;

    for (i = 0; i<256; ++i) char2mask[i] = true;
    if (zerobases) {
        for (i = 0; zerobases[i]; ++i) char2mask[safeCharIndex(zerobases[i])] = false;
    }
    else {
        char2mask['0'] = false;
    }

    real_len = 0;
    for (i = 0; i < size && ifilter[i]; ++i) {
        real_len += int(filter_mask[i] = char2mask[safeCharIndex(ifilter[i])]);
    }
    for (; i < size; i++) {
        filter_mask[i] = true;
        real_len++;
    }
}


AP_filter::AP_filter(size_t size) {
    make_permeable(size);
}

AP_filter::AP_filter(const AP_filter& other)
    : filter_mask(new bool[other.filter_len]),
      filter_len(other.filter_len),
      real_len(other.real_len),
      update(other.update),
      simplify_type(other.simplify_type),
      bootstrap(NULp),
      filterpos_2_seqpos(NULp)
{
    memcpy(filter_mask, other.filter_mask, filter_len*sizeof(*filter_mask));
    memcpy(simplify, other.simplify, sizeof(simplify)*sizeof(*simplify));
    if (other.bootstrap) {
        bootstrap = new size_t[real_len];
        memcpy(bootstrap, other.bootstrap, real_len*sizeof(*bootstrap));
    }
    if (other.filterpos_2_seqpos) {
        filterpos_2_seqpos = new size_t[real_len];
        memcpy(filterpos_2_seqpos, other.filterpos_2_seqpos, real_len*sizeof(*filterpos_2_seqpos));
    }
#if defined(ASSERTION_USED)
    checked_for_validity = other.checked_for_validity;
#endif
}



AP_filter::AP_filter(const char *ifilter, const char *zerobases, size_t size) {
    if (!ifilter || !*ifilter) {
        make_permeable(size);
    }
    else {
        init_from_string(ifilter, zerobases, size);
    }
}

AP_filter::AP_filter(AF_Not, const AP_filter& other) {
    size_t      size  = other.get_length();
    const bool *omask = other.filter_mask;

    init(size);
    for (size_t i = 0; i < size; i++) {
        real_len += (filter_mask[i] = !omask[i]);
    }
}

AP_filter::AP_filter(const AP_filter& f1, AF_Combine comb, const AP_filter& f2) {
    size_t size = f1.get_length();
    af_assert(size == f2.get_length());

    init(size);

    const bool *m1 = f1.filter_mask;
    const bool *m2 = f2.filter_mask;

    switch (comb) {
        case AND:
            for (size_t i = 0; i<size; ++i) {
                real_len += (filter_mask[i] = (m1[i] && m2[i]));
            }
            break;
        case OR:
            for (size_t i = 0; i<size; ++i) {
                real_len += (filter_mask[i] = (m1[i] || m2[i]));
            }
            break;
        case XOR:
            for (size_t i = 0; i<size; ++i) {
                real_len += (filter_mask[i] = (m1[i] ^ m2[i]));
            }
            break;
    }
}

AP_filter::~AP_filter() {
    delete [] bootstrap;
    delete [] filter_mask;
    delete [] filterpos_2_seqpos;
}

char *AP_filter::to_string() const {
    af_assert(checked_for_validity);

    char *data = ARB_alloc<char>(filter_len+1);

    for (size_t i=0; i<filter_len; ++i) {
        data[i] = "01"[filter_mask[i]];
    }
    data[filter_len] = 0;

    return data;
}


void AP_filter::enable_simplify(AWT_FILTER_SIMPLIFY type) {
    if (type != simplify_type) {
        int i;
        for (i=0; i<32; i++) {
            simplify[i] = '.';
        }
        for (; i<256; i++) { // LOOP_VECTORIZED // tested down to gcc 5.5.0 (may fail on older gcc versions)
            simplify[i] = i;
        }
        switch (type) {
            case AWT_FILTER_SIMPLIFY_DNA:
                simplify[(unsigned char)'g'] = 'a';
                simplify[(unsigned char)'G'] = 'A';
                simplify[(unsigned char)'u'] = 'c';
                simplify[(unsigned char)'t'] = 'c';
                simplify[(unsigned char)'U'] = 'C';
                simplify[(unsigned char)'T'] = 'C';
                break;
            case AWT_FILTER_SIMPLIFY_PROTEIN:
                af_assert(0);                           // not implemented or impossible!?
                break;
            case AWT_FILTER_SIMPLIFY_NONE:
                break;
            default:
                af_assert(0);
                break;
        }

        simplify_type = type;
    }
}

void AP_filter::calc_filterpos_2_seqpos() {
    af_assert(checked_for_validity);
    af_assert(real_len>0);

    delete [] filterpos_2_seqpos;
    filterpos_2_seqpos = new size_t[real_len];
    size_t i, j;
    for (i=j=0; i<filter_len; ++i) {
        if (filter_mask[i]) {
            filterpos_2_seqpos[j++] = i;
        }
    }
}

void AP_filter::enable_bootstrap() {
    af_assert(checked_for_validity);
    af_assert(real_len>0);

    delete [] bootstrap;
    bootstrap = new size_t[real_len];

    af_assert(filter_len < RAND_MAX);

    for (size_t i = 0; i<real_len; ++i) {
        int r = GB_random(real_len);
        af_assert(r >= 0);     // otherwise overflow in random number generator
        bootstrap[i] = r;
    }
}

char *AP_filter::blowup_string(const char *filtered_string, char fillChar) const {
    /*! blow up 'filtered_string' to unfiltered length
     * by inserting 'fillChar' at filtered positions
     */
    af_assert(checked_for_validity);

    char   *blownup = ARB_alloc<char>(filter_len+1);
    size_t  f       = 0;

    for (size_t i = 0; i<filter_len; ++i) {
        blownup[i] = use_position(i) ? filtered_string[f++] : fillChar;
    }
    blownup[filter_len] = 0;

    return blownup;
}

char *AP_filter::filter_string(const char *fulllen_string) const {
    /*! filter given 'fulllen_string'
     */

    af_assert(checked_for_validity);

    char   *filtered = ARB_alloc<char>(real_len+1);
    size_t  f        = 0;

    get_filterpos_2_seqpos(); // create if missing
    for (size_t i = 0; i<real_len; ++i) {
        size_t p      = filterpos_2_seqpos[i];
        filtered[f++] = fulllen_string[p];
    }
    filtered[f] = 0;

    return filtered;
}

// -------------------
//      AP_weights

AP_weights::AP_weights(const AP_filter *fil)
    : len(fil->get_filtered_length()),
      weights(NULp)
{
}

AP_weights::AP_weights(const GB_UINT4 *w, size_t wlen, const AP_filter *fil)
    : len(fil->get_filtered_length()),
      weights(NULp)
{
    ARB_alloc_aligned(weights, len);

    af_assert(wlen == fil->get_length());

    size_t i, j;
    for (j=i=0; j<wlen; ++j) {
        if (fil->use_position(j)) {
            weights[i++] = w[j];
        }
    }
    af_assert(j <= fil->get_length());
    af_assert(i == fil->get_filtered_length());
}

AP_weights::AP_weights(const AP_weights& other)
    : len(other.len),
      weights(NULp)
{
    if (other.weights) {
        ARB_alloc_aligned(weights, len);
        memcpy(weights, other.weights, len*sizeof(*weights));
    }
}

AP_weights::~AP_weights() {
    free(weights);
}

long AP_timer() {
    static long time = 0;
    return ++time;
}


// --------------------------------------------------------------------------------

#ifdef UNIT_TESTS
#ifndef TEST_UNIT_H
#include <test_unit.h>
#endif

#define TEST_EXPECT_EQUAL_FILTERS(f1,f2) do{            \
        TEST_EXPECT_NO_ERROR((f1).is_invalid());        \
        TEST_EXPECT_NO_ERROR((f2).is_invalid());        \
        char *m1 = (f1).to_string();                    \
        char *m2 = (f2).to_string();                    \
        TEST_EXPECT_EQUAL(m1, m2);                      \
        free(m2);                                       \
        free(m1);                                       \
    }while(0)

void TEST_filter() {
    const int   LEN            = 20;
    const int   MASK_BITCOUNT  = 9;
    const char *mask           = "01100001110000110011";
    const char *mask_inv       = "10011110001111001100";
    const char *mask_some      = "00100101100101011001";
    const char *seq            = "MSKTAYTKVLFDRGSALDGK";
    const char *seq_masked     =  "SK"  "KVL"  "SA""GK";
    const char *blow_mask      = "_SK____KVL____SA__GK";
    const char *seq_masked_inv = "M""TAYT" "FDRG""LD";
    const char *blow_mask_inv  = "M__TAYT___FDRG__LD__";

    AP_filter f1(LEN);
    AP_filter f2(mask, "0", LEN);
    AP_filter n2(mask, "1", LEN);
    AP_filter f3(mask_inv, "0", LEN);
    AP_filter n3(mask_inv, "1", LEN);

    TEST_EXPECT_EQUAL(f1.get_length(), LEN);
    TEST_EXPECT_EQUAL(f2.get_length(), LEN);
    TEST_EXPECT_EQUAL(f3.get_length(), LEN);
    TEST_EXPECT_EQUAL(n2.get_length(), LEN);
    TEST_EXPECT_EQUAL(n3.get_length(), LEN);

    TEST_EXPECT_EQUAL(f1.get_filtered_length(), LEN);
    TEST_EXPECT_EQUAL(f2.get_filtered_length(), MASK_BITCOUNT);
    TEST_EXPECT_EQUAL(f3.get_filtered_length(), LEN-MASK_BITCOUNT);
    TEST_EXPECT_EQUAL(n2.get_filtered_length(), LEN-MASK_BITCOUNT);
    TEST_EXPECT_EQUAL(n3.get_filtered_length(), MASK_BITCOUNT);

    TEST_EXPECT_NO_ERROR(f1.is_invalid());
    TEST_EXPECT_NO_ERROR(f2.is_invalid());
    TEST_EXPECT_NO_ERROR(f3.is_invalid());
    TEST_EXPECT_NO_ERROR(n2.is_invalid());
    TEST_EXPECT_NO_ERROR(n3.is_invalid());

    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f1.to_string(), "11111111111111111111");
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f2.to_string(), mask);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f3.to_string(), mask_inv);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(n2.to_string(), mask_inv);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(n3.to_string(), mask);

    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f1.filter_string(seq), seq);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f2.filter_string(seq), seq_masked);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f3.filter_string(seq), seq_masked_inv);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(n2.filter_string(seq), seq_masked_inv);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(n3.filter_string(seq), seq_masked);

    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f1.blowup_string(seq,            '_'), seq);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f2.blowup_string(seq_masked,     '_'), blow_mask);
    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(f3.blowup_string(seq_masked_inv, '_'), blow_mask_inv);

    // test inverting filters:
    AP_filter i2(NOT, f2);
    AP_filter i3(NOT, f3);

    TEST_EXPECT_EQUAL_FILTERS(i2, n2);
    TEST_EXPECT_EQUAL_FILTERS(i3, n3);

    // test filter combination (AND + OR):
    AP_filter s2(mask_some, "0", LEN);
    AP_filter s3(NOT, s2);

    AP_filter as23(s2, AND, s3);
    AP_filter os23(s2, OR,  s3);

    TEST_EXPECT_ERROR_CONTAINS(as23.is_invalid(), "Sequence completely filtered out (no columns left)");

    TEST_EXPECT_EQUAL(as23.get_filtered_length(), 0);
    TEST_EXPECT_EQUAL(os23.get_filtered_length(), LEN);

    AP_filter fs22(f2, AND, s2);
    AP_filter fs23(f2, AND, s3);
    AP_filter o2223(fs22, OR, fs23);

    TEST_EXPECT_EQUAL_FILTERS(o2223, f2);

    AP_filter x(fs22, XOR, fs23);
    AP_filter xa1(AP_filter(fs22, AND, AP_filter(NOT, fs23)), OR,  AP_filter(AP_filter(NOT, fs22), AND, fs23)); // = (a&&!b) || (!a&&b)
    AP_filter xa2(AP_filter(fs22, OR,  fs23),                 AND, AP_filter(NOT, AP_filter(fs22, AND, fs23))); // = (a||b) && !(a&&b)

    TEST_EXPECT_EQUAL_FILTERS(x, xa1);
    TEST_EXPECT_EQUAL_FILTERS(x, xa2);

    TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(x.to_string(), mask);
}

#endif // UNIT_TESTS

// --------------------------------------------------------------------------------
