from struct import pack as _pack
from struct import unpack as _unpack
from functools import partial as _partial

def unpack(s, numtypes=1):
    up, data_offset = create_unpacker(s, numtypes)
    result = up(([], s, data_offset))
    return tuple(result[0])

def unpack_off(s, numtypes=1):
    up, data_offset = create_unpacker(s, numtypes)
    result = up(([], s, data_offset))
    return tuple(result[0]), result[2]

def create_unpacker(s, numtypes=1):
    try:
        numtypes = int(numtypes)
    except (ValueError, TypeError):
        raise TypeError("numtypes must be an integer")
    if numtypes < 0:
        raise TypeError("numtypes must be > 0")
    if not isinstance(s, (str, buffer)):
        raise TypeError("s must be a string or buffer")
    def nullunpacker(dectup):
        return dectup
    def composeunpacker(f1, f2):
        def compose(dectup):
            return f1(f2(dectup))
        return compose
    unpacker = nullunpacker
    offset = 0
    for i in xrange(0, numtypes):
        tag = s[offset]
        offset += 1
        if tag == 'c':
            unpacker = composeunpacker(count_unpacker, unpacker)
        elif tag == 'm':
            if (len(s) > offset) and (s[offset] == '<'):
                raise NotImplementedError("Length modifiers not yet supported for multi-precision integer ('m') type.")
            unpacker = composeunpacker(multi_int_unpacker, unpacker)
        elif tag == 'w':
            if (len(s) > offset) and (s[offset] == '<'):
                raise NotImplementedError("Length modifiers not yet supported for unsigned multi-precision integer ('w') type.")
            unpacker = composeunpacker(multi_uint_unpacker, unpacker)
        elif tag == 'y':
            unpacker = composeunpacker(bool_unpacker, unpacker)
        elif tag == 'b':
            if (len(s) > offset) and (s[offset] == '<'):
                raise NotImplementedError("Length modifiers not yet supported for binary blob ('b') type.")
            unpacker = composeunpacker(blob_unpacker, unpacker)
        elif tag == 's':
            if (len(s) > offset) and (s[offset] == '<'):
                raise NotImplementedError("Length modifiers not yet supported for string ('s') type.")
            unpacker = composeunpacker(string_unpacker, unpacker)
        elif tag == 'k':
            unpacker = composeunpacker(byte_unpacker, unpacker)
        elif tag == 'p':
            unpacker = composeunpacker(ubyte_unpacker, unpacker)
        elif tag == 'j':
            unpacker = composeunpacker(short_unpacker, unpacker)
        elif tag == 'o':
            unpacker = composeunpacker(ushort_unpacker, unpacker)
        elif tag == 'i':
            unpacker = composeunpacker(int_unpacker, unpacker)
        elif tag == 'u':
            unpacker = composeunpacker(uint_unpacker, unpacker)
        elif tag == 'l':
            unpacker = composeunpacker(long_unpacker, unpacker)
        elif tag == 'n':
            unpacker = composeunpacker(ulong_unpacker, unpacker)
        elif tag == 'f':
            unpacker = composeunpacker(double_unpacker, unpacker)
        elif tag == 'h':
            unpacker = composeunpacker(float_unpacker, unpacker)
        elif tag == 'g':
            raise NotImplementedError("The arbitrary size floating point ('g') type is not yet supported.")
        elif tag == 'v':
            unpacker = composeunpacker(variant_unpacker, unpacker)
        elif tag in ['t', 'a', 'd']:
            if tag == 't':
                subunpacker, suboffset = make_tuple_unpacker(buffer(s, offset))
            elif tag == 'a':
                subunpacker, suboffset = make_array_unpacker(buffer(s, offset))
            elif tag == 'd':
                subunpacker, suboffset = make_dict_unpacker(buffer(s, offset))
            else:
                assert False, "Getting here should be impossible."
            offset += suboffset
            unpacker = composeunpacker(subunpacker, unpacker)
        else:
            raise ValueError("Unknown type tag '%s' encountered." % (tag,))
    return unpacker, offset

def count_unpacker(dectup):
    val, newoffset = unpackcount_off(dectup[1], dectup[2])
    dectup[0].append(val)
    return dectup[0], dectup[1], newoffset

def multi_int_unpacker(dectup):
    result, s, offset = dectup
    mlen, newoffset = unpackcount_off(dectup[1], dectup[2])
    if (newoffset + mlen) > len(s):
        raise ValueError("The string passed in is too short to unpack.")
    if mlen <= 0:
        result.append(0L)
        return (result, s, newoffset)

    mint = unpackl(buffer(s, newoffset, mlen))
    if ord(s[newoffset]) & 0x80:
        mint -= 2**(mlen * 8)
    result.append(mint)
    return (result, s, newoffset + mlen)

