This is xnu-11215.1.10. See this file in:
"""
XNU Collection iterators
"""
from .cvalue import gettype
from .standard import xnu_format

from abc import ABCMeta, abstractmethod, abstractproperty


#
# Note to implementers
# ~~~~~~~~~~~~~~~~~~~~
#
# Iterators must be named iter_* in accordance with typical python API naming
#
# The "root" or "head" of the collection must be passed by reference
# and not by pointer.
#
# Returned elements must be references to the element type,
# and not pointers.
#
# Intrusive data types should ask for an "element type" (as an SBType),
# and a field name or path, and use SBType.xContainerOfTransform
# to cache the transform to apply for the entire iteration.
#


def iter_linked_list(head_value, next_field_or_path,
        first_field_or_path = None):
    """
    Iterates a NULL-terminated linked list.

        Assuming these C types:

            struct container {
                struct node *list1_head;
                struct node *list2_head;
            }

            struct node {
                struct node *next;
            }

        and "v" is an SBValue to a `struct container` type,
        enumerating list1 is:

            iter_linked_list(v, 'next', 'list1_head')

        if "v" is a `struct node *` directly, then the enumeration
        becomes:

            iter_linked_list(v, 'next')


    @param head_value (SBValue)
        a reference to the list head.

    @param next_field_or_path (str)
        The name of (or path to if starting with '.')
        the field linking to the next element.

    @param first_field_or_path (str or None)
        The name of (or path to if starting with '.')
        the field from @c head_value holding the pointer
        to the first element of the list.
    """

    if first_field_or_path is None:
        elt = head_value
    elif first_field_or_path[0] == '.':
        elt = head_value.chkGetValueForExpressionPath(first_field_or_path)
    else:
        elt = head_value.chkGetChildMemberWithName(first_field_or_path)

    if next_field_or_path[0] == '.':
        while elt.GetValueAsAddress():
            elt = elt.Dereference()
            yield elt
            elt = elt.chkGetValueForExpressionPath(next_field_or_path)
    else:
        while elt.GetValueAsAddress():
            elt = elt.Dereference()
            yield elt
            elt = elt.chkGetChildMemberWithName(next_field_or_path)


def iter_queue_entries(head_value, elt_type, field_name_or_path, backwards=False):
    """
    Iterate over a queue of entries (<kern/queue.h> method 1)

    @param head_value (SBValue)
        a reference to the list head (queue_head_t)

    @param elt_type (SBType)
        The type of the elements on the chain

    @param field_name_or_path (str)
        The name of (or path to if starting with '.') the field
        containing the struct mpsc_queue_chain linkage.

    @param backwards (bool)
         Whether the walk is forward or backwards
    """

    stop = head_value.GetLoadAddress()
    elt  = head_value
    key  = 'prev' if backwards else 'next'

    transform = elt_type.xContainerOfTransform(field_name_or_path)

    while True:
        elt  = elt.chkGetChildMemberWithName(key)
        addr = elt.GetValueAsAddress()
        if addr == 0 or addr == stop:
            break
        elt = elt.Dereference()
        yield transform(elt)


def iter_queue(head_value, elt_type, field_name_or_path, backwards=False, unpack=None):
    """
    Iterate over queue of elements (<kern/queue.h> method 2)

    @param head_value (SBValue)
        A reference to the list head (queue_head_t)

    @param elt_type (SBType)
        The type of the elements on the chain

    @param field_name_or_path (str)
        The name of (or path to if starting with '.') the field
        containing the struct mpsc_queue_chain linkage.

    @param backwards (bool)
         Whether the walk is forward or backwards

    @param unpack (Function or None)
        A function to unpack the pointer or None.
    """

    stop = head_value.GetLoadAddress()
    elt  = head_value
    key  = '.prev' if backwards else '.next'
    addr = elt.xGetScalarByPath(key)
    if field_name_or_path[0] == '.':
        path = field_name_or_path + key
    else:
        path = "." + field_name_or_path + key

    while True:
        if unpack is not None:
            addr = unpack(addr)
        if addr == 0 or addr == stop:
            break

        elt  = elt.xCreateValueFromAddress('element', addr, elt_type)
        addr = elt.xGetScalarByPath(path)
        yield elt


