shrub/pkg/hs-urbit/lib/Noun/Jam/Fast.hs
Benjamin Summers 0d057747cc Jam/Cue: Tuning
2019-07-04 15:40:36 -07:00

388 lines
10 KiB
Haskell

{-# LANGUAGE MagicHash #-}
{-# OPTIONS_GHC -fwarn-unused-binds -fwarn-unused-imports #-}
module Noun.Jam.Fast (jam, jamBS, jamFat, jamFatBS) where
import ClassyPrelude hiding (hash)
import Control.Lens (view, to, from)
import Data.Bits (shiftL, shiftR, setBit, clearBit, xor, (.|.))
import Noun.Atom (Atom(MkAtom), toAtom, bitWidth, takeBitsWord)
import Noun.Atom (wordBitWidth, wordBitWidth# , atomBitWidth#)
import Noun (Noun(Atom, Cell))
import Noun.Fat
import Noun.Pill (bigNatWords, atomBS)
import Data.Vector.Primitive ((!))
import Foreign.Marshal.Alloc (callocBytes, free)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (poke)
import GHC.Integer.GMP.Internals (BigNat)
import GHC.Int (Int(I#))
import GHC.Natural (Natural(NatS#, NatJ#))
import GHC.Prim (Word#, plusWord#, word2Int#)
import GHC.Word (Word(W#))
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Unsafe as BS
import qualified Data.Hashable as Hash
import qualified Data.HashTable.IO as H
import qualified Data.Vector.Primitive as VP
-- Exports ---------------------------------------------------------------------
jamFatBS :: FatNoun -> ByteString
jamFatBS n = doPut bt sz (writeNoun n)
where
(sz, bt) = unsafePerformIO (compress n)
jamFat :: FatNoun -> Atom
jamFat = view (from atomBS) . jamFatBS
jamBS :: Noun -> ByteString
jamBS = jamFatBS . toFatNoun
jam :: Noun -> Atom
jam = jamFat . toFatNoun
-- Types -----------------------------------------------------------------------
{-|
The encoder state.
- ptr: Pointer into the output buffer.
- reg: Next 64 bits of output, partially written.
- off: Number of bits already written into `reg`
- pos: Total number of bits written.
-}
data S = S
{ ptr :: {-# UNPACK #-} !(Ptr Word)
, reg :: {-# UNPACK #-} !Word
, off :: {-# UNPACK #-} !Int
, pos :: {-# UNPACK #-} !Word
} deriving (Show,Eq,Ord)
data PutResult a = PutResult {-# UNPACK #-} !S !a
deriving Functor
newtype Put a = Put
{ runPut :: H.CuckooHashTable Word Word
-> S
-> IO (PutResult a)
}
--------------------------------------------------------------------------------
{-# INLINE getRef #-}
getRef :: Put (Maybe Word)
getRef = Put \tbl s -> PutResult s <$> H.lookup tbl (pos s)
{-
1. Write the register to the output, and increment the output pointer.
-}
{-# INLINE flush #-}
flush :: Put ()
flush = Put $ \tbl s@S{..} -> do
poke ptr reg
pure $ PutResult (s { ptr = ptr `plusPtr` 8 }) ()
{-# INLINE update #-}
update :: (S -> S) -> Put ()
update f = Put \tbl s@S{..} -> pure (PutResult (f s) ())
{-# INLINE setRegOff #-}
setRegOff :: Word -> Int -> Put ()
setRegOff r o = update \s@S{..} -> (s {reg=r, off=o})
{-# INLINE setReg #-}
setReg :: Word -> Put ()
setReg r = update \s@S{..} -> (s { reg=r })
{-# INLINE getS #-}
getS :: Put S
getS = Put $ \tbl s -> pure (PutResult s s)
{-# INLINE putS #-}
putS :: S -> Put ()
putS s = Put $ \tbl _ -> pure (PutResult s ())
{-
To write a bit:
| reg |= 1 << off
| off <- (off + 1) % 64
| if (!off):
| buf[w++] <- reg
| reg <- 0
-}
{-# INLINE writeBit #-}
writeBit :: Bool -> Put ()
writeBit b = Put $ \tbl s@S{..} -> do
let s' = s { reg = (if b then setBit else clearBit) reg off
, off = (off + 1) `mod` 64
, pos = pos + 1
}
if off == 63
then runPut (flush >> setRegOff 0 0) tbl s'
else pure $ PutResult s' ()
{-
To write a 64bit word:
| reg |= w << off
| buf[bufI++] = reg
| reg = w >> (64 - off)
-}
{-# INLINE writeWord #-}
writeWord :: Word -> Put ()
writeWord wor = do
S{..} <- getS
setReg (reg .|. shiftL wor off)
flush
update \s -> s { pos = 64 + pos
, reg = shiftR wor (64 - off)
}
{-
To write some bits (< 64) from a word:
| wor = takeBits(wid, wor)
| reg = reg .|. (wor << off)
| off = (off + wid) % 64
|
| if (off + wid >= 64)
| buf[w] = x
| reg = wor >> (wid - off)
-}
{-# INLINE writeBitsFromWord #-}
writeBitsFromWord :: Int -> Word -> Put ()
writeBitsFromWord wid wor = do
wor <- pure (takeBitsWord wid wor)
oldSt <- getS
let newSt = oldSt { reg = reg oldSt .|. shiftL wor (off oldSt)
, off = (off oldSt + wid) `mod` 64
, pos = fromIntegral wid + pos oldSt
}
putS newSt
when (wid + off oldSt >= 64) $ do
flush
setReg (shiftR wor (wid - off newSt))
{-
Write all of the the signficant bits of a direct atom.
-}
{-# INLINE writeAtomWord# #-}
writeAtomWord# :: Word# -> Put ()
writeAtomWord# w = do
writeBitsFromWord (I# (word2Int# (wordBitWidth# w))) (W# w)
{-# INLINE writeAtomWord #-}
writeAtomWord :: Word -> Put ()
writeAtomWord (W# w) = writeAtomWord# w
{-
Write all of the the signficant bits of an indirect atom.
TODO Use memcpy when the bit-offset of the output is divisible by 8.
-}
{-# INLINE writeAtomBigNat #-}
writeAtomBigNat :: BigNat -> Put ()
writeAtomBigNat (view bigNatWords -> words) = do
let lastIdx = VP.length words - 1
for_ [0..(lastIdx-1)] \i ->
writeWord (words ! i)
writeAtomWord (words ! lastIdx)
{-# INLINE writeAtomBits #-}
writeAtomBits :: Atom -> Put ()
writeAtomBits = \case MkAtom (NatS# wd) -> writeAtomWord# wd
MkAtom (NatJ# bn) -> writeAtomBigNat bn
-- Put Instances ---------------------------------------------------------------
instance Functor Put where
fmap f g = Put $ \tbl s -> do
PutResult s' a <- runPut g tbl s
pure $ PutResult s' (f a)
{-# INLINE fmap #-}
instance Applicative Put where
pure x = Put (\_ s -> return $ PutResult s x)
{-# INLINE pure #-}
Put f <*> Put g = Put $ \tbl s1 -> do
PutResult s2 f' <- f tbl s1
PutResult s3 g' <- g tbl s2
return $ PutResult s3 (f' g')
{-# INLINE (<*>) #-}
Put f *> Put g = Put $ \tbl s1 -> do
PutResult s2 _ <- f tbl s1
g tbl s2
{-# INLINE (*>) #-}
instance Monad Put where
return = pure
{-# INLINE return #-}
(>>) = (*>)
{-# INLINE (>>) #-}
Put x >>= f = Put $ \tbl s -> do
PutResult s' x' <- x tbl s
runPut (f x') tbl s'
{-# INLINE (>>=) #-}
--------------------------------------------------------------------------------
doPut :: H.CuckooHashTable Word Word -> Word -> Put () -> ByteString
doPut tbl sz m =
unsafePerformIO $ do
-- traceM "doPut"
buf <- callocBytes (fromIntegral (wordSz*8))
_ <- runPut (m >> mbFlush) tbl (S buf 0 0 0)
BS.unsafePackCStringFinalizer (castPtr buf) byteSz (free buf)
where
wordSz = fromIntegral (sz `divUp` 64)
byteSz = fromIntegral (sz `divUp` 8)
divUp = \x y -> (x `div` y) + (if x `mod` y == 0 then 0 else 1)
mbFlush :: Put ()
mbFlush = do
shouldFlush <- (/= 0) . off <$> getS
when shouldFlush flush
--------------------------------------------------------------------------------
{-
TODO Handle back references
-}
writeNoun :: FatNoun -> Put ()
writeNoun n =
getRef >>= \case
Just bk -> writeBackRef bk
Nothing -> case n of FatAtom _ n -> writeAtom (MkAtom $ NatJ# n)
FatWord (W# w) -> writeAtom (MkAtom $ NatS# w)
FatCell _ _ h t -> writeCell h t
{-# INLINE writeMat #-}
writeMat :: Atom -> Put ()
writeMat 0 = writeBit True
writeMat atm = do
writeBitsFromWord (preWid+1) (shiftL 1 preWid)
writeBitsFromWord (preWid-1) atmWid
writeAtomBits atm
where
atmWid = bitWidth atm
preWid = fromIntegral (wordBitWidth atmWid)
{-# INLINE writeCell #-}
writeCell :: FatNoun -> FatNoun -> Put ()
writeCell h t = do
writeBit True
writeBit False
writeNoun h
writeNoun t
{-# INLINE writeAtom #-}
writeAtom :: Atom -> Put ()
writeAtom a = do
writeBit False
writeMat a
{-# INLINE writeBackRef #-}
writeBackRef :: Word -> Put ()
writeBackRef a = do
p <- pos <$> getS
writeBit True
writeBit True
writeMat (toAtom a)
-- Calculate Jam Size and Backrefs ---------------------------------------------
{-# INLINE matSz #-}
matSz :: Atom -> Word
matSz a = W# (matSz# a)
{-# INLINE matSz# #-}
matSz# :: Atom -> Word#
matSz# 0 = 1##
matSz# a = preW `plusWord#` preW `plusWord#` atmW
where
atmW = atomBitWidth# a
preW = wordBitWidth# atmW
{-# INLINE atomSz #-}
atomSz :: Atom -> Word
atomSz = (1+) . matSz
{-# INLINE refSz #-}
refSz :: Word -> Word
refSz = (1+) . jamWordSz
{-# INLINE jamWordSz #-}
jamWordSz :: Word -> Word
jamWordSz 0 = 2
jamWordSz (W# w) = 1 + 2*(W# preW) + (W# atmW)
where
atmW = wordBitWidth# w
preW = wordBitWidth# atmW
compress :: FatNoun -> IO (Word, H.CuckooHashTable Word Word)
compress top = do
let sz = max 50
$ min 10_000_000
$ 2 * (10 ^ (floor $ logBase 600 (fromIntegral $ fatSize top)))
nodes :: H.BasicHashTable FatNoun Word <- H.newSized sz
backs :: H.CuckooHashTable Word Word <- H.newSized sz
let proc :: Word -> FatNoun -> IO Word
proc pos = \case
n@(FatAtom _ a) -> pure $ atomSz (MkAtom (NatJ# a))
FatWord w -> pure (jamWordSz w)
FatCell _ _ h t -> do
!hSz <- go (pos+2) h
!tSz <- go (pos+2+hSz) t
pure (2+hSz+tSz)
go :: Word -> FatNoun -> IO Word
go p inp = do
H.lookup nodes inp >>= \case
Nothing -> do
H.insert nodes inp p
proc p inp
Just bak -> do
let rs = refSz bak
doRef = H.insert backs p bak $> rs
noRef = proc p inp
case inp of
FatCell _ _ _ _ -> doRef
FatWord w | rs < atomSz (fromIntegral w) -> doRef
FatAtom _ a | rs < atomSz (MkAtom (NatJ# a)) -> doRef
_ -> noRef
res <- go 0 top
pure (res, backs)
-- Stolen from Hashable Library ------------------------------------------------
{-# INLINE combine #-}
combine :: Int -> Int -> Int
combine h1 h2 = (h1 * 16777619) `xor` h2
{-# INLINE defaultHashWithSalt #-}
defaultHashWithSalt :: Hashable a => Int -> a -> Int
defaultHashWithSalt salt x = salt `combine` Hash.hash x