Hackbright Code Challenges

Binary Search Tree Validator: Solution

Binary Search Tree Validator: Solution


Binary Search Tree Validator






Recursion, Trees



Consider our valid tree:

digraph valid { rankdir=TB 4 -> {2, 6} 2 -> {1, 3} 6 -> {5, 7} }

All of the left-hand descendants of 4 must be less than (or equal to) 4. All of the right-hand descendants of 4 must be greater than (or equal to) 4.

However, this continues down: so the children of 2 must also be less than (or equal to) 4, and the children of 6 must be greater than (or equal to) 4.

We can solve this by keeping track of two variables:


A value all children must be less than (or equal to)


A value all children must be greater than (or equal to)

As you go down the levels, when you follow a left child path, lt becomes the data of the parent node you’re traveling left from. WHen you follow a right path, gt becomes the data of the parent node you’re traveling right from. At each node, you can compare the data to make sure it fits within these bounds.

A straightforward version is:

class Node(object): # ...
    def is_valid(self):
        """Is this tree a valid BST?"""


        def _ok(n, lt, gt):
            """Check this node & recurse to children

                lt: left children must be <= this
                gt: right child must be >= this

            if n is None:
                # base case: this isn't a node
                return True

            if lt is not None and n.data > lt:
                # base case: bigger than allowed
                #  we'll fail fast here
                return False

            if gt is not None and n.data < gt:
                # base case: smaller than allowed
                #  we'll fail fast here
                return False

            if not _ok(n.left, n.data, gt):
                # general case: check our left child
                #   all descendants of left child must be
                #   less than our data (and greater than
                #   whatever we had to be greater than).
                #   if not, fail fast.
                return False

            if not _ok(n.right, lt, n.data):
                # general case: check our right child
                #   all descendants of right child must be
                #   greater than our data (and less than
                #   whatever we had to be less than)
                #   if not, fail fast.
                return False

            # If we reach here, we're either a leaf node with
            # valid data for lt/gt, or we're higher up, but
            # our recursive calls downward succeeded. Either way,
            # this is our winning base case.
            return True

        # Call our recursive function, starting here.
        # Since we haven't yet gone left or right, we don't know
        # our `lt` or `gt` values yet, so pass None for these.

        return _ok(self, None, None)

Using Exceptions

Sometimes, it can be more clear to use exceptions to signal failure than having to pass False up a recursive call stack. Here’s a version that shows that technique:

class Node(object): # ...
    def is_valid_exception(self):
        """Is tree a valid BST?

        This recurses similiarly to `is_valid`, but it uses
        exceptions to immediate exit when an invalid value is
        found. Using execeptions for quick control-passing can
        sometimes make for clearer code.

        def _ok(n, lt, gt):
            """Check this node & recurse to children

                lt: left children must be <= this
                gt: right child must be >= this

            if n is None:
                # base case: this isn't a node

            if ((lt is not None and n.data > lt) or
                    (gt is not None and n.data < gt)):
                # base case: we're either smaller or bigger
                # than allowed. Raise exception to return
                # back to `is_valid_exception`
                raise ValueError

            # Check our children (see `is_valid` for comments)
            _ok(n.left, n.data, gt)
            _ok(n.right, lt, n.data)

        # Call our recursive function --- if it returns,
        # the tree is valid. If it raises a ValueError, it's
        # invalid.

            _ok(self, None, None)
            return True

        except ValueError:
            return False

As An Expression

With some clever thinking, we can write this entire method as a single (recursive) expression — though it’s pretty hairy to follow:

class Node(object): # ...
    def is_valid_expression(self, lt=None, gt=None):
        """Is tree a valid BST?

        This uses a single expression --- the logic is the same
        as `is_valid`, but packed into an expression.

        This is a useful demonstration of how powerful logical
        expressions can be, but it's probably a terrible way to
        write this.

        return (
            not (lt is not None and self.data >= lt) and
            not (gt is not None and self.data <= gt) and
            (self.left is None or
             self.left.is_valid_4(self.data, gt)) and
            (self.right is None or
             self.right.is_valid_4(lt, self.data))

Iteration Solutions

Another way to think about the problem is: “if I gather the nodes in the order of traversal of the tree, is this list sorted”?

We can write a method that will let us loop over the tree using Python for-loops and list comprehensions. To do this, we’ll create a special method, __iter__.

class Node(object): # ...
    def __iter__(self):
        """Iterate over nodes in BST in proper order.

        The __iter__ method is called when you iterate
        over an object. It should yield successive
        values (for information on yielding, learn about

        Our BST can be iterated over to get the values
        in order. For example, for this tree::

             2     6
            1 3   5 7

        We can loop over it::

            >>> t = Node(4,
            ...       Node(2, Node(1), Node(3)),
            ...       Node(6, Node(5), Node(7))
            ... )

            >>> for n in t:
            ...     print(n.data, end=' ')
            1 2 3 4 5 6 7 

        This method of navigating a BST by left-recurse, self,
        right-recurse, is often called "in-order traversal".

        # walk the left descendants recursively:
        for n in self.left or []:
            yield n

        # hand back this node
        yield self

        # walk the right descendants recursively:
        for n in self.right or []:
            yield n

(This requires understanding of “generators” and the yield keyword).

Now, we can loop over our node and get all the children in traversal order. We can compare this to the sorted version of that list, to see if they’re the same:

class Node(object): # ...
    def is_valid_using_iter_sort(self):
        """Is tree a valid BST?

        Compare the iteration order with the sort order; if
        they're different, it's not a valid tree.

        This method of checking for validity isn't nearly as
        efficient --- we have to walk the tree (O(n))
        and then sort the nodes (O(n log n)). Our runtime is
        therefore O(n log n), which is greater than O(n) for
        the other methods. We also can't fail-fast, unlike the
         other methods --- they quit as soon as they find
         an invalid value, whereas this iterates over the
         entire tree.

        It is a good example of how having an __iter__ method
        can be useful, though.

        # Get node data in traversal-order
        values = [n.data for n in self]

        return values == sorted(values)

This solution is clever and short — but as noted in the comments, the runtime becomes O(n log n) because of the sort.

We can keep the simple use of iteration with a O(n) solution by removing the sorting and checking manually that the list is in sorted order:

class Node(object): # ...
    def is_valid_using_iter_check(self):
        """Is tree a valid BST?

        Another way to use our __iter__ method --- this time,
        walking over the iteration, and just making sure it
        doesn't ever go backwards.

        This solution is O(n) and does let us fail fast.

        last = None

        for n in self:
            if last is not None and n.data < last:
                return False
            last = n.data

        # Made it through without problerms, in right order!
        return True

        # Another possibility in similar style.
        # This is still O(n), but no longer fails fast --- since
        # ``list(self)`` will traverse the entire tree

        ns = list(self)
        return all(ns[i] >= ns[i - 1] for i in range(1, len(ns)))