def iter_circle_queue(head_value, elt_type, field_name_or_path, backwards=False):
    """
    Iterate over a queue of entries (<kern/circle_queue.h>)

    @param head_value (SBValue)
        a reference to the list head (circle_queue_head_t)

    @param elt_type (SBType)
        The type of the elements on the chain

    @param field_name_or_path (str)
        The name of (or path to if starting with '.') the field
        containing the struct mpsc_queue_chain linkage.

    @param backwards (bool)
         Whether the walk is forward or backwards
    """

    elt  = head_value.chkGetChildMemberWithName('head')
    stop = elt.GetValueAsAddress()
    key  = 'prev' if backwards else 'next'

    transform = elt_type.xContainerOfTransform(field_name_or_path)

    if stop:
        if backwards:
            elt  = elt.chkGetValueForExpressionPath('->prev')
            stop = elt.GetValueAsAddress()

        while True:
            elt = elt.Dereference()
            yield transform(elt)

            elt  = elt.chkGetChildMemberWithName(key)
            addr = elt.GetValueAsAddress()
            if addr == 0 or addr == stop:
                break


def iter_mpsc_queue(head_value, elt_type, field_name_or_path):
    """
    Iterates a struct mpsc_queue_head

    @param head_value (SBValue)
        A struct mpsc_queue_head value.

    @param elt_type (SBType)
        The type of the elements on the chain.

    @param field_name_or_path (str)
        The name of (or path to if starting with '.') the field
        containing the struct mpsc_queue_chain linkage.

    @returns (generator)
        An iterator for the MPSC queue.
    """

    transform = elt_type.xContainerOfTransform(field_name_or_path)

    return (transform(e) for e in iter_linked_list(
        head_value, 'mpqc_next', '.mpqh_head.mpqc_next'
    ))


class iter_priority_queue(object):
    """
    Iterates any of the priority queues from <kern/priority_queue.h>

    @param head_value (SBValue)
        A struct priority_queue* value.

    @param elt_type (SBType)
        The type of the elements on the chain.

    @param field_name_or_path (str)
        The name of (or path to if starting with '.') the field
        containing the struct priority_queue_entry* linkage.

    @returns (generator)
        An iterator for the priority queue.
    """

    def __init__(self, head_value, elt_type, field_name_or_path):
        self.transform = elt_type.xContainerOfTransform(field_name_or_path)

        self.gen   = iter_linked_list(head_value, 'next', 'pq_root')
        self.queue = []

    def __iter__(self):
        return self

    def __next__(self):
        queue = self.queue
        elt   = next(self.gen, None)

        if elt is None:
            if not len(queue):
                raise StopIteration
            elt = queue.pop()
            self.gen = iter_linked_list(elt, 'next', 'next')

        addr = elt.xGetScalarByName('child')
        if addr:
            queue.append(elt.xCreateValueFromAddress(
                None, addr & 0xffffffffffffffff, elt.GetType()))

        return self.transform(elt)

def iter_SLIST_HEAD(head_value, link_field):
    """ Specialized version of iter_linked_list() for <sys/queue.h> SLIST_HEAD """

    next_path = ".{}.sle_next".format(link_field)
    return (e for e in iter_linked_list(head_value, next_path, "slh_first"))

def iter_LIST_HEAD(head_value, link_field):
    """ Specialized version of iter_linked_list() for <sys/queue.h> LIST_HEAD """

    next_path = ".{}.le_next".format(link_field)
    return (e for e in iter_linked_list(head_value, next_path, "lh_first"))

def iter_STAILQ_HEAD(head_value, link_field):
    """ Specialized version of iter_linked_list() for <sys/queue.h> STAILQ_HEAD """

    next_path = ".{}.stqe_next".format(link_field)
    return (e for e in iter_linked_list(head_value, next_path, "stqh_first"))

def iter_TAILQ_HEAD(head_value, link_field):
    """ Specialized version of iter_linked_list() for <sys/queue.h> TAILQ_HEAD """

    next_path = ".{}.tqe_next".format(link_field)
    return (e for e in iter_linked_list(head_value, next_path, "tqh_first"))

def iter_SYS_QUEUE_HEAD(head_value, link_field):
    """ Specialized version of iter_linked_list() for any <sys/queue> *_HEAD """

    field  = head_value.rawGetType().GetFieldAtIndex(0).GetName()
    next_path = ".{}.{}e_next".format(link_field, field.partition('_')[0])
    return (e for e in iter_linked_list(head_value, next_path, field))


