module Main where

import Test.QuickCheck
import Data.List (transpose)

-----------------------
-- Signals

-- A signal is a list of booleans.  By convention, all signals are infinite.
newtype Signal = Sig [Bool]

lift0 :: Bool -> Signal
lift0 a = Sig $ repeat a

lift1 :: (Bool->Bool) -> Signal -> Signal
lift1 f (Sig s) = Sig $ map f s

lift2 :: (Bool->Bool->Bool) -> (Signal, Signal) -> Signal
lift2 f (Sig xs, Sig ys) = Sig $ zipWith f xs ys

lift22 :: (Bool->Bool->(Bool,Bool)) -> (Signal, Signal) -> (Signal,Signal)
lift22 f (Sig xs, Sig ys) = 
  let (zs1,zs2) = unzip (zipWith f xs ys)
  in (Sig zs1, Sig zs2) 

lift3 :: (Bool->Bool->Bool->Bool) -> (Signal, Signal, Signal) -> Signal
lift3 f (Sig xs, Sig ys, Sig zs) = Sig $ zipWith3 f xs ys zs

------------------------
-- Simulation

-- Truncate a signal to a short prefix, for testing or printing
truncatedSignalSize = 20
truncateSig bs = take truncatedSignalSize bs

instance Show Signal where
  show (Sig s) = show (truncateSig s) ++ "..."

trace :: [(String,Signal)] -> Int -> IO ()
trace desc count = 
  let (names, wires) = unzip desc
      loop ws n = 
        if n>=count
          then return () 
          else do let vals  = map head ws
                      ws'   = map tail ws 
                  mapM_ (\ (n,b) -> do putStr (take (length n - 1) (repeat ' '))
                                       putStr (if b then "1" else "0")
                                       putStr " ")
                        (zip names vals)
                  putStr "\n"
                  loop ws' (n+1)
  in do mapM_ (\n -> putStr (n++" ")) names
        putStr "\n"
        loop (map (\ (Sig w) -> w) wires) 0

probe :: [(String,Signal)] -> IO ()
probe desc = trace desc 1

simulate :: [(String,Signal)] -> IO ()
simulate desc = trace desc 20

------------------------
-- Testing support (QuickCheck helpers)

instance Arbitrary Signal where
  arbitrary = 
    do x <- arbitrary
       (Sig xs) <- arbitrary
       return $ Sig (x : xs)
  coarbitrary = error "Not implemented"

arbitraryListOfSize 0 = return []
arbitraryListOfSize n = do x <- arbitrary
                           xs <- arbitraryListOfSize (n-1)
                           return (x:xs)

class Agreeable b where
  agreesWith :: b -> b -> Bool

instance Agreeable Signal where
  (Sig as) `agreesWith` (Sig bs) = 
    all (\x->x) (zipWith (==) (truncateSig as) (truncateSig bs))

instance (Agreeable a, Agreeable b) => Agreeable (a,b) where
  (a1,b1) `agreesWith` (a2,b2) = 
    (a1 `agreesWith` a2) && (b1 `agreesWith` b2)

instance Agreeable a => Agreeable [a] where
  as `agreesWith` bs = all (\x->x) (zipWith agreesWith as bs)

sampleAt n (Sig b) = b !! n

sampleAtN n signals = map (sampleAt n) signals

sample = sampleAt 0
sampleN = sampleAtN 0

binary [] = 0
binary (b:bs) = (if b then 1 else 0) + (2 * binary bs)

qc :: Testable a => a -> IO ()
qc = check (defaultConfig { configEvery = \n args -> "" })

------------------------
-- Basic gates

high = lift0 True
low = lift0 False

xor2 :: (Signal, Signal) -> Signal
xor2 = lift2 (\x y -> (x && not y) || (not x && y)) 

and2 :: (Signal, Signal) -> Signal
and2 = lift2 (\x y -> x && y) 

explicit l = Sig $ l ++ (repeat False)

