#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>

/* This really isn't a bit trie anymore, but...
 *
 * A normal bit-trie will split two-ways, by looking at exactly one
 * bit.  Obviously, this can be lossy if all the keys begin with the
 * same pattern (e.g. all 0s).  To alleviate that problem a bit,
 * bit-trie first dispatch on the first bit set at the root, and then
 * walks much smaller binary trees.
 *
 * Why not perform the bit traversal trick iteratively, at each node?
 *
 * There are a couple issues with that idea, mostly that we end up
 * biasing our data structure for keys that are sparse in 1s.
 *
 * Instead, we also store a key in each node, and perform the bit
 * traversal on (key ^ needle).  This implicitly takes care of
 * bit-patterns in keys, and bounds not only the size of the
 * structure, but also its depth!
 *
 * In addition, we know that the number of children will usually be
 * very low, if only because the number of bits considered in the key
 * shrinks with the search depth.  A tagging scheme is used to
 * efficiently reallocate nodes, along with their vector of children.
 * The values are a bit strange (not powers of two) to compensate for
 * the key and value.
 *
 * TODO: I'm pretty sure we can still do prefix and predecessor
 * searches.
 *
 * Here's why this is significantly better than a radix tree (never
 * mind bit patterns or random keys, which this is better able to
 * exploit).
 *
 * Let n(l, d) be the maximum number of nodes at (exactly) depth d,
 * given keys of length l.
 *
 * n(l, 0) = 1;
 * n(l, d) d > l = 0;
 * n(l, d) = sum_i=0^(l-1) n(d-1, i).
 *
 * In particular,
 * n(64, 64) = 1
 * n(64, 63) = 64     = 64/1*1
 * n(64, 62) = 2016   = 63/2*64
 * n(64, 61) = 41664  = 62/3*2016
 * n(64, 60) = 635376 = 61/4*41664
 *
 * And that function is symmetric wrt the depth (n(64, 0) = n(64,
 * 64)).  So, assuming a *completely full* trie, half the items are at
 * depth 32 or less.  Better: the middle is much heavier than either
 * tails, so, e.g., a very large percentage will be found at depth 40
 * or less.
 *
 * Now, what if we have smaller tries; what are the worst cases?  The
 * worst case, depth 64, can still only happen for one key; more
 * interestingly, it can only happen for tries of size at least 65
 * (and the average depth is then 32).  For depth 63, the worst case
 * (64 such keys), can only happen if there are at least 1+2+3+...+63
 * + 63 = 2079 keys (?).
 *
 * Overall, it's possible to have a small number of very deep nodes,
 * but there can't be too many of them without having significantly
 * larger datasets.
 */

struct bit_trie_node;
typedef union { struct bit_trie_node * data; uint64_t addr; } bit_trie_node_t;

struct bit_trie_node {
        uint64_t key;
        void * value;
        bit_trie_node_t children[];
};

static unsigned sizes[] = { 0, 2, 4, 6,
                           14, 30, 46, 64};

static inline
unsigned node_tag (bit_trie_node_t node)
{
        return node.addr&7;
}

static inline
struct bit_trie_node * node_data (bit_trie_node_t node)
{
        bit_trie_node_t temp = node;
        temp.addr &= ~7UL;
        return temp.data;
}

#ifdef DEBUG
# include <stdio.h>
#endif

static
void * bit_trie_find_ (bit_trie_node_t node, const uint64_t key, void * const missing)
{
        unsigned fathomed_bits = 0;
        while (1) {
#ifdef DEBUG
                printf("Node: %lx\n", node.addr);
#endif
                if (node.addr == 0) return missing;
                struct bit_trie_node * data = node_data(node);
                const uint64_t diff = key ^ data->key;
                if (diff == 0) {
#ifdef DEBUG
                        printf("Found at %lx\n", node.addr);
#endif
                        return data->value;
                }

                // count last bit set, from the end
                const unsigned split = __builtin_clzl(diff);
                const unsigned idx = split - fathomed_bits;
                unsigned tag = node_tag(node);
                unsigned size = sizes[tag];
                if (idx >= size) return missing;
                node = data->children[idx];
                fathomed_bits = split+1;
        }
}

static
bit_trie_node_t make_node_for_tag (unsigned tag)
{
        const unsigned size = sizes[tag];
        struct bit_trie_node * data = calloc(1, 
                                             sizeof(struct bit_trie_node)
                                             + size*sizeof(bit_trie_node_t));
        bit_trie_node_t temp;
        temp.data = data;
        temp.addr |= tag;
        return temp;
}

static
bit_trie_node_t realloc_node_for_size (struct bit_trie_node * node, unsigned max_size,
                                       unsigned idx, unsigned old_tag)
{
        unsigned tag, size, old_size = sizes[old_tag];
        for (tag = old_tag+1; __builtin_expect(idx >= (size = sizes[tag]), 0); tag++);

        if (size > max_size) size = max_size;

        struct bit_trie_node * data = realloc(node,
                                              sizeof(struct bit_trie_node)
                                              + size*sizeof(bit_trie_node_t));
        for (unsigned i = old_size; i < size; i++)
                data->children[i].addr = 0;

        bit_trie_node_t temp;
        temp.data = data;
        temp.addr |= tag;
        return temp;
}

static
void * bit_trie_insert_ (bit_trie_node_t * dest,
                         const uint64_t key, void * const value,
                         void * const missing)
{
        unsigned fathomed_bits = 0;
        while (1) {
                bit_trie_node_t node = *dest;
#ifdef DEBUG
                printf("Node: %lx\n", node.addr);
#endif
                if (node.addr == 0) {
                        bit_trie_node_t new_node = make_node_for_tag(0);
#ifdef DEBUG
                        printf("insert at %lx\n", new_node.addr);
#endif
                        struct bit_trie_node * data = node_data(new_node);
                        data->key = key;
                        data->value = value;
                        *dest = new_node;
                        return missing;
                }

                struct bit_trie_node * data = node_data(node);
                const uint64_t diff = key ^ data->key;
                if (diff == 0) {
#ifdef DEBUG
                        printf("Found at %lx\n", node.addr);
#endif
                        void * old = data->value;
                        data->value = value;
                        return old;
                }
                const unsigned split = __builtin_clzl(diff);
                const unsigned idx = split - fathomed_bits;
#ifdef DEBUG
                printf("State %lx %lx %u %u\n", key, data->key, idx, fathomed_bits);
#endif
                unsigned tag = node_tag(node);
                if (idx >= sizes[tag]) {
                        *dest = node = realloc_node_for_size(data, 64-fathomed_bits, idx, tag);
                        data = node_data(node);
                }

                dest = data->children+idx;
                fathomed_bits = split+1;
        }
}

typedef struct bit_trie { bit_trie_node_t head; } bit_trie_t;

void * bit_trie_find (const bit_trie_t * trie, uint64_t key, void * missing)
{
        return bit_trie_find_(trie->head, key, missing);
}

void * bit_trie_insert (bit_trie_t * trie, uint64_t key, void * value, void * missing)
{
        return bit_trie_insert_(&trie->head, key, value, missing);
}

#ifdef TEST_ME
#include <assert.h>

int main ()
{
        uint64_t keys[] = {1, 2, 3, 4, 5, 6, 0, 7};
        bit_trie_t trie;
        for (size_t i = 0; i < 7; i++)
                bit_trie_insert(&trie, keys[i], keys+i, NULL);

        for (size_t i = 0; i < 7; i++)
                assert(keys+i == bit_trie_find(&trie, keys[i], NULL));

        return 0;
}
#endif