class RB_HEAD(object):
    """
    Class providing utilities to manipulate a collection made with RB_HEAD()
    """

    def __init__(self, rbh_root_value, link_field, cmp):
        """
        Create an rb-tree wrapper

        @param rbh_root_value (SBValue)
            A value to an RB_HEAD() field

        @param link_field (str)
            The name of the RB_ENTRY() field in the elements

        @param cmp (function)
            The comparison function between a linked element,
            and a key
        """

        self.root_sbv = rbh_root_value.chkGetChildMemberWithName('rbh_root')
        self.field    = link_field
        self.cmp      = cmp
        self.etype    = self.root_sbv.GetType().GetPointeeType()

    def _parent(self, elt):
        RB_COLOR_MASK = 0x1

        path   = "." + self.field + ".rbe_parent"
        parent = elt.chkGetValueForExpressionPath(path)
        paddr  = parent.GetValueAsAddress()

        if paddr == 0:
            return None, 0

        if paddr & RB_COLOR_MASK == 0:
            return parent.Dereference(), paddr

        paddr &= ~RB_COLOR_MASK
        return parent.xCreateValueFromAddress(None, paddr, self.etype), paddr

    def _sibling(self, elt, rbe_left, rbe_right):

        field = self.field
        lpath = "." + field + rbe_left
        rpath = "." + field + rbe_right

        e_r = elt.chkGetValueForExpressionPath(rpath)
        if e_r.GetValueAsAddress():
            #
            # IF (RB_RIGHT(elm, field)) {
            #     elm = RB_RIGHT(elm, field);
            #     while (RB_LEFT(elm, field))
            #         elm = RB_LEFT(elm, field);
            #

            path = "->" + field + rbe_left
            res = e_r
            e_l = res.chkGetValueForExpressionPath(path)
            while e_l.GetValueAsAddress():
                res = e_l
                e_l = res.chkGetValueForExpressionPath(path)

            return res.Dereference()

        eaddr = elt.GetLoadAddress()
        e_p, paddr = self._parent(elt)

        if paddr:
            #
            # IF (RB_GETPARENT(elm) &&
            #     (elm == RB_LEFT(RB_GETPARENT(elm), field)))
            #         elm = RB_GETPARENT(elm)
            #

            if e_p.xGetScalarByPath(lpath) == eaddr:
                return e_p

            #
            # WHILE (RB_GETPARENT(elm) &&
            #     (elm == RB_RIGHT(RB_GETPARENT(elm), field)))
            #         elm = RB_GETPARENT(elm);
            # elm = RB_GETPARENT(elm);
            #

            while paddr:
                if e_p.xGetScalarByPath(rpath) != eaddr:
                    return e_p

                eaddr = paddr
                e_p, paddr = self._parent(e_p)

        return None

    def _find(self, key):
        elt = self.root_sbv
        f   = self.field
        le  = None
        ge  = None

        while elt.GetValueAsAddress():
            elt = elt.Dereference()
            rc  = self.cmp(elt, key)
            if rc == 0:
                return elt, elt, elt

            if rc < 0:
                ge  = elt
                elt = elt.chkGetValueForExpressionPath("->" + f + ".rbe_left")
            else:
                le  = elt
                elt = elt.chkGetValueForExpressionPath("->" + f + ".rbe_right")

        return le, None, ge

    def _minmax(self, direction):
        """ Returns the first element in the tree """

        elt  = self.root_sbv
        res  = None
        path = "->" + self.field + direction

        while elt.GetValueAsAddress():
            res = elt
            elt = elt.chkGetValueForExpressionPath(path)

        return res.Dereference() if res is not None else None

    def find_lt(self, key):
        """ Find the element smaller than the specified key """

        elt, _, _ = self._find(key)
        return self.prev(elt) if elt is not None else None

    def find_le(self, key):
        """ Find the element smaller or equal to the specified key """

        elt, _, _ = self._find(key)
        return elt

    def find(self, key):
        """ Find the element with the specified key """

        _, elt, _ = self._find(key)
        return elt

    def find_ge(self, key):
        """ Find the element greater or equal to the specified key """

        _, _, elt = self._find(key)
        return elt

    def find_gt(self, key):
        """ Find the element greater than the specified key """

        _, _, elt = self._find(key)
        return self.next(elt.Dereference()) if elt is not None else None

    def first(self):
        """ Returns the first element in the tree """

        return self._minmax(".rbe_left")

    def last(self):
        """ Returns the last element in the tree """

        return self._minmax(".rbe_right")

    def next(self, elt):
        """ Returns the next element in rbtree order or None """

        return self._sibling(elt, ".rbe_left", ".rbe_right")

    def prev(self, elt):
        """ Returns the next element in rbtree order or None """

        return self._sibling(elt, ".rbe_right", ".rbe_left")

    def iter(self, min_key=None, max_key=None):
        """
        Iterates all elements in this red-black tree
        with min_key <= key < max_key
        """

        e = self.first() if min_key is None else self.find_ge(min_key)

        if max_key is None:
            while e is not None:
                yield e
                e = self.next(e)
        else:
            cmp = self.cmp
            while e is not None and cmp(e, max_key) >= 0:
                yield e
                e = self.next(e)

    def __iter__(self):
        return self.iter()