str l = explicit (map (\c -> if c=='1' then True else False) l)

delay init (Sig xs) = Sig $ init : xs

mux :: (Signal, Signal, Signal) -> Signal
mux = lift3 (\b1 b2 select -> if select then b1 else b2)

demux :: (Signal, Signal) -> (Signal, Signal)
demux args = 
  lift22 (\i select -> if select then (i,False) else (False,i)) args

muxN :: ([Signal], [Signal], Signal) -> [Signal]
muxN (b1,b2,sel) = map (\ (bb1,bb2) -> mux (bb1,bb2,sel)) (zip b1 b2)

demuxN :: ([Signal], Signal) -> ([Signal], [Signal])
demuxN (b,sel) = unzip (map (\bb -> demux (bb,sel)) b)

------------------------------------------------------------------------------
-- Combinational circuits

halfadd :: (Signal, Signal) -> (Signal, Signal)
halfadd (x,y) = (sum,cout)
                where sum  = xor2 (x,y)
                      cout = and2 (x,y)
              
prop_halfadd_commut b1 b2 =
    halfadd (lift0 b1, lift0 b2) `agreesWith` halfadd (lift0 b2, lift0 b1) 

main = qc prop_halfadd_commut

fulladd (cin,x,y) = (sum,cout)
                    where (sum1,c1) = halfadd (x,y)
                          (sum,c2)  = halfadd (cin, sum1)
                          cout      = xor2 (c1,c2)

test1a = probe [("cin",cin), ("x",x), ("y",y), ("  sum",sum), ("cout",cout)]
  where
    (sum,cout) = fulladd (cin, x, y)
    cin = high
    x = low
    y = high

bitAdder :: (Signal, [Signal]) -> ([Signal], Signal)
bitAdder (cin, []) = ([], cin)
bitAdder (cin, (x:xs)) = (sum:sums, cout)
                         where (sum, c)    = halfadd (cin,x)
                               (sums,cout) = bitAdder (c,xs)

test1 = probe [("cin",cin), ("in1",in1), ("in2",in2), ("in3",in3), ("in4",in4),
               ("  s1",s1), ("s2",s2), ("s3",s3), ("s4",s4), ("c",c)]
  where
    cin = high
    in1 = high
    in2 = high
    in3 = low
    in4 = high
    ([s1,s2,s3,s4], c) = bitAdder (cin, [in1,in2,in3,in4])

prop_bitAdder_Correct cin xs =
    let (out,cout) = bitAdder (cin, map lift0 xs) in
         binary (sampleN out ++ [sample cout]) 
      == binary xs + (if sample cin then 1 else 0)

-- main = qc prop_bitAdder_Correct

adder :: ([Signal], [Signal]) -> [Signal]
adder (xs, ys) = 
   let (sums,cout) = adderAux (low, xs, ys)
   in sums ++ [cout]
   where                                        
     adderAux (cin, [], [])     = ([], cin)
     adderAux (cin, x:xs, y:ys) = (sum:sums, cout)
                                  where (sum, c) = fulladd (cin,x,y)
                                        (sums,cout) = adderAux (c,xs,ys)
     adderAux (cin, [], ys)     = adderAux (cin, [low], ys)
     adderAux (cin, xs, [])     = adderAux (cin, xs, [low])

test2 = 
  let
    xs@[x1,x2,x3,x4] = [high,high,low,low]
    ys@[y1,y2,y3,y4] = [high,low,low,low]
    [s1,s2,s3,s4,c] = adder (xs, ys)
  in 
     probe [ ("x1",x1), ("x2",x2), ("x3",x3), ("x4",x4),
             (" y1",y1), ("y2",y2), ("y3",y3), ("y4",y4), 
             (" s1",s1), ("s2",s2), ("s3",s3), ("s4",s4), (" c",c) ]

prop_Adder_Correct l1 l2 =
    let sum = adder (map lift0 l1, map lift0 l2) in
    binary (sampleN sum) == binary l1 + binary l2

