module Main where

data AVL e = E           -- empty tree
           | N           -- non-empty tree
               Int           -- balance factor (right size - left size)
               (AVL e)       -- left subtree
               e             -- value
               (AVL e)       -- right subtree
  deriving Show

avlLookup :: Ord e => e -> AVL e -> Bool
avlLookup x E = False
avlLookup x (N _ t1 e t2) = 
  case compare x e of
    LT -> avlLookup x t1
    EQ -> True
    GT -> avlLookup x t2

avlToList :: AVL e -> [e]
avlToList E             = []
avlToList (N _ t1 e t2) = avlToList t1 ++ [e] ++ avlToList t2

increasing :: Ord e => [e] -> Bool
increasing []         = True
increasing [e]        = True
increasing (e1:e2:es) = e1 < e2 && increasing (e2:es)

height :: AVL e -> Int
height E             = 0
height (N _ t1 e t2) = 1 + (height t1 `max` height t2)

check :: (Ord e, Show e) => AVL e -> IO ()
check E = 
  return ()
check (t @ (N bf t1 e t2)) =
  if bf /= height t2 - height t1 then
    putStr (show t ++ " has incorrect balance factor (should be " 
            ++ show (height t2 - height t1) ++ ") at the root\n") 
  else if not (increasing (avlToList t)) then 
    putStr (show t ++ " is not correctly sorted\n") 
  else if bf < -1 || bf > 1 then
    putStr (show t ++ " is unbalanced\n") 
  else 
    do check t1
       check t2
    
main = 
  do putStr "Running tests...\n"
     check (E :: AVL Int)
     check (N 0 E 5 E)
     check (N 0 E 5 E)
     check (N 0 (N 0 E 6 E) 5 E)
     check (N 0 (N 0 E 4 E) 5 E)