def iter_RB_HEAD(rbh_root_value, link_field):
    """ Conveniency wrapper for RB_HEAD iteration """

    return RB_HEAD(rbh_root_value, link_field, None).iter()


def iter_smr_queue(head_value, elt_type, field_name_or_path):
    """
    Iterate over an SMR queue of entries (<kern/smr.h>)

    @param head_value (SBValue)
        a reference to the list head (struct smrq_*_head)

    @param elt_type (SBType)
        The type of the elements on the chain

    @param field_name_or_path (str)
        The name of (or path to if starting with '.') the field
        containing the struct mpsc_queue_chain linkage.
    """

    transform = elt_type.xContainerOfTransform(field_name_or_path)

    return (transform(e) for e in iter_linked_list(
        head_value, '.next.__smr_ptr', '.first.__smr_ptr'
    ))

class _Hash(object, metaclass=ABCMeta):
    @abstractproperty
    def buckets(self):
        """
        Returns the number of buckets in the hash table
        """
        pass

    @abstractproperty
    def count(self):
        """
        Returns the number of elements in the hash table
        """
        pass

    @abstractproperty
    def rehashing(self):
        """
        Returns whether the hash is currently rehashing
        """
        pass

    @abstractmethod
    def iter(self, detailed=False):
        """
        @param detailed (bool)
            whether to enumerate just elements, or show bucket info too
            when bucket info is requested, enumeration returns a tuple of:
            (0, bucket, index_in_bucket, element)
        """
        pass

    def __iter__(self):
        return self.iter()

    def describe(self):
        fmt = (
            "Hash table info\n"
            " address              : {1:#x}\n"
            " element count        : {0.count}\n"
            " bucket count         : {0.buckets}\n"
            " rehashing            : {0.rehashing}"
        )
        print(xnu_format(fmt, self, self.hash_value.GetLoadAddress()))

        if self.rehashing:
            print()
            return

        b_len = {}
        for _, b_idx, e_idx, _ in self.iter(detailed=True):
            b_len[b_idx] = e_idx + 1

        stats = {i: 0 for i in range(max(b_len.values()) + 1) }
        for v in b_len.values():
            stats[v] += 1
        stats[0] = self.buckets - len(b_len)

        fmt = (
            " histogram            :\n"
            "  {:>4}  {:>6}  {:>6}  {:>6}  {:>5}"
        )
        print(xnu_format(fmt, "size", "count", "(cum)", "%", "(cum)"))

        fmt = "  {:>4,d}  {:>6,d}  {:>6,d}  {:>6.1%}  {:>5.0%}"
        tot = 0
        for sz, n in stats.items():
            tot += n
            print(xnu_format(fmt, sz, n, tot, n / self.buckets, tot / self.buckets))


class SMRHash(_Hash):
    """
    Class providing utilities to manipulate SMR hash tables
    """

    def __init__(self, hash_value, traits_value):
        """
        Create an smr hash table iterator

        @param hash_value (SBValue)
            a reference to a struct smr_hash instance.

        @param traits_value (SBValue)
            a reference to the traits for this hash table
        """
        super().__init__()
        self.hash_value = hash_value
        self.traits_value = traits_value

    @property
    def buckets(self):
        hash_arr = self.hash_value.xGetScalarByName('smrh_array')
        return 1 << (64 - ((hash_arr >> 48) & 0xffff))

    @property
    def count(self):
        return self.hash_value.xGetScalarByName('smrh_count')

    @property
    def rehashing(self):
        return self.hash_value.xGetScalarByName('smrh_resizing')

    def iter(self, detailed=False):
        obj_null = self.traits_value.chkGetChildMemberWithName('smrht_obj_type')
        obj_ty   = obj_null.GetType().GetArrayElementType().GetPointeeType()
        lnk_offs = self.traits_value.xGetScalarByPath('.smrht.link_offset')
        hash_arr = self.hash_value.xGetScalarByName('smrh_array')
        hash_sz  = 1 << (64 - ((hash_arr >> 48) & 0xffff))
        hash_arr = obj_null.xCreateValueFromAddress(None,
            hash_arr | 0xffff000000000000, gettype('struct smrq_slist_head'));

        if detailed:
            return (
                (0, head_idx, e_idx, e.xCreateValueFromAddress(None, e.GetLoadAddress() - lnk_offs, obj_ty))
                for head_idx, head in enumerate(hash_arr.xIterSiblings(0, hash_sz))
                for e_idx, e in enumerate(iter_linked_list(head, '.next.__smr_ptr', '.first.__smr_ptr'))
            )

        return (
            e.xCreateValueFromAddress(None, e.GetLoadAddress() - lnk_offs, obj_ty)
            for head in hash_arr.xIterSiblings(0, hash_sz)
            for e in iter_linked_list(head, '.next.__smr_ptr', '.first.__smr_ptr')
        )


