// ============================================================ //
//                                                              //
//   File      : FilteredExport.cxx                             //
//   Purpose   : encapsulate SAI-filtered fasta exporter        //
//                                                              //
//   Coded by Ralf Westram (coder@reallysoft.de) in June 2017   //
//   http://www.arb-home.de/                                    //
//                                                              //
// ============================================================ //

#include "FilteredExport.h"
#include <arbdbt.h>
#include <gb_aci.h>
#include <arb_progress.h>
#include <algorithm>

AP_filter *FilterDefinition::make_filter(GBDATA *gb_main, const char *aliName, size_t aliSize) const {
    /*! generate defined filter
     * @param aliName name of alignment to filter
     * @return generated filter or NULp (error is exported in that case)
     */
    GB_ERROR   error  = NULp;
    AP_filter *filter = NULp;

    {
        GBDATA *gb_sai     = GBT_expect_SAI(gb_main, sai_name.c_str());
        if (!gb_sai) error = GB_await_error();
        else {
            GBDATA *gb_data = GBT_find_sequence(gb_sai, aliName);
            if (!gb_data) {
                error = GBS_global_string("SAI '%s' has no data in alignment '%s'", sai_name.c_str(), aliName);
            }
            else {
#if defined(ASSERTION_USED)
                long datasize = GB_read_count(gb_data); // may be less than ali-length!
                arb_assert(datasize == long(aliSize));  // @@@ write a test failing this assertion (BLOCK and PASS need to handle this differently)
#endif

                char *sai_data       = GB_read_as_string(gb_data); // @@@ NOT_ALL_SAI_HAVE_DATA
                if (!sai_data) error = GB_await_error();
                else  {
                    bool blockChars = (type == BLOCK) != inverse;
                    CharRangeTable crt(characters.c_str());
                    if (blockChars) {
                        filter = new AP_filter(sai_data, crt.expandedRange(), aliSize); // blocks characters
                    }
                    else {
                        AP_filter inv_filt(sai_data, crt.expandedRange(), aliSize);
                        filter = new AP_filter(NOT, inv_filt);
                    }
                    free(sai_data);
                }
            }
        }
    }

    arb_assert(contradicted(filter, error));
    if (error) GB_export_error(error);
    return filter;
}

FilteredExport::FilteredExport(GBDATA *gb_main_, const char *aliname_, size_t alisize_) :
    gb_main(gb_main_),
    aliname(nulldup(aliname_)),
    alisize(alisize_),
    accept_missing_data(false),
    header_ACI(strdup("readdb(name)")),
    sequence_ACI(NULp),
    count_table(NULp),
    minCount(0),
    filter(alisize),
    filter_added(false)
{}

FilteredExport::~FilteredExport() {
    free(header_ACI);
    free(sequence_ACI);
    free(aliname);
}

GB_ERROR FilteredExport::add_SAI_filter(const FilterDefinition& filterDef) {
    AP_filter *newFilter = filterDef.make_filter(gb_main, aliname, alisize);
    if (!newFilter) {
        return GB_await_error();
    }

    if (!filter_added) {
        filter = *newFilter;
        filter_added = true;
    }
    else {
        switch (filterDef.get_type()) {
            case PASS:  filter = AP_filter(filter, OR,  *newFilter); break;
            case BLOCK: filter = AP_filter(filter, AND, *newFilter); break;
        }
    }
    delete newFilter;
    return NULp;
}

int FilteredExport::count_bases(const char *seq) const {
    int count = 0;
    for (int p = 0; seq[p]; ++p) {
        count += count_table.isSet(seq[p]);
    }
    return count;
}