-- main = qc prop_Adder_Correct

------------------------
-- First circuit with memory: A simple toggle 

toggle :: (Signal) -> (Signal)
toggle change = out
                where out' = delay False out
                      out  = xor2 (change,out')

test3a =
  let
    change                = str "1011100" 
    out                   = toggle change
  in 
     simulate [
        ("change",change),
        ("out",out)
    ]

toggleSpec :: (Signal) -> (Signal)
toggleSpec (Sig change) =
  Sig $
  map (\n -> 
         let numberOfToggles = length (filter (\x->x) (take (n+1) change))
             parityOfToggles = (numberOfToggles `mod` 2) == 1
         in parityOfToggles)
    [0,1..]

prop_Toggle_Correct change =
        toggleSpec change `agreesWith` toggle change

-- main = qc prop_Toggle_Correct

------------------------
-- A counter

delayN []     xs      = []
delayN (i:is) ~(x:xs) = (delay i x) : (delayN is xs)

test3e = 
  let
    cin                   = high
    out                   = delayN [False,False,False,False] out'
    (out',cout)           = bitAdder (cin, out)
    [out1,out2,out3,out4] = out'
  in 
     simulate [
        ("cin",cin),
        (" out1",out1),
        ("out2",out2),
        ("out3",out3),
        ("out4",out4),
        (" cout",cout)
    ]

counter n = 
  out' 
  where out                   = delayN (take n (repeat False)) out'
        (out',cout)           = bitAdder (high, out)

test4 = 
  let
    ([out1,out2,out3]) = counter 3
  in 
     simulate [
        ("out1",out1),
        ("out2",out2),
        ("out3",out3)
    ]

-------------------
-- Various memory circuits 

-- A 1-bit register
reg1 set input =
  output 
  where output = delay False b
        b      = mux (input, output, set)

test5 = 
  let
    input = str "1010001"
    set   = str "0011011"
    out   = reg1 set input
  in 
     simulate [
        ("input",input),
        ("set",set),
        ("   out",out)
    ]

-- An n-bit register
reg set []     = []
reg set (i:is) = reg1 set i : reg set is

test6 = 
  let
    i1  = str "1000001"
    i2  = str "1110101"
    set = str "0011111"
    [out1,out2] = reg set [i1,i2]
  in 
     simulate [
        ("i1",i1),
        ("i2",i2),
        ("set",set),
        ("   out1",out1),
        ("out2",out2)
    ]

-- A 1-bit, a-wide memory
memory1 [] i set     = reg1 set i
memory1 (a:as) i set = 
  mux (out0, out1, a)
  where out0 = memory1 as i set0
        out1 = memory1 as i set1
        (set0,set1) = demux (i, a)

test7 = 
  let
    a1  = str "111"
    a2  = str "101"
    i1  = str "1000001"
    set = str "0011111"
    out = memory1 [a1,a2] i1 set 
  in 
     simulate [
        ("i1",i1),
        ("a1",a1),
        ("a2",a2),
        ("set",set),
        ("out",out)
    ]

-- A multi-bit, a-wide memory
memory ([], is, set)     = reg set is
memory (a:as, is, set) = 
  muxN (out0, out1, a)
  where out0 = memory (as, is, set0)
        out1 = memory (as, is, set1)
        (set0,set1) = demux (set, a)

test8 = 
  let
    a   = str "10111111"
    i   = str "11100"
    set = str "00100"
    [out] = memory ([a], [i], set)
  in 
     simulate [
        ("i",i),
        ("a",a),
        ("set",set),
        ("out",out)
    ]

test9 = 
  let
    a2  = a1
    a1  = str "10111"
    i1  = str "11100"
    set = str "00100"
    [out] = memory ([a1,a2], [i1], set)
  in 
     simulate [
        ("i1",i1),
        ("a1",a1),
        ("a2",a2),
        ("set",set),
        ("out",out)
    ]