def multi_uint_unpacker(dectup):
    result, s, offset = dectup
    mlen, newoffset = unpackcount_off(dectup[1], dectup[2])
    if (newoffset + mlen) > len(s):
        raise ValueError("The string passed in is too short to unpack.")
    if mlen <= 0:
        result.append(0L)
        return (result, s, newoffset)
    mint = unpackl(buffer(s, newoffset, mlen))
    result.append(mint)
    return (result, s, newoffset + mlen)

def blob_unpacker(dectup):
    result, s, offset = dectup
    blen, newoffset = unpackcount_off(dectup[1], dectup[2])
    if (newoffset + blen) > len(s):
        raise ValueError("The string passed in is too short to unpack.")
    if blen <= 0:
        result.append('')
        return result, s, newoffset
    result.append(buffer(s, newoffset, blen))
    return result, s, newoffset + blen

def string_unpacker(dectup):
    result, s, offset = dectup
    slen, newoffset = unpackcount_off(dectup[1], dectup[2])
    if (newoffset + slen) > len(s):
        raise ValueError("The string passed in is too short to unpack.")
    if slen <= 0:
        result.append('')
        return result, s, newoffset
    result.append(unicode(buffer(s, newoffset, slen), 'utf-8'))
    return result, s, newoffset + slen

def bool_unpacker(dectup):
    result, s, offset = dectup
    if offset >= len(s):
        raise ValueError("The string passed in is too short for type spec.")
    result.append(True if ord(s[offset]) != 0 else False)
    return result, s, offset + 1

def pystruct_unpacker(spec, len, dectup):
    result, s, offset = dectup
    (upresult,) = _unpack(spec, buffer(s, offset, len))
    result.append(upresult)
    return result, s, offset + len

byte_unpacker = _partial(pystruct_unpacker, '>b', 1)
ubyte_unpacker = _partial(pystruct_unpacker, '>b', 1)
short_unpacker = _partial(pystruct_unpacker, '>h', 2)
ushort_unpacker = _partial(pystruct_unpacker, '>H', 2)
int_unpacker = _partial(pystruct_unpacker, '>i', 4)
uint_unpacker = _partial(pystruct_unpacker, '>I', 4)
long_unpacker = _partial(pystruct_unpacker, '>q', 8)
ulong_unpacker = _partial(pystruct_unpacker, '>Q', 8)
float_unpacker = _partial(pystruct_unpacker, '>f', 4)
double_unpacker = _partial(pystruct_unpacker, '>d', 8)

def variant_unpacker(dectup):
    result, s, offset = dectup
    val, uplen = unpack_off(buffer(s, offset), 1)
    (val,) = val
    result.append(val)
    return result, s, offset + uplen

def make_tuple_unpacker(s):
    if s[0] != '(':
        raise ValueError("Tuple type tag ('t') must be followed by '('")
    offset = 1
    def nullunpacker(dectup):
        return dectup
    def composeunpacker(f1, f2):
        def compose(dectup):
            return f1(f2(dectup))
        return compose
    def tuplefinishunpacker(dectup):
        return [tuple(dectup[0])], dectup[1], dectup[2]
    unpacker = nullunpacker
    while (offset < len(s)) and (s[offset] != ')'):
        elunpacker, consumed = create_unpacker(buffer(s, offset))
        offset += consumed
        unpacker = composeunpacker(elunpacker, unpacker)
    if offset >= len(s):
        raise ValueError("Ran off the end of the string to unpack while parsing tuple")
    return composeunpacker(tuplefinishunpacker, unpacker), offset + 1

def make_array_unpacker(s):
    if s[0] != '[':
        raise ValueError("Array type tag ('a') must be followed by '['")
    offset = 1
    elunpacker, consumed = create_unpacker(buffer(s, offset))
    offset += consumed
    if (offset >= len(s)) or (s[offset] != ']'):
        raise ValueError("Array type tag must end with ']' after one type tag.")
    def unpack_array(dectup):
        result, s, offset = dectup
        array = []
        if len(s) <= offset:
            raise ValueError("End of string while parsing array.")
        while s[offset] != '\0':
            out, s, offset = elunpacker(([], s, offset + 1))
            if len(s) <= offset:
                raise ValueError("End of string while parsing array.")
            array.extend(out)
        result.append(array)
        return result, s, offset + 1
    return unpack_array, offset + 1

