How to implement a binary search tree in Python?

Question:

This is what I’ve got so far but it is not working:

class Node:
    rChild,lChild,data = None,None,None

    def __init__(self,key):
        self.rChild = None
        self.lChild = None
        self.data = key

class Tree:
    root,size = None,0
    def __init__(self):
        self.root = None
        self.size = 0

    def insert(self,node,someNumber):
        if node is None:
            node = Node(someNumber)
        else:
            if node.data > someNumber:
                self.insert(node.rchild,someNumber)
            else:
                self.insert(node.rchild, someNumber)
        return

def main():
    t = Tree()
    t.root = Node(4)
    t.root.rchild = Node(5)
    print t.root.data #this works
    print t.root.rchild.data #this works too
    t = Tree()
    t.insert(t.root,4)
    t.insert(t.root,5)
    print t.root.data #this fails
    print t.root.rchild.data #this fails too

if __name__ == '__main__':
     main()
Asked By: chochim

||

Answers:

The Op’s Tree.insert method qualifies for the “Gross Misnomer of the Week” award — it doesn’t insert anything. It creates a node which is not attached to any other node (not that there are any nodes to attach it to) and then the created node is trashed when the method returns.

For the edification of @Hugh Bothwell:

>>> class Foo(object):
...    bar = None
...
>>> a = Foo()
>>> b = Foo()
>>> a.bar
>>> a.bar = 42
>>> b.bar
>>> b.bar = 666
>>> a.bar
42
>>> b.bar
666
>>>
Answered By: John Machin
class Node: 
    rChild,lChild,data = None,None,None

This is wrong – it makes your variables class variables – that is, every instance of Node uses the same values (changing rChild of any node changes it for all nodes!). This is clearly not what you want; try

class Node: 
    def __init__(self, key):
        self.rChild = None
        self.lChild = None
        self.data = key

now each node has its own set of variables. The same applies to your definition of Tree,

class Tree:
    root,size = None,0    # <- lose this line!
    def __init__(self):
        self.root = None
        self.size = 0

Further, each class should be a “new-style” class derived from the “object” class and should chain back to object.__init__():

class Node(object): 
    def __init__(self, data, rChild=None, lChild=None):
        super(Node,self).__init__()
        self.data   = data
        self.rChild = rChild
        self.lChild = lChild

class Tree(object):
    def __init__(self):
        super(Tree,self).__init__()
        self.root = None
        self.size = 0

Also, main() is indented too far – as shown, it is a method of Tree which is uncallable because it does not accept a self argument.

Also, you are modifying the object’s data directly (t.root = Node(4)) which kind of destroys encapsulation (the whole point of having classes in the first place); you should be doing something more like

def main():
    t = Tree()
    t.add(4)    # <- let the tree create a data Node and insert it
    t.add(5)
Answered By: Hugh Bothwell

Just something to help you to start on.

A (simple idea of) binary tree search would be quite likely be implement in python according the lines:

def search(node, key):
    if node is None: return None  # key not found
    if key< node.key: return search(node.left, key)
    elif key> node.key: return search(node.right, key)
    else: return node.value  # found key

Now you just need to implement the scaffolding (tree creation and value inserts) and you are done.

Answered By: eat

Here is a quick example of a binary insert:

class Node:
    def __init__(self, val):
        self.l_child = None
        self.r_child = None
        self.data = val

def binary_insert(root, node):
    if root is None:
        root = node
    else:
        if root.data > node.data:
            if root.l_child is None:
                root.l_child = node
            else:
                binary_insert(root.l_child, node)
        else:
            if root.r_child is None:
                root.r_child = node
            else:
                binary_insert(root.r_child, node)

def in_order_print(root):
    if not root:
        return
    in_order_print(root.l_child)
    print root.data
    in_order_print(root.r_child)

def pre_order_print(root):
    if not root:
        return        
    print root.data
    pre_order_print(root.l_child)
    pre_order_print(root.r_child)    

r = Node(3)
binary_insert(r, Node(7))
binary_insert(r, Node(1))
binary_insert(r, Node(5))

     3
    / 
   1   7
      /
     5

print "in order:"
in_order_print(r)

print "pre order"
pre_order_print(r)

in order:
1
3
5
7
pre order
3
1
7
5
Answered By: dting
class Node:
    rChild,lChild,parent,data = None,None,None,0    