class SMRScalableHash(_Hash):

    def __init__(self, hash_value, traits_value):
        """
        Create an smr hash table iterator

        @param hash_value (SBValue)
            a reference to a struct smr_shash instance.

        @param traits_value (SBValue)
            a reference to the traits for this hash table
        """
        super().__init__()
        self.hash_value = hash_value
        self.traits_value = traits_value

    @property
    def buckets(self):
        shift = self.hash_value.xGetScalarByPath('.smrsh_state.curshift')
        return (0xffffffff >> shift) + 1;

    @property
    def count(self):
        sbv      = self.hash_value.chkGetChildMemberWithName('smrsh_count')
        addr     = sbv.GetValueAsAddress() | 0x8000000000000000
        target   = sbv.GetTarget()
        ncpus    = target.chkFindFirstGlobalVariable('zpercpu_early_count').xGetValueAsInteger()
        pg_shift = target.chkFindFirstGlobalVariable('page_shift').xGetValueAsInteger()

        return sum(
            target.xReadInt64(addr + (cpu << pg_shift))
            for cpu in range(ncpus)
        )

    @property
    def rehashing(self):
        curidx = self.hash_value.xGetScalarByPath('.smrsh_state.curidx');
        newidx = self.hash_value.xGetScalarByPath('.smrsh_state.newidx');
        return curidx != newidx

    def iter(self, detailed=False):
        """
        @param detailed (bool)
            whether to enumerate just elements, or show bucket info too
            when bucket info is requested, enumeration returns a tuple of:
            (table_index, bucket, index_in_bucket, element)
        """

        hash_value   = self.hash_value
        traits_value = self.traits_value

        obj_null = traits_value.chkGetChildMemberWithName('smrht_obj_type')
        obj_ty   = obj_null.GetType().GetArrayElementType().GetPointeeType()
        lnk_offs = traits_value.xGetScalarByPath('.smrht.link_offset')
        hashes   = []

        curidx   = hash_value.xGetScalarByPath('.smrsh_state.curidx');
        newidx   = hash_value.xGetScalarByPath('.smrsh_state.newidx');
        arrays   = hash_value.chkGetChildMemberWithName('smrsh_array')

        array    = arrays.chkGetChildAtIndex(curidx)
        shift    = hash_value.xGetScalarByPath('.smrsh_state.curshift')
        hashes.append((curidx, array.Dereference(), shift))
        if newidx != curidx:
            array    = arrays.chkGetChildAtIndex(newidx)
            shift    = hash_value.xGetScalarByPath('.smrsh_state.newshift')
            hashes.append((newidx, array.Dereference(), shift))

        seen = set()

        def _iter_smr_shash_bucket(head):
            addr  = head.xGetScalarByName('lck_ptr_bits')
            tg    = head.GetTarget()

            while addr & 1 == 0:
                addr &= 0xfffffffffffffffe
                e = head.xCreateValueFromAddress(None, addr - lnk_offs, obj_ty)
                if addr not in seen:
                    seen.add(addr)
                    yield e
                addr = tg.xReadULong(addr);

        if detailed:
            return (
                (hash_idx, head_idx, e_idx, e)
                for hash_idx, hash_arr, hash_shift in hashes
                for head_idx, head in enumerate(hash_arr.xIterSiblings(
                    0, 1 + (0xffffffff >> hash_shift)))
                for e_idx, e in enumerate(_iter_smr_shash_bucket(head))
            )

        return (
            e
            for hash_idx, hash_arr, hash_shift in hashes
            for head in hash_arr.xIterSiblings(
                0, 1 + (0xffffffff >> hash_shift))
            for e in _iter_smr_shash_bucket(head)
        )


__all__ = [
    iter_linked_list.__name__,
    iter_queue_entries.__name__,
    iter_queue.__name__,
    iter_circle_queue.__name__,

    iter_mpsc_queue.__name__,

    iter_priority_queue.__name__,

    iter_SLIST_HEAD.__name__,
    iter_LIST_HEAD.__name__,
    iter_STAILQ_HEAD.__name__,
    iter_TAILQ_HEAD.__name__,
    iter_SYS_QUEUE_HEAD.__name__,
    iter_RB_HEAD.__name__,
    RB_HEAD.__name__,

    iter_smr_queue.__name__,
    SMRHash.__name__,
    SMRScalableHash.__name__,
]