def make_dict_unpacker(s):
    if s[0] != '{':
        raise ValueError("Dictionary type tag ('d') must be followed by '{'")
    offset = 1
    keyunpacker, consumed = create_unpacker(buffer(s, offset))
    offset += consumed
    if offset >= len(s):
        raise ValueError("Dictionary type tags must contain two subtypes.")
    valunpacker, consumed = create_unpacker(buffer(s, offset))
    offset += consumed
    if (offset >= len(s)) or (s[offset] != '}'):
        print "offset = %r / consumed = %r / s = %r / len(s) = %r" % \
            (offset, consumed, s, len(s))
        raise ValueError("Dictionary type tag must end with '}' after two type tag.")
    def unpack_dict(dectup):
        result, s, offset = dectup
        dictionary = {}
        if len(s) <= offset:
            raise ValueError("End of string while parsing dictionary.")
        while s[offset] != '\0':
            out, s, offset = valunpacker(keyunpacker(([], s, offset + 1)))
            if len(s) <= offset:
                raise ValueError("End of string while parsing dictionary.")
            dictionary[out[0]] = out[1]
        result.append(dictionary)
        return result, s, offset + 1
    return unpack_dict, offset + 1

def packl(lnum, pad = 1):
    if lnum < 0:
        raise RangeError("Cannot use packl to convert a negative integer "
                         "to a string.")
    count = 0
    l = []
    while lnum > 0:
        l.append(lnum & 0xffffffffffffffffL)
        count += 1
        lnum >>= 64
    if count <= 0:
        return '\0' * pad
    elif pad >= 8:
        lens = 8 * count % pad
        pad = ((lens != 0) and (pad - lens)) or 0
        l.append('>' + 'x' * pad + 'Q' * count)
        l.reverse()
        return _pack(*l)
    else:
        l.append('>' + 'Q' * count)
        l.reverse()
        s = _pack(*l).lstrip('\0')
        lens = len(s)
        if (lens % pad) != 0:
            return '\0' * (pad - lens % pad) + s
        else:
            return s

def unpackl(s):
    n = 0L
    if len(s) <= 0:
        return n
    count8 = len(s) // 8
    count1 = len(s) % 8
    upfmt = '>' + 'Q' * count8 + 'B' * count1
    l = _unpack(upfmt, s)
    for val in l[0:count8]:
        n <<= 64
        n |= val
    for val in l[count8:]:
        n <<= 8
        n |= val
    return n

def packcount(count):
    """packcount(non-negative integer) -> string representing the count

The count cannot be less than 0.  The encoding is designed so that if
the count is a message length, a very small percentage of the total
message will be consumed by the count.

It is suggested that the message length itself not be counted as part
of the message length.  This means that counts of 0 should be
sensible, and indicate a message with no content.

This function should NOT be used for values that are really arbitrary
sized integers.  The maximum sized integer that this function can
represent is 4080 bits.  Public key values will most likely exceed
this.  To store a public key (or other arbitrary sized integer) value,
encode that value in some other way (perhaps using packl), then use
this function to encode the length of that encoding."""
    if count < 0:
        raise RangeError("A count cannot be negative!")
    if count < 223:
        return chr(count)
    elif count < (223 + 8192):
        count -= 223
        upper = count // 256
        assert upper < 32
        lower = count % 256
        return chr(upper + 223) + chr(lower)
    else:
        s = packl(count, 2)
        assert len(s) % 2 == 0
        countlen = len(s) // 2
        if countlen > 255:
            raise RangeError("A count must take fewer than 4080 bits to "
                             "represent!")
        return chr(255) + chr(countlen) + s

def unpackcount_off(s, off = 0):
    """unpackcount(a string, offset in string of something created by packcount)
   -> a tuple
     the tuple is (count, offset of byte after count).

See the documentation for packcount for a little more of an
explanation.  A ValueError exception will be thrown for strings that
don't start with a valid count value."""
    fc = ord(s[off])
    if fc < 223:
        # One byte count
        return (fc, off + 1)
    elif fc < (223 + 32):
        # Two byte count
        if 2 > (len(s) - off):
            raise ValueError("The passed in string is too short, and isn't "
                             "a valid count.")
        upper_5bits = fc - 223
        lower_8bits = ord(s[off + 1])
        count = ((upper_5bits << 8) | lower_8bits) + 223
        return (count, off + 2)
    else:
        # Variable length count, length in second octet
        if 2 > (len(s) - off):
            raise ValueError("The passed in string is too short, and isn't "
                             "a valid count.")
        length = ord(s[off + 1]) * 2
        if length <= 0:
            raise ValueError("The passed in string isn't a valid count!")
        # Skip past length of count itself
        off += 2
        # Enough octets for rest of count?
        if length > (len(s) - off):
            raise ValueError("The passed in string is too short, and isn't "
                             "a valid count.")
        # Use the handy unpackl function to turn the octet string into a number
        count = unpackl(s[off:(off + length)])
        # If count fits into an int, turn it into an int
        if count < 2**31:
            count = int(count)
        return (count, off + length)

def unpackcount(s):
    """unpackcount(a string beginning with something created by packcount) -> a tuple
    the tuple is (count, remainder of s).

See the documentation for packcount for a little more of an
explanation.  A ValueError exception will be thrown for strings that
don't start with a valid count value."""
    (count, newoff) = unpackcount_off(s, 0)
    return (count, s[newoff:])