def __init__(self,key):
    self.rChild = None
    self.lChild = None
    self.parent = None
    self.data = key 

class Tree:
    root,size = None,0
    def __init__(self):
        self.root = None
        self.size = 0
    def insert(self,someNumber):
        self.size = self.size+1
        if self.root is None:
            self.root = Node(someNumber)
        else:
            self.insertWithNode(self.root, someNumber)    

    def insertWithNode(self,node,someNumber):
        if node.lChild is None and node.rChild is None:#external node
            if someNumber > node.data:
                newNode = Node(someNumber)
                node.rChild = newNode
                newNode.parent = node
            else:
                newNode = Node(someNumber)
                node.lChild = newNode
                newNode.parent = node
        else: #not external
            if someNumber > node.data:
                if node.rChild is not None:
                    self.insertWithNode(node.rChild, someNumber)
                else: #if empty node
                    newNode = Node(someNumber)
                    node.rChild = newNode
                    newNode.parent = node 
            else:
                if node.lChild is not None:
                    self.insertWithNode(node.lChild, someNumber)
                else:
                    newNode = Node(someNumber)
                    node.lChild = newNode
                    newNode.parent = node                    

    def printTree(self,someNode):
        if someNode is None:
            pass
        else:
            self.printTree(someNode.lChild)
            print someNode.data
            self.printTree(someNode.rChild)

def main():  
    t = Tree()
    t.insert(5)  
    t.insert(3)
    t.insert(7)
    t.insert(4)
    t.insert(2)
    t.insert(1)
    t.insert(6)
    t.printTree(t.root)

if __name__ == '__main__':
    main()

My solution.

Answered By: chochim
class BST:
    def __init__(self, val=None):
        self.left = None
        self.right = None
        self.val = val

    def __str__(self):
        return "[%s, %s, %s]" % (self.left, str(self.val), self.right)

    def isEmpty(self):
        return self.left == self.right == self.val == None

    def insert(self, val):
        if self.isEmpty():
            self.val = val
        elif val < self.val:
            if self.left is None:
                self.left = BST(val)
            else:
                self.left.insert(val)
        else:
            if self.right is None:
                self.right = BST(val)
            else:
                self.right.insert(val)

a = BST(1)
a.insert(2)
a.insert(3)
a.insert(0)
print a
Answered By: Aram Kocharyan

Another Python BST with sort key (defaulting to value)

LEFT = 0
RIGHT = 1
VALUE = 2
SORT_KEY = -1

class BinarySearchTree(object):

    def __init__(self, sort_key=None):
        self._root = []  
        self._sort_key = sort_key
        self._len = 0  

def insert(self, val):
    if self._sort_key is None:
        sort_key = val // if no sort key, sort key is value
    else:
        sort_key = self._sort_key(val)

    node = self._root
    while node:
        if sort_key < node[_SORT_KEY]:
            node = node[LEFT]
        else:
            node = node[RIGHT]

    if sort_key is val:
        node[:] = [[], [], val]
    else:
        node[:] = [[], [], val, sort_key]
    self._len += 1

def minimum(self):
    return self._extreme_node(LEFT)[VALUE]

def maximum(self):
    return self._extreme_node(RIGHT)[VALUE]

def find(self, sort_key):
    return self._find(sort_key)[VALUE]

def _extreme_node(self, side):
    if not self._root:
        raise IndexError('Empty')
    node = self._root
    while node[side]:
        node = node[side]
    return node

def _find(self, sort_key):
    node = self._root
    while node:
        node_key = node[SORT_KEY]
        if sort_key < node_key:
            node = node[LEFT]
        elif sort_key > node_key:
            node = node[RIGHT]
        else:
            return node
    raise KeyError("%r not found" % sort_key)
Answered By: kiriloff

Here is a compact, object oriented, recursive implementation:

    class BTreeNode(object):
        def __init__(self, data):
            self.data = data
            self.rChild = None
            self.lChild = None

    def __str__(self):
        return (self.lChild.__str__() + '<-' if self.lChild != None else '') + self.data.__str__() + ('->' + self.rChild.__str__() if self.rChild != None else '')

    def insert(self, btreeNode):
        if self.data > btreeNode.data: #insert left
            if self.lChild == None:
                self.lChild = btreeNode
            else:
                self.lChild.insert(btreeNode)
        else: #insert right
            if self.rChild == None:
                self.rChild = btreeNode
            else:
                self.rChild.insert(btreeNode)


