urbit/pkg/hs-urbit/lib/Data/Noun/Jam/Fast.hs

396 lines
11 KiB
Haskell
Raw Normal View History

{-# LANGUAGE MagicHash #-}
2019-07-02 05:43:10 +03:00
{-# OPTIONS_GHC -fwarn-unused-binds -fwarn-unused-imports #-}
module Data.Noun.Jam.Fast (jam, jamBS) where
import ClassyPrelude hiding (hash)
import Control.Lens (view, to, from)
import Data.Bits (shiftL, shiftR, setBit, clearBit, xor, (.|.))
import Data.Noun.Atom (Atom(MkAtom), toAtom, bitWidth, takeBitsWord)
import Data.Noun.Atom (wordBitWidth, wordBitWidth# , atomBitWidth#)
import Data.Noun (Noun(Atom, Cell))
import Data.Noun.Pill (bigNatWords, atomBS)
import Data.Vector.Primitive ((!))
import Foreign.Marshal.Alloc (callocBytes, free)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (poke)
import GHC.Integer.GMP.Internals (BigNat)
import GHC.Int (Int(I#))
import GHC.Natural (Natural(NatS#, NatJ#))
import GHC.Prim (Word#, plusWord#, word2Int#)
import GHC.Word (Word(W#))
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Unsafe as BS
import qualified Data.Hashable as Hash
import qualified Data.HashTable.IO as H
import qualified Data.Vector.Primitive as VP
-- Exports ---------------------------------------------------------------------
jamBS :: Noun -> ByteString
jamBS n = doPut bt sz (writeNoun n)
where
(sz, bt) = unsafePerformIO (compress $ toBigNoun n)
2019-07-02 05:43:10 +03:00
jam :: Noun -> Atom
jam = view (to jamBS . from atomBS)
2019-07-02 05:43:10 +03:00
-- Types -----------------------------------------------------------------------
2019-07-02 05:43:10 +03:00
{-|
The encoder state.
2019-07-02 05:43:10 +03:00
- ptr: Pointer into the output buffer.
- reg: Next 64 bits of output, partially written.
- off: Number of bits already written into `reg`
- pos: Total number of bits written.
-}
data S = S
{ ptr :: {-# UNPACK #-} !(Ptr Word)
, reg :: {-# UNPACK #-} !Word
, off :: {-# UNPACK #-} !Int
, pos :: {-# UNPACK #-} !Word
} deriving (Show,Eq,Ord)
data PutResult a = PutResult {-# UNPACK #-} !S !a
deriving Functor
newtype Put a = Put
{ runPut :: H.LinearHashTable Word Word
-> S
-> IO (PutResult a)
}
2019-07-02 05:43:10 +03:00
--------------------------------------------------------------------------------
2019-07-02 05:43:10 +03:00
{-# INLINE getRef #-}
getRef :: Put (Maybe Word)
getRef = Put \tbl s -> PutResult s <$> H.lookup tbl (pos s)
2019-07-02 05:43:10 +03:00
{-
1. Write the register to the output, and increment the output pointer.
-}
{-# INLINE flush #-}
flush :: Put ()
flush = Put $ \tbl s@S{..} -> do
poke ptr reg
pure $ PutResult (s { ptr = ptr `plusPtr` 8 }) ()
2019-07-02 05:43:10 +03:00
{-# INLINE update #-}
update :: (S -> S) -> Put ()
update f = Put \tbl s@S{..} -> pure (PutResult (f s) ())
2019-07-02 05:43:10 +03:00
{-# INLINE setRegOff #-}
setRegOff :: Word -> Int -> Put ()
setRegOff r o = update \s@S{..} -> (s {reg=r, off=o})
2019-07-02 05:43:10 +03:00
{-# INLINE setReg #-}
setReg :: Word -> Put ()
setReg r = update \s@S{..} -> (s { reg=r })
2019-07-02 05:43:10 +03:00
{-# INLINE getS #-}
getS :: Put S
getS = Put $ \tbl s -> pure (PutResult s s)
2019-07-02 05:43:10 +03:00
{-# INLINE putS #-}
putS :: S -> Put ()
putS s = Put $ \tbl _ -> pure (PutResult s ())
2019-07-02 05:43:10 +03:00
{-
To write a bit:
2019-07-02 05:43:10 +03:00
| reg |= 1 << off
| off <- (off + 1) % 64
| if (!off):
| buf[w++] <- reg
| reg <- 0
-}
{-# INLINE writeBit #-}
writeBit :: Bool -> Put ()
writeBit b = Put $ \tbl s@S{..} -> do
let s' = s { reg = (if b then setBit else clearBit) reg off
, off = (off + 1) `mod` 64
, pos = pos + 1
}
if off == 63
then runPut (flush >> setRegOff 0 0) tbl s'
else pure $ PutResult s' ()
{-
2019-07-02 05:43:10 +03:00
To write a 64bit word:
2019-07-02 05:43:10 +03:00
| reg |= w << off
| buf[bufI++] = reg
| reg = w >> (64 - off)
-}
2019-07-02 05:43:10 +03:00
{-# INLINE writeWord #-}
writeWord :: Word -> Put ()
writeWord wor = do
S{..} <- getS
setReg (reg .|. shiftL wor off)
flush
update \s -> s { pos = 64 + pos
, reg = shiftR wor (64 - off)
}
2019-07-02 05:43:10 +03:00
{-
To write some bits (< 64) from a word:
| wor = takeBits(wid, wor)
| reg = reg .|. (wor << off)
| off = (off + wid) % 64
|
| if (off + wid >= 64)
| buf[w] = x
| reg = wor >> (wid - off)
-}
2019-07-02 05:43:10 +03:00
{-# INLINE writeBitsFromWord #-}
writeBitsFromWord :: Int -> Word -> Put ()
writeBitsFromWord wid wor = do
wor <- pure (takeBitsWord wid wor)
oldSt <- getS
2019-07-02 05:43:10 +03:00
let newSt = oldSt { reg = reg oldSt .|. shiftL wor (off oldSt)
, off = (off oldSt + wid) `mod` 64
, pos = fromIntegral wid + pos oldSt
}
2019-07-02 05:43:10 +03:00
putS newSt
2019-07-02 05:43:10 +03:00
when (wid + off oldSt >= 64) $ do
flush
setReg (shiftR wor (wid - off newSt))
{-
2019-07-02 05:43:10 +03:00
Write all of the the signficant bits of a direct atom.
-}
{-# INLINE writeAtomWord# #-}
writeAtomWord# :: Word# -> Put ()
writeAtomWord# w = do
writeBitsFromWord (I# (word2Int# (wordBitWidth# w))) (W# w)
{-# INLINE writeAtomWord #-}
writeAtomWord :: Word -> Put ()
writeAtomWord (W# w) = writeAtomWord# w
{-
Write all of the the signficant bits of an indirect atom.
TODO Use memcpy when the bit-offset of the output is divisible by 8.
-}
{-# INLINE writeAtomBigNat #-}
writeAtomBigNat :: BigNat -> Put ()
writeAtomBigNat (view bigNatWords -> words) = do
let lastIdx = VP.length words - 1
for_ [0..(lastIdx-1)] \i ->
writeWord (words ! i)
writeAtomWord (words ! lastIdx)
{-# INLINE writeAtomBits #-}
writeAtomBits :: Atom -> Put ()
writeAtomBits = \case MkAtom (NatS# wd) -> writeAtomWord# wd
MkAtom (NatJ# bn) -> writeAtomBigNat bn
2019-07-02 05:43:10 +03:00
-- Put Instances ---------------------------------------------------------------
2019-07-02 05:43:10 +03:00
instance Functor Put where
fmap f g = Put $ \tbl s -> do
PutResult s' a <- runPut g tbl s
pure $ PutResult s' (f a)
{-# INLINE fmap #-}
2019-07-02 05:43:10 +03:00
instance Applicative Put where
pure x = Put (\_ s -> return $ PutResult s x)
{-# INLINE pure #-}
2019-07-02 05:43:10 +03:00
Put f <*> Put g = Put $ \tbl s1 -> do
PutResult s2 f' <- f tbl s1
PutResult s3 g' <- g tbl s2
return $ PutResult s3 (f' g')
{-# INLINE (<*>) #-}
2019-07-02 05:43:10 +03:00
Put f *> Put g = Put $ \tbl s1 -> do
PutResult s2 _ <- f tbl s1
g tbl s2
{-# INLINE (*>) #-}
instance Monad Put where
return = pure
{-# INLINE return #-}
(>>) = (*>)
{-# INLINE (>>) #-}
Put x >>= f = Put $ \tbl s -> do
PutResult s' x' <- x tbl s
runPut (f x') tbl s'
{-# INLINE (>>=) #-}
--------------------------------------------------------------------------------
doPut :: H.LinearHashTable Word Word -> Word -> Put () -> ByteString
doPut tbl sz m =
unsafePerformIO $ do
buf <- callocBytes (fromIntegral (wordSz*8))
_ <- runPut (m >> mbFlush) tbl (S buf 0 0 0)
BS.unsafePackCStringFinalizer (castPtr buf) byteSz (free buf)
where
wordSz = fromIntegral (sz `divUp` 64)
byteSz = fromIntegral (sz `divUp` 8)
divUp = \x y -> (x `div` y) + (if x `mod` y == 0 then 0 else 1)
mbFlush :: Put ()
mbFlush = do
shouldFlush <- (/= 0) . off <$> getS
when shouldFlush flush
--------------------------------------------------------------------------------
{-
TODO Handle back references
-}
2019-07-02 05:43:10 +03:00
writeNoun :: Noun -> Put ()
writeNoun n =
getRef >>= \case
Just bk -> writeBackRef bk
Nothing -> case n of Atom a -> writeAtom a
Cell h t -> writeCell h t
{-# INLINE writeMat #-}
writeMat :: Atom -> Put ()
writeMat 0 = writeBit True
writeMat atm = do
writeBitsFromWord (preWid+1) (shiftL 1 preWid)
writeBitsFromWord (preWid-1) atmWid
writeAtomBits atm
where
atmWid = bitWidth atm
preWid = fromIntegral (wordBitWidth atmWid)
{-# INLINE writeCell #-}
writeCell :: Noun -> Noun -> Put ()
writeCell h t = do
writeBit True
writeBit False
writeNoun h
writeNoun t
{-# INLINE writeAtom #-}
writeAtom :: Atom -> Put ()
writeAtom a = do
writeBit False
writeMat a
{-# INLINE writeBackRef #-}
writeBackRef :: Word -> Put ()
writeBackRef a = do
p <- pos <$> getS
writeBit True
writeBit True
writeMat (toAtom a)
-- Compute Hashes and Jam Size (with no backrefs) ------------------------------
data BigNoun
= BigCell { bSize :: {-# UNPACK #-} !Word
, bHash :: {-# UNPACK #-} !Int
, bHead :: BigNoun
, bTail :: BigNoun
}
| BigAtom { bSize :: {-# UNPACK #-} !Word
, bHash :: {-# UNPACK #-} !Int
, bAtom :: {-# UNPACK #-} !Atom
}
deriving (Show)
instance Hashable BigNoun where
hash = bHash
{-# INLINE hash #-}
hashWithSalt = defaultHashWithSalt
{-# INLINE hashWithSalt #-}
instance Eq BigNoun where
BigAtom s1 _ a1 == BigAtom s2 _ a2 = s1==s2 && a1==a2
BigCell s1 _ h1 t1 == BigCell s2 _ h2 t2 = s1==s2 && h1==h2 && t1==t2
_ == _ = False
{-# INLINE (==) #-}
{-# INLINE toBigNoun #-}
toBigNoun :: Noun -> BigNoun
toBigNoun = go
where
go (Atom a) = BigAtom (1 + matSz a) (Hash.hash a) a
go (Cell h t) = BigCell siz has hed tel
where
hed = toBigNoun h
tel = toBigNoun t
siz = 2 + bSize hed + bSize tel
has = fromIntegral siz `combine` bHash hed `combine` bHash tel
-- Calculate Jam Size and Backrefs ---------------------------------------------
{-# INLINE matSz #-}
matSz :: Atom -> Word
matSz a = W# (matSz# a)
{-# INLINE matSz# #-}
matSz# :: Atom -> Word#
matSz# 0 = 1##
matSz# a = preW `plusWord#` preW `plusWord#` atmW
where
atmW = atomBitWidth# a
preW = wordBitWidth# atmW
{-# INLINE refSz# #-}
refSz# :: Word# -> Word#
refSz# w = 2## `plusWord#` (matSz# (MkAtom (NatS# w)))
compress :: BigNoun -> IO (Word, H.LinearHashTable Word Word)
compress top = do
nodes :: H.LinearHashTable BigNoun Word <- H.new
backs :: H.LinearHashTable Word Word <- H.new
let proc :: Word -> BigNoun -> IO Word
proc pos = \case
BigAtom _ _ a -> pure (1 + matSz a)
BigCell _ _ h t -> do
hSz <- go (pos+2) h
tSz <- go (pos+2+hSz) t
pure (2+hSz+tSz)
go :: Word -> BigNoun -> IO Word
go p inp = do
H.lookup nodes inp >>= \case
Nothing -> do
H.insert nodes inp p
proc p inp
Just bak@(W# bakRaw) -> do
let refSz = W# (refSz# bakRaw)
if (refSz < bSize inp)
then H.insert backs p bak $> refSz
else proc p inp
res <- go 0 top
pure (res, backs)
-- Stolen from Hashable Library ------------------------------------------------
{-# INLINE combine #-}
combine :: Int -> Int -> Int
combine h1 h2 = (h1 * 16777619) `xor` h2
{-# INLINE defaultHashWithSalt #-}
defaultHashWithSalt :: Hashable a => Int -> a -> Int
defaultHashWithSalt salt x = salt `combine` Hash.hash x