char *FilteredExport::get_filtered_sequence(GBDATA *gb_species, const char*& reason) const {
    /* returns filtered sequence                                               or
     * NULp (which means "do not export")
     * - if an error occurred (error is exported only in that case!),
     * - if species had no data in alignment (and accept_missing_data is true) or
     * - if filtered sequence does not have min. required base count.
     * If NULp returned and no error exported -> 'reason' gets set!
     */
    arb_assert(gb_species);
    arb_assert(!GB_have_error());

    reason = NULp;

    char   *seq     = NULp;
    GBDATA *gb_data = GBT_find_sequence(gb_species, aliname);

    if (!gb_data) {
        if (GB_have_error()) {
            GB_export_error(GB_failedTo_error("read sequence of ", GBT_get_name_or_description(gb_species), GB_await_error()));
        }
        else {
            if (accept_missing_data) {
                reason = "has no data";
            }
            else {
                GB_export_errorf("species '%s' has no data in '%s'", GBT_get_name_or_description(gb_species), aliname);
            }
        }
    }
    else {
        GB_ERROR error = filter.is_invalid();
        if (error) GB_export_error(error);
        else {
            const char *cseq = GB_read_char_pntr(gb_data);
            seq              = filter.filter_string(cseq);

            // check min. requirements:
            if (minCount>0) { // otherwise check would always succeed
                int count = count_bases(seq);
                if (count<minCount) { // too few bases -> do not export
                    freenull(seq);
                    reason = "not enough base-characters left";
                }
            }
        }
    }

    if (seq && sequence_ACI) {
        char *seq_postprocessed = GB_command_interpreter_in_env(seq, sequence_ACI, GBL_simple_call_env(gb_species));
        if (seq_postprocessed) {
            freeset(seq, seq_postprocessed);
        }
        else {
            char *error = strdup(GB_await_error());
            GB_export_errorf("Failed to post-process sequence data of species '%s' (Reason: %s)", GBT_get_name_or_description(gb_species), error);
            free(error);
            freenull(seq);
        }
    }

    arb_assert(contradicted(seq, GB_have_error() || reason));
    arb_assert(implicated(!seq, contradicted(GB_have_error(), reason)));

    return seq;
}

char *FilteredExport::get_fasta_header(GBDATA *gb_species) const {
    return GB_command_interpreter_in_env("", header_ACI, GBL_simple_call_env(gb_species));
}

GB_ERROR FilteredExport::write_fasta(FILE *out) {
    GB_ERROR error    = NULp;
    int      exported = 0;
    int      skipped  = 0;

    {
        arb_progress progress("Write sequence data", GBT_get_species_count(gb_main));

        for (GBDATA *gb_species = GBT_first_species(gb_main);
             gb_species && !error;
             gb_species = GBT_next_species(gb_species))
        {
            const char *reason;
            char       *filt_seq = get_filtered_sequence(gb_species, reason);

            if (filt_seq) {
                ++exported; // count exported species
                fputc('>', out);
                {
                    char *header = get_fasta_header(gb_species);
                    if (header) {
                        fputs(header, out);
                        free(header);
                    }
                    else {
                        error = GB_await_error();
                    }
                }
                fputc('\n', out);

                fputs(filt_seq, out);
                fputc('\n', out);

                free(filt_seq);
            }
            else {
                if (reason) {
                    ++skipped; // count skipped
                    fprintf(stderr, "Skipped species '%s' (Reason: %s)\n", GBT_get_name_or_description(gb_species), reason);
                }
                else {
                    error = GB_await_error();
                }
            }
            ++progress;
        }

        if (error) progress.done();
    }

    if (!error) {
        if (exported) {
            fprintf(stderr, "Summary: %i species exported", exported);
            if (skipped) fprintf(stderr, ", %i species skipped", skipped);
            fputc('\n', stderr);
        }
        else {
            fprintf(stderr, "Summary: all %i species skipped (warning: generated empty file)\n", skipped);
        }
    }

    return error;
}

// --------------------------------------------------------------------------------

#ifdef UNIT_TESTS
#ifndef TEST_UNIT_H
#include <test_unit.h>
#endif

#define TEST_EXPECT_CRT_DEFINES(arg,expected) TEST_EXPECT_EQUAL(CharRangeTable(arg).expandedRange(), expected)

#define ABC "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
#define abc "abcdefghijklmnopqrstuvwxyz"

void TEST_CharRangeTable() {
    TEST_EXPECT_CRT_DEFINES("abc", "abc");
    TEST_EXPECT_CRT_DEFINES("cba", "abc");
    TEST_EXPECT_CRT_DEFINES("c-a", "-ac"); // unwanted reverse range does not expand!
    TEST_EXPECT_CRT_DEFINES("a-c", "abc");

    TEST_EXPECT_CRT_DEFINES("a-db-e", "abcde");
    TEST_EXPECT_CRT_DEFINES("a-de-b", "-abcde");

    TEST_EXPECT_CRT_DEFINES("-ab", "-ab");
    TEST_EXPECT_CRT_DEFINES("a-b", "ab");
    TEST_EXPECT_CRT_DEFINES("ab-", "-ab");

    TEST_EXPECT_CRT_DEFINES("a-ac-c", "ac");

    TEST_EXPECT_CRT_DEFINES("a-zA-Z", ABC abc);

    // dangerous ranges are not expanded:
    TEST_EXPECT_CRT_DEFINES(".-=", "-.=");
    TEST_EXPECT_CRT_DEFINES("=-.", "-.=");

    TEST_EXPECT_CRT_DEFINES("a-Z", "-Za");
    TEST_EXPECT_CRT_DEFINES("A-z", "-Az");
}

#undef abc
#undef ABC

#define TEST_EXPECT_SEQWITHLENGTH__NOERROREXPORTED(create_seqcopy, expected_length) do {        \
        char *seqcopy;                                                                          \
        TEST_EXPECT_RESULT__NOERROREXPORTED(seqcopy = create_seqcopy);                          \
        TEST_EXPECT_EQUAL(strlen(seqcopy), expected_length);                                    \
        free(seqcopy);                                                                          \
    } while(0)

void TEST_FilteredExport() {
    // see also ../../TOOLS/arb_test.cxx@SAI_FILTERED_EXPORT_TESTS

    GB_shell  shell;
    GBDATA   *gb_main = GB_open("TEST_prot_tiny.arb", "r"); // ../../UNIT_TESTER/run/TEST_prot_tiny.arb

    // only "CytLyti6" has data in 'ali_dna_incomplete'

    {
        char   *ali_name = NULp;
        size_t  ali_size = 0;

        {
            GB_transaction ta(gb_main);
            ali_name = GBT_get_default_alignment(gb_main);
            TEST_REJECT_NULL(ali_name);
            ali_size = GBT_get_alignment_len(gb_main, ali_name);
            TEST_REJECT(ali_size<=0);
        }

        {
            FilteredExport exporter(gb_main, ali_name, ali_size);
            FilteredExport exporter_incomplete(gb_main, "ali_dna_incomplete", ali_size);

            {
                GB_transaction ta(gb_main);

                GBDATA *gb_spec1 = GBT_find_species(gb_main, "StrRamo3"); // ~ 60 AA
                GBDATA *gb_spec2 = GBT_find_species(gb_main, "MucRacem"); // more AA
                GBDATA *gb_spec3 = GBT_find_species(gb_main, "BctFra12");
                GBDATA *gb_spec4 = GBT_find_species(gb_main, "CytLyti6");

                TEST_REJECT_NULL(gb_spec1);
                TEST_REJECT_NULL(gb_spec2);
                TEST_REJECT_NULL(gb_spec3);
                TEST_REJECT_NULL(gb_spec4);

                const char *reason;
                TEST_EXPECT_SEQWITHLENGTH__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec1, reason), ali_size);
                TEST_EXPECT_SEQWITHLENGTH__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec2, reason), ali_size);

                TEST_EXPECT_NORESULT__ERROREXPORTED_CONTAINS(exporter_incomplete.get_filtered_sequence(gb_spec3, reason), "has no data in 'ali_dna_incomplete'");
                TEST_EXPECT_NULL(reason);

                exporter_incomplete.do_accept_missing_data(); // changes state of exporter_incomplete!

                TEST_EXPECT_NORESULT__NOERROREXPORTED(exporter_incomplete.get_filtered_sequence(gb_spec3, reason));
                TEST_EXPECT_EQUAL(reason, "has no data");

                TEST_EXPECT_SEQWITHLENGTH__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec4, reason), ali_size);

                // test header-generation:
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_fasta_header(gb_spec1), "StrRamo3");
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_fasta_header(gb_spec2), "MucRacem");

                exporter.set_header_ACI("readdb(name);\",\";readdb(full_name)"); // use real ACI for header
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_fasta_header(gb_spec3), "BctFra12,Bacteroides fragilis");

                exporter.set_header_ACI("readdb(name);\",\";readdb(ali_dna);\",\";readdb(nosuchfield);\",\";readdb(full_name)");   // wrong accepted use (try to read a container and a unknown field)
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_fasta_header(gb_spec4), "CytLyti6,,,Cytophaga lytica"); // both produce empty output

                exporter.set_header_ACI("readdb(name);\",\";bugme");                                                          // wrong rejected use (unknown command)
                TEST_EXPECT_NORESULT__ERROREXPORTED_CONTAINS(exporter.get_fasta_header(gb_spec4), "Unknown command 'bugme'"); // aborts with error

                // test sequences are skipped if to few bases remain:
                exporter.set_required_baseCount("ACGT", 185); // (bit more than 3*60)
                TEST_EXPECT_NORESULT__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec1, reason));
                TEST_EXPECT_EQUAL(reason, "not enough base-characters left");
                exporter.reset_required_baseCount();

                // test FilterDefinition
                FilterDefinition pvp_blck_variab("POS_VAR_BY_PARSIMONY", BLOCK, true,  "-=.0123");
                FilterDefinition pvp_pass_variab("POS_VAR_BY_PARSIMONY", PASS,  true,  "-=.0123");
                FilterDefinition pvp_blck_cnsrvd("POS_VAR_BY_PARSIMONY", BLOCK, false, "-=.012345");
                FilterDefinition pvp_pass_cnsrvd("POS_VAR_BY_PARSIMONY", PASS , false, "-=.012345");

                {
                    AP_filter *filt_blck_variab;
                    AP_filter *filt_pass_variab;
                    AP_filter *filt_blck_cnsrvd;
                    AP_filter *filt_pass_cnsrvd;

                    TEST_EXPECT_RESULT__NOERROREXPORTED(filt_blck_variab = pvp_blck_variab.make_filter(gb_main, ali_name, ali_size));
                    TEST_EXPECT_RESULT__NOERROREXPORTED(filt_pass_variab = pvp_pass_variab.make_filter(gb_main, ali_name, ali_size));
                    TEST_EXPECT_RESULT__NOERROREXPORTED(filt_blck_cnsrvd = pvp_blck_cnsrvd.make_filter(gb_main, ali_name, ali_size));
                    TEST_EXPECT_RESULT__NOERROREXPORTED(filt_pass_cnsrvd = pvp_pass_cnsrvd.make_filter(gb_main, ali_name, ali_size));

                    TEST_EXPECT_EQUAL(filt_blck_variab->get_length(), ali_size);
                    TEST_EXPECT_EQUAL(filt_pass_variab->get_length(), ali_size);
                    TEST_EXPECT_EQUAL(filt_blck_cnsrvd->get_length(), ali_size);
                    TEST_EXPECT_EQUAL(filt_pass_cnsrvd->get_length(), ali_size);

                    TEST_EXPECT_EQUAL(filt_blck_variab->get_filtered_length(), 135);
                    TEST_EXPECT_EQUAL(filt_pass_variab->get_filtered_length(), ali_size-135);
                    TEST_EXPECT_EQUAL(filt_blck_cnsrvd->get_filtered_length(), ali_size-45);
                    TEST_EXPECT_EQUAL(filt_pass_cnsrvd->get_filtered_length(), 45);

                    delete filt_pass_cnsrvd;
                    delete filt_blck_cnsrvd;
                    delete filt_pass_variab;
                    delete filt_blck_variab;
                }

                // test get_filtered_sequence with real filters:
                TEST_EXPECT_NO_ERROR(exporter.add_SAI_filter(pvp_blck_variab));
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec1, reason), "TCCACGCACCAAGTCCT--GGACTTTG--GACCGGGTCCCTGACG-ATG-CCGGGGACG---GACTGG--TC--GGGCCGCT-ACGG---CGCTCGGCCGCG-------GCCG-CTGGG----CCGCCGT.....");

                exporter.clear_SAI_filters();

                TEST_EXPECT_NO_ERROR(exporter.add_SAI_filter(pvp_pass_cnsrvd));
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec1, reason), "TCACCAAGTTGGGTCGACCGGGAATGGGCCAGCCCGCGCTGGCCC");

                exporter.clear_SAI_filters();

                TEST_EXPECT_NO_ERROR(exporter.add_SAI_filter(pvp_blck_cnsrvd));
                TEST_EXPECT_NO_ERROR(exporter.add_SAI_filter(pvp_blck_variab));
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec1, reason), "CGCACCC--ACTTTG--GACCGGCCTCG-ATG-GCG---GCG--TC--GCGT-CG---CGTCGG-------GCG-CG----GCGT.....");

                // test sequence post-processing
                exporter.set_sequence_ACI(":.=-"); // convert dots to hyphens (using SRT)
                TEST_EXPECT_EQUAL_STRINGCOPY__NOERROREXPORTED(exporter.get_filtered_sequence(gb_spec1, reason), "CGCACCC--ACTTTG--GACCGGCCTCG-ATG-GCG---GCG--TC--GCGT-CG---CGTCGG-------GCG-CG----GCGT-----");

                exporter.set_sequence_ACI(":bad"); // malformed expression
                TEST_EXPECT_NORESULT__ERROREXPORTED_CONTAINS(exporter.get_filtered_sequence(gb_spec1, reason), "SRT ERROR: no '=' found in command 'bad'");
            }
        }
        free(ali_name);
    }
    GB_close(gb_main);
}

#endif // UNIT_TESTS

// --------------------------------------------------------------------------------