def main():
    btreeRoot = BTreeNode(5)
    print 'inserted %s:' %5, btreeRoot

    btreeRoot.insert(BTreeNode(7))
    print 'inserted %s:' %7, btreeRoot

    btreeRoot.insert(BTreeNode(3))
    print 'inserted %s:' %3, btreeRoot

    btreeRoot.insert(BTreeNode(1))
    print 'inserted %s:' %1, btreeRoot

    btreeRoot.insert(BTreeNode(2))
    print 'inserted %s:' %2, btreeRoot

    btreeRoot.insert(BTreeNode(4))
    print 'inserted %s:' %4, btreeRoot

    btreeRoot.insert(BTreeNode(6))
    print 'inserted %s:' %6, btreeRoot

The output of the above main() is:

inserted 5: 5
inserted 7: 5->7
inserted 3: 3<-5->7
inserted 1: 1<-3<-5->7
inserted 2: 1->2<-3<-5->7
inserted 4: 1->2<-3->4<-5->7
inserted 6: 1->2<-3->4<-5->6<-7
Answered By: Pejvan

I find the solutions a bit clumsy on the insert part. You could return the root reference and simplify it a bit:

def binary_insert(root, node):
    if root is None:
        return node
    if root.data > node.data:
        root.l_child = binary_insert(root.l_child, node)
    else:
        root.r_child = binary_insert(root.r_child, node)
    return root
Answered By: jlhonora

The following code is basic on @DTing‘s answer and what I learn from class, which uses a while loop to insert (indicated in the code).

class Node:
    def __init__(self, val):
        self.l_child = None
        self.r_child = None
        self.data = val


def binary_insert(root, node):
    y = None
    x = root
    z = node
    #while loop here
    while x is not None:
        y = x
        if z.data < x.data:
            x = x.l_child
        else:
            x = x.r_child
    z.parent = y
    if y == None:
        root = z
    elif z.data < y.data:
        y.l_child = z
    else:
        y.r_child = z


def in_order_print(root):
    if not root:
        return
    in_order_print(root.l_child)
    print(root.data)
    in_order_print(root.r_child)


r = Node(3)
binary_insert(r, Node(7))
binary_insert(r, Node(1))
binary_insert(r, Node(5))

in_order_print(r)
Answered By: Liam

Here is a working solution.

class BST:
    def __init__(self,data):
        self.root = data
        self.left = None
        self.right = None

    def insert(self,data):
        if self.root == None:
            self.root = BST(data)
        elif data > self.root:
            if self.right == None:
                self.right = BST(data)
            else:
                self.right.insert(data)
        elif data < self.root:
            if self.left == None:
                self.left = BST(data)
            else:
                self.left.insert(data)

    def inordertraversal(self):
        if self.left != None:
            self.left.inordertraversal()
        print (self.root),
        if self.right != None:
            self.right.inordertraversal()

t = BST(4)
t.insert(1)
t.insert(7)
t.insert(3)
t.insert(6)
t.insert(2)
t.insert(5)
t.inordertraversal()
Answered By: Sdhir

its easy to implement a BST using two classes, 1. Node and 2. Tree
Tree class will be just for user interface, and actual methods will be implemented in Node class.

class Node():

    def __init__(self,val):
        self.value = val
        self.left = None
        self.right = None


    def _insert(self,data):
        if data == self.value:
            return False
        elif data < self.value:
            if self.left:
                return self.left._insert(data)
            else:
                self.left = Node(data)
                return True
        else:
            if self.right:
                return self.right._insert(data)
            else:
                self.right = Node(data)
                return True

    def _inorder(self):
        if self:
            if self.left:
                self.left._inorder()
            print(self.value)
            if self.right:
                self.right._inorder()



class Tree():

    def __init__(self):
        self.root = None

    def insert(self,data):
        if self.root:
            return self.root._insert(data)
        else:
            self.root = Node(data)
            return True
    def inorder(self):
        if self.root is not None:
            return self.root._inorder()
        else:
            return False




if __name__=="__main__":
    a = Tree()
    a.insert(16)
    a.insert(8)
    a.insert(24)
    a.insert(6)
    a.insert(12)
    a.insert(19)
    a.insert(29)
    a.inorder()

Inorder function for checking whether BST is properly implemented.

Answered By: nirav bharadiya

The accepted answer neglects to set a parent attribute for each node inserted, without which one cannot implement a successor method which finds the successor in an in-order tree walk in O(h) time, where h is the height of the tree (as opposed to the O(n) time needed for the walk).

Here is an implementation based on the pseudocode given in Cormen et al., Introduction to Algorithms, including assignment of a parent attribute and a successor method:

class Node(object):
    def __init__(self, key):
        self.key = key
        self.left = None
        self.right = None
        self.parent = None


class Tree(object):
    def __init__(self, root=None):
        self.root = root

    def insert(self, z):
        y = None
        x = self.root
        while x is not None:
            y = x
            if z.key < x.key:
                x = x.left
            else:
                x = x.right
        z.parent = y
        if y is None:
            self.root = z       # Tree was empty
        elif z.key < y.key:
            y.left = z
        else:
            y.right = z

    @staticmethod
    def minimum(x):
        while x.left is not None:
            x = x.left
        return x

    @staticmethod
    def successor(x):
        if x.right is not None:
            return Tree.minimum(x.right)
        y = x.parent
        while y is not None and x == y.right:
            x = y
            y = y.parent
        return y

Here are some tests to show that the tree behaves as expected for the example given by DTing:

import pytest

@pytest.fixture
def tree():
    t = Tree()
    t.insert(Node(3))
    t.insert(Node(1))
    t.insert(Node(7))
    t.insert(Node(5))
    return t

def test_tree_insert(tree):
    assert tree.root.key == 3
    assert tree.root.left.key == 1
    assert tree.root.right.key == 7
    assert tree.root.right.left.key == 5

def test_tree_successor(tree):
    assert Tree.successor(tree.root.left).key == 3
    assert Tree.successor(tree.root.right.left).key == 7

if __name__ == "__main__":
    pytest.main([__file__])
Answered By: Kurt Peek

The problem, or at least one problem with your code is here:-

def insert(self,node,someNumber):
    if node is None:
        node = Node(someNumber)
    else:
        if node.data > someNumber:
            self.insert(node.rchild,someNumber)
        else:
            self.insert(node.rchild, someNumber)
    return

You see the statement “if node.data > someNumber:” and the associated “else:” statement both have the same code after them. i.e you do the same thing whether the if statement is true or false.

I’d suggest you probably intended to do different things here, perhaps one of these should say self.insert(node.lchild, someNumber) ?

Answered By: Michael

Another Python BST solution

class Node(object):
    def __init__(self, value):
        self.left_node = None
        self.right_node = None
        self.value = value

    def __str__(self):
        return "[%s, %s, %s]" % (self.left_node, self.value, self.right_node)

    def insertValue(self, new_value):
        """
        1. if current Node doesnt have value then assign to self
        2. new_value lower than current Node's value then go left
        2. new_value greater than current Node's value then go right
        :return:
        """
        if self.value:
            if new_value < self.value:
                # add to left
                if self.left_node is None:  # reached start add value to start
                    self.left_node = Node(new_value)
                else:
                    self.left_node.insertValue(new_value)  # search
            elif new_value > self.value:
                # add to right
                if self.right_node is None:  # reached end add value to end
                    self.right_node = Node(new_value)
                else:
                    self.right_node.insertValue(new_value)  # search
        else:
            self.value = new_value

    def findValue(self, value_to_find):
        """
        1. value_to_find is equal to current Node's value then found
        2. if value_to_find is lower than Node's value then go to left
        3. if value_to_find is greater than Node's value then go to right
        """
        if value_to_find == self.value:
            return "Found"
        elif value_to_find < self.value and self.left_node:
            return self.left_node.findValue(value_to_find)
        elif value_to_find > self.value and self.right_node:
            return self.right_node.findValue(value_to_find)
        return "Not Found"

    def printTree(self):
        """
        Nodes will be in sequence
        1. Print LHS items
        2. Print value of node
        3. Print RHS items
        """
        if self.left_node:
            self.left_node.printTree()
        print(self.value),
        if self.right_node:
            self.right_node.printTree()

    def isEmpty(self):
        return self.left_node == self.right_node == self.value == None


def main():
    root_node = Node(12)
    root_node.insertValue(6)
    root_node.insertValue(3)
    root_node.insertValue(7)

    # should return 3 6 7 12
    root_node.printTree()

    # should return found
    root_node.findValue(7)
    # should return found
    root_node.findValue(3)
    # should return Not found
    root_node.findValue(24)

