urbit/pkg/king/lib/Noun/Jam.hs
2019-08-21 16:49:08 -07:00

363 lines
9.4 KiB
Haskell

{-# OPTIONS_GHC -O2 #-}
module Noun.Jam (jam, jamBS) where
import ClassyPrelude hiding (hash)
import Noun.Atom
import Noun.Core
import Control.Lens (from, view)
import Data.Bits (clearBit, setBit, shiftL, shiftR, (.|.))
import Data.Vector.Primitive ((!))
import Foreign.Marshal.Alloc (callocBytes, free)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (poke)
import GHC.Int (Int(I#))
import GHC.Integer.GMP.Internals (BigNat)
import GHC.Natural (Natural(NatJ#, NatS#))
import GHC.Prim (Word#, plusWord#, word2Int#)
import GHC.Word (Word(W#))
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Unsafe as BS
import qualified Data.HashTable.IO as H
import qualified Data.Vector.Primitive as VP
-- Exports ---------------------------------------------------------------------
jamBS :: Noun -> ByteString
jamBS n = doPut bt sz (writeNoun n)
where
(sz, bt) = unsafePerformIO (compress n)
jam :: Noun -> Atom
jam = view (from atomBytes) . jamBS
-- Types -----------------------------------------------------------------------
{-|
The encoder state.
- ptr: Pointer into the output buffer.
- reg: Next 64 bits of output, partially written.
- off: Number of bits already written into `reg`
- pos: Total number of bits written.
-}
data S = S
{ ptr :: {-# UNPACK #-} !(Ptr Word)
, reg :: {-# UNPACK #-} !Word
, off :: {-# UNPACK #-} !Int
, pos :: {-# UNPACK #-} !Word
} deriving (Show,Eq,Ord)
data PutResult a = PutResult {-# UNPACK #-} !S !a
deriving Functor
newtype Put a = Put
{ runPut :: H.CuckooHashTable Word Word
-> S
-> IO (PutResult a)
}
--------------------------------------------------------------------------------
{-# INLINE getRef #-}
getRef :: Put (Maybe Word)
getRef = Put $ \tbl s -> PutResult s <$> H.lookup tbl (pos s)
{-
1. Write the register to the output, and increment the output pointer.
-}
{-# INLINE flush #-}
flush :: Put ()
flush = Put $ \tbl s@S{..} -> do
poke ptr reg
pure $ PutResult (s { ptr = ptr `plusPtr` 8 }) ()
{-# INLINE update #-}
update :: (S -> S) -> Put ()
update f = Put $ \tbl s@S{..} -> pure (PutResult (f s) ())
{-# INLINE setRegOff #-}
setRegOff :: Word -> Int -> Put ()
setRegOff r o = update $ \s@S{..} -> (s {reg=r, off=o})
{-# INLINE setReg #-}
setReg :: Word -> Put ()
setReg r = update $ \s@S{..} -> (s { reg=r })
{-# INLINE getS #-}
getS :: Put S
getS = Put $ \tbl s -> pure (PutResult s s)
{-# INLINE putS #-}
putS :: S -> Put ()
putS s = Put $ \tbl _ -> pure (PutResult s ())
{-
To write a bit:
| reg |= 1 << off
| off <- (off + 1) % 64
| if (!off):
| buf[w++] <- reg
| reg <- 0
-}
{-# INLINE writeBit #-}
writeBit :: Bool -> Put ()
writeBit b = Put $ \tbl s@S{..} -> do
let s' = s { reg = (if b then setBit else clearBit) reg off
, off = (off + 1) `mod` 64
, pos = pos + 1
}
if off == 63
then runPut (flush >> setRegOff 0 0) tbl s'
else pure $ PutResult s' ()
{-
To write a 64bit word:
| reg |= w << off
| buf[bufI++] = reg
| reg = w >> (64 - off)
-}
{-# INLINE writeWord #-}
writeWord :: Word -> Put ()
writeWord wor = do
S{..} <- getS
setReg (reg .|. shiftL wor off)
flush
update $ \s -> s { pos = 64 + pos
, reg = shiftR wor (64 - off)
}
{-
To write some bits (< 64) from a word:
| wor = takeBits(wid, wor)
| reg = reg .|. (wor << off)
| off = (off + wid) % 64
|
| if (off + wid >= 64)
| buf[w] = x
| reg = wor >> (wid - off)
-}
{-# INLINE writeBitsFromWord #-}
writeBitsFromWord :: Int -> Word -> Put ()
writeBitsFromWord wid wor = do
wor <- pure (takeBitsWord wid wor)
oldSt <- getS
let newSt = oldSt { reg = reg oldSt .|. shiftL wor (off oldSt)
, off = (off oldSt + wid) `mod` 64
, pos = fromIntegral wid + pos oldSt
}
putS newSt
when (wid + off oldSt >= 64) $ do
flush
setReg (shiftR wor (wid - off newSt))
{-
Write all of the the signficant bits of a direct atom.
-}
{-# INLINE writeAtomWord# #-}
writeAtomWord# :: Word# -> Put ()
writeAtomWord# w = do
writeBitsFromWord (I# (word2Int# (wordBitWidth# w))) (W# w)
{-# INLINE writeAtomWord #-}
writeAtomWord :: Word -> Put ()
writeAtomWord (W# w) = writeAtomWord# w
{-
Write all of the the signficant bits of an indirect atom.
TODO Use memcpy when the bit-offset of the output is divisible by 8.
-}
{-# INLINE writeAtomBigNat #-}
writeAtomBigNat :: BigNat -> Put ()
writeAtomBigNat !(view bigNatWords -> words) = do
let lastIdx = VP.length words - 1
for_ [0..(lastIdx-1)] $ \i ->
writeWord (words ! i)
writeAtomWord (words ! lastIdx)
{-# INLINE writeAtomBits #-}
writeAtomBits :: Atom -> Put ()
writeAtomBits = \case NatS# wd -> writeAtomWord# wd
NatJ# bn -> writeAtomBigNat bn
-- Put Instances ---------------------------------------------------------------
instance Functor Put where
fmap f g = Put $ \tbl s -> do
PutResult s' a <- runPut g tbl s
pure $ PutResult s' (f a)
{-# INLINE fmap #-}
instance Applicative Put where
pure x = Put (\_ s -> return $ PutResult s x)
{-# INLINE pure #-}
Put f <*> Put g = Put $ \tbl s1 -> do
PutResult s2 f' <- f tbl s1
PutResult s3 g' <- g tbl s2
return $ PutResult s3 (f' g')
{-# INLINE (<*>) #-}
Put f *> Put g = Put $ \tbl s1 -> do
PutResult s2 _ <- f tbl s1
g tbl s2
{-# INLINE (*>) #-}
instance Monad Put where
return = pure
{-# INLINE return #-}
(>>) = (*>)
{-# INLINE (>>) #-}
Put x >>= f = Put $ \tbl s -> do
PutResult s' x' <- x tbl s
runPut (f x') tbl s'
{-# INLINE (>>=) #-}
--------------------------------------------------------------------------------
doPut :: H.CuckooHashTable Word Word -> Word -> Put () -> ByteString
doPut !tbl !sz m =
unsafePerformIO $ do
-- traceM "doPut"
buf <- callocBytes (fromIntegral (wordSz*8))
_ <- runPut (m >> mbFlush) tbl (S buf 0 0 0)
BS.unsafePackCStringFinalizer (castPtr buf) byteSz (free buf)
where
!wordSz = fromIntegral (sz `divUp` 64)
!byteSz = fromIntegral (sz `divUp` 8)
!divUp = \x y -> (x `div` y) + (if x `mod` y == 0 then 0 else 1)
mbFlush :: Put ()
mbFlush = do
shouldFlush <- (/= 0) . off <$> getS
when shouldFlush flush
--------------------------------------------------------------------------------
{-
TODO Handle back references
-}
writeNoun :: Noun -> Put ()
writeNoun !n =
getRef >>= \case
Just bk -> writeBackRef bk
Nothing -> case n of Atom a -> writeAtom a
Cell h t -> writeCell h t
{-# INLINE writeMat #-}
writeMat :: Atom -> Put ()
writeMat 0 = writeBit True
writeMat atm = do
writeBitsFromWord (preWid+1) (shiftL 1 preWid)
writeBitsFromWord (preWid-1) atmWid
writeAtomBits atm
where
atmWid = bitWidth atm
preWid = fromIntegral (wordBitWidth atmWid)
{-# INLINE writeCell #-}
writeCell :: Noun -> Noun -> Put ()
writeCell !h !t = do
writeBit True
writeBit False
writeNoun h
writeNoun t
{-# INLINE writeAtom #-}
writeAtom :: Atom -> Put ()
writeAtom !a = do
writeBit False
writeMat a
{-# INLINE writeBackRef #-}
writeBackRef :: Word -> Put ()
writeBackRef !a = do
p <- pos <$> getS
writeBit True
writeBit True
writeMat (fromIntegral a)
-- Calculate Jam Size and Backrefs ---------------------------------------------
{-# INLINE matSz #-}
matSz :: Atom -> Word
matSz !a = W# (matSz# a)
{-# INLINE matSz# #-}
matSz# :: Atom -> Word#
matSz# 0 = 1##
matSz# a = preW `plusWord#` preW `plusWord#` atmW
where
atmW = atomBitWidth# a
preW = wordBitWidth# atmW
{-# INLINE atomSz #-}
atomSz :: Atom -> Word
atomSz !w = 1 + matSz w
{-# INLINE refSz #-}
refSz :: Word -> Word
refSz !w = 1 + jamWordSz w
{-# INLINE jamWordSz #-}
jamWordSz :: Word -> Word
jamWordSz 0 = 2
jamWordSz (W# w) = 1 + 2*(W# preW) + (W# atmW)
where
atmW = wordBitWidth# w
preW = wordBitWidth# atmW
compress :: Noun -> IO (Word, H.CuckooHashTable Word Word)
compress !top = do
let sz = max 50
$ min 10_000_000
$ (2*) $ (10^) $ floor $ logBase 600 $ fromIntegral $ nounSize top
nodes :: H.BasicHashTable Noun Word <- H.newSized sz
backs :: H.CuckooHashTable Word Word <- H.newSized sz
let proc :: Word -> Noun -> IO Word
proc !pos = \case
Atom a -> pure (atomSz a)
Cell h t -> do !hSz <- go (pos+2) h
!tSz <- go (pos+2+hSz) t
pure (2+hSz+tSz)
go :: Word -> Noun -> IO Word
go !p !inp = do
H.lookup nodes inp >>= \case
Nothing -> do
H.insert nodes inp p
proc p inp
Just bak -> do
let rs = refSz bak
doRef = H.insert backs p bak $> rs
noRef = proc p inp
case inp of
Cell _ _ -> doRef
Atom a | rs < atomSz (fromIntegral a) -> doRef
_ -> noRef
res <- go 0 top
pure (res, backs)