import gdb
import gdb.printing

class SharedPointerPrinter:
    "Print a shared_ptr or weak_ptr"

    class _iterator:
        def __init__(self, sharedPointer):
            self.sharedPointer = sharedPointer
            self.managedValue = sharedPointer.val['_M_ptr']
            self.count = 0

        def __iter__(self):
            return self

        def next(self):
            if self.managedValue == 0:
                raise StopIteration
            self.count = self.count + 1
            if (self.count == 1):
                return ('Use count', self.sharedPointer.val['_M_refcount']['_M_pi']['_M_use_count'])
            elif (self.count == 2):
                return ('Weak count', self.sharedPointer.val['_M_refcount']['_M_pi']['_M_weak_count'] - 1)
            elif (self.count == 3):
                return ('Managed value', self.managedValue)
            else:
                raise StopIteration

    def __init__ (self, typename, val):
        self.typename = typename
        self.val = val

    def children (self):
        return self._iterator(self)

    def to_string (self):
        state = 'empty'
        refcounts = self.val['_M_refcount']['_M_pi']
        if refcounts != 0:
            usecount = refcounts['_M_use_count']
            weakcount = refcounts['_M_weak_count']
            if usecount == 0:
                state = 'expired, weakcount %d' % weakcount
            else:
                state = 'usecount %d, weakcount %d' % (usecount, weakcount - 1)
        return '%s (%s) to %s' % (self.typename, state, self.val['_M_ptr'])

class UniquePointerPrinter:
    "Print a unique_ptr"

    class _iterator:
        def __init__(self, uniquePointer):
            self.uniquePointer = uniquePointer
            self.managedValue = uniquePointer.val['_M_t']['_M_head_impl']
            self.count = 0

        def __iter__(self):
            return self

        def next(self):
            if self.managedValue == 0 or self.count == 1:
                raise StopIteration
            self.count = self.count + 1
            return ('Managed value', self.managedValue)

    def __init__ (self, typename, val):
        self.val = val

    def children (self):
        return self._iterator(self)

    def to_string (self):
        v = self.val['_M_t']['_M_head_impl']
        return ('std::unique_ptr<%s> containing %s' % (str(v.type.target()),
                                                       str(v)))

def mk_pretty_printers():
    pp = gdb.printing.RegexpCollectionPrettyPrinter("myprinters")
    pp.add_printer('shared_ptr', '^std::shared_ptr', SharedPointerPrinter)
    pp.add_printer('unique_ptr', '^std::unique_ptr', UniquePointerPrinter)
    return pp