if __name__ == '__main__':
    main()
Answered By: Umesh
    def BinaryST(list1,key):
    start = 0
    end = len(list1)
    print("Length of List: ",end)

    for i in range(end):
        for j in range(0, end-i-1):
            if(list1[j] > list1[j+1]):
                temp = list1[j]
                list1[j] = list1[j+1]
                list1[j+1] = temp

    print("Order List: ",list1)

    mid = int((start+end)/2)
    print("Mid Index: ",mid)

    if(key == list1[mid]):
        print(key," is on ",mid," Index")

    elif(key > list1[mid]):
        for rindex in range(mid+1,end):
            if(key == list1[rindex]):
                print(key," is on ",rindex," Index")
                break
            elif(rindex == end-1):
                print("Given key: ",key," is not in List")
                break
            else:
                continue

    elif(key < list1[mid]):
        for lindex in range(0,mid):
            if(key == list1[lindex]):
                print(key," is on ",lindex," Index")
                break
            elif(lindex == mid-1):
                print("Given key: ",key," is not in List")
                break
            else:
                continue


size = int(input("Enter Size of List: "))
list1 = []
for e in range(size):
    ele = int(input("Enter Element in List: "))
    list1.append(ele)

key = int(input("nEnter Key for Search: "))

print("nUnorder List: ",list1)
BinaryST(list1,key)
Answered By: Aezaz Desai
class TreeNode:
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None


class BinaryTree:
    def __init__(self, root=None):
        self.root = root

    def add_node(self, node, value):
        """
        Node points to the left of value if node > value; right otherwise,
        BST cannot have duplicate values
        """
        if node is not None:
            if value < node.value:
                if node.left is None:
                    node.left = TreeNode(value)
                else:
                    self.add_node(node.left, value)
            else:
                if node.right is None:
                    node.right = TreeNode(value)
                else:
                    self.add_node(node.right, value)
        else:
            self.root = TreeNode(value)

    def search(self, value):
        """
        Value will be to the left of node if node > value; right otherwise.
        """
        node = self.root
        while node is not None:
            if node.value == value:
                return True     # node.value
            if node.value > value:
                node = node.left
            else:
                node = node.right
        return False

    def traverse_inorder(self, node):
        """
        Traverse the left subtree of a node as much as possible, then traverse
        the right subtree, followed by the parent/root node.
        """
        if node is not None:
            self.traverse_inorder(node.left)
            print(node.value)
            self.traverse_inorder(node.right)


def main():
    binary_tree = BinaryTree()
    binary_tree.add_node(binary_tree.root, 200)
    binary_tree.add_node(binary_tree.root, 300)
    binary_tree.add_node(binary_tree.root, 100)
    binary_tree.add_node(binary_tree.root, 30)
    binary_tree.traverse_inorder(binary_tree.root)
    print(binary_tree.search(200))


if __name__ == '__main__':
    main()
Answered By: user9652688

A simple, recursive method with only 1 function and using an array of values:

class TreeNode(object):

    def __init__(self, value: int, left=None, right=None):
        super().__init__()
        self.value = value
        self.left = left
        self.right = right

    def __str__(self):
        return str(self.value)


def create_node(values, lower, upper) -> TreeNode:
    if lower > upper:
        return None

    index = (lower + upper) // 2

    value = values[index]
    node = TreeNode(value=value)
    node.left = create_node(values, lower, index - 1)
    node.right = create_node(values, index + 1, upper)

    return node


def print_bst(node: TreeNode):
    if node:
        # Simple pre-order traversal when printing the tree
        print("node: {}".format(node))
        print_bst(node.left)
        print_bst(node.right)



if __name__ == '__main__':
    vals = [0, 1, 2, 3, 4, 5, 6]
    bst = create_node(vals, lower=0, upper=len(vals) - 1)
    print_bst(bst)

As you can see, we really only need 1 method, which is recursive: create_node. We pass in the full values array in each create_node method call, however, we update the lower and upper index values every time that we make the recursive call.

Then, using the lower and upper index values, we calculate the index value of the current node and capture it in value. This value is the value for the current node, which we use to create a node.

From there, we set the values of left and right by recursively calling the function, until we reach the end state of the recursion call when lower is greater than upper.

Important: we update the value of upper when creating the left side of the tree. Conversely, we update the value of lower when creating the right side of the tree.

Hopefully this helps!

Answered By: Virat Singh