mirror of
https://github.com/unisonweb/unison.git
synced 2024-09-20 23:07:13 +03:00
simplify throwing decoding error ceremony a bit
This commit is contained in:
parent
4adc35016c
commit
ef3282b381
@ -17,7 +17,7 @@ module U.Codebase.Sqlite.Operations
|
||||
saveDeclComponent,
|
||||
loadDeclComponent,
|
||||
loadDeclByReference,
|
||||
getDeclTypeById,
|
||||
expectDeclTypeById,
|
||||
|
||||
-- * terms/decls
|
||||
getCycleLen,
|
||||
@ -65,10 +65,6 @@ module U.Codebase.Sqlite.Operations
|
||||
saveBranchObject,
|
||||
saveDbPatch,
|
||||
|
||||
-- * Error types
|
||||
Error (..),
|
||||
DecodeError (..),
|
||||
|
||||
-- * somewhat unexpectedly unused definitions
|
||||
c2sReferenceId,
|
||||
c2sReferentId,
|
||||
@ -84,38 +80,23 @@ where
|
||||
|
||||
import Control.Lens (Lens')
|
||||
import qualified Control.Lens as Lens
|
||||
import Control.Monad (join, unless, when, (<=<))
|
||||
import Control.Monad.Except (MonadIO (liftIO))
|
||||
import qualified Control.Monad.Extra as Monad
|
||||
import Control.Monad.State (MonadState, evalStateT)
|
||||
import Control.Monad.Trans (MonadTrans (lift))
|
||||
import Control.Monad.Trans.Maybe (MaybeT (MaybeT), runMaybeT)
|
||||
import Control.Monad.Writer (MonadWriter, runWriterT)
|
||||
import qualified Control.Monad.Writer as Writer
|
||||
import Data.Bifunctor (Bifunctor (bimap))
|
||||
import Data.Bitraversable (Bitraversable (bitraverse))
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Bytes.Get (runGetS)
|
||||
import qualified Data.Bytes.Get as Get
|
||||
import Data.Foldable (traverse_)
|
||||
import qualified Data.Foldable as Foldable
|
||||
import Data.Functor (void, (<&>))
|
||||
import Data.Functor.Identity (Identity)
|
||||
import Data.Map (Map)
|
||||
import qualified Data.Map as Map
|
||||
import qualified Data.Map.Merge.Lazy as Map
|
||||
import Data.Maybe (isJust)
|
||||
import Data.Sequence (Seq)
|
||||
import qualified Data.Sequence as Seq
|
||||
import Data.Set (Set)
|
||||
import qualified Data.Set as Set
|
||||
import Data.Text (Text)
|
||||
import qualified Data.Text as Text
|
||||
import Data.Traversable (for)
|
||||
import Data.Tuple.Extra (uncurry3)
|
||||
import qualified Data.Vector as Vector
|
||||
import Data.Word (Word64)
|
||||
import Debug.Trace
|
||||
import qualified U.Codebase.Branch as C.Branch
|
||||
import qualified U.Codebase.Causal as C
|
||||
import U.Codebase.Decl (ConstructorId)
|
||||
@ -180,6 +161,7 @@ import qualified U.Util.Monoid as Monoid
|
||||
import U.Util.Serialization (Get)
|
||||
import qualified U.Util.Serialization as S
|
||||
import qualified U.Util.Term as TermUtil
|
||||
import Unison.Prelude
|
||||
import Unison.Sqlite
|
||||
import qualified Unison.Util.Map as Map
|
||||
import qualified Unison.Util.Set as Set
|
||||
@ -189,24 +171,10 @@ import qualified Unison.Util.Set as Set
|
||||
debug :: Bool
|
||||
debug = False
|
||||
|
||||
type ErrString = String
|
||||
|
||||
data DecodeError
|
||||
= ErrTermFormat
|
||||
| ErrDeclFormat
|
||||
| ErrTermElement Word64
|
||||
| ErrDeclElement Word64
|
||||
| ErrFramedArrayLen
|
||||
| ErrTypeOfTerm C.Reference.Id
|
||||
| ErrWatch WatchKind C.Reference.Id
|
||||
| ErrBranch Db.BranchObjectId
|
||||
| ErrPatch Db.PatchObjectId
|
||||
| ErrObjectDependencies OT.ObjectType Db.ObjectId
|
||||
deriving (Show)
|
||||
|
||||
-- TODO rename
|
||||
data Error
|
||||
= DecodeError DecodeError ByteString ErrString
|
||||
data DecodeError = DecodeError
|
||||
{ decoder :: Text, -- the name of the decoder
|
||||
err :: String -- the error message
|
||||
}
|
||||
deriving stock (Show)
|
||||
deriving anyclass (SqliteExceptionReason)
|
||||
|
||||
@ -215,9 +183,9 @@ newtype NeedTypeForBuiltinMetadata
|
||||
deriving stock (Show)
|
||||
deriving anyclass (SqliteExceptionReason)
|
||||
|
||||
getFromBytesOr :: DecodeError -> Get a -> ByteString -> Either Error a
|
||||
getFromBytesOr e get bs = case runGetS get bs of
|
||||
Left err -> Left (DecodeError e bs err)
|
||||
getFromBytesOr :: Text -> Get a -> ByteString -> Either DecodeError a
|
||||
getFromBytesOr decoder get bs = case runGetS get bs of
|
||||
Left err -> Left (DecodeError decoder err)
|
||||
Right a -> Right a
|
||||
|
||||
-- * Database lookups
|
||||
@ -354,26 +322,34 @@ diffPatch (S.Patch fullTerms fullTypes) (S.Patch refTerms refTypes) =
|
||||
|
||||
-- * Deserialization helpers
|
||||
|
||||
decodeTermFormat :: ByteString -> Either Error S.Term.TermFormat
|
||||
decodeTermFormat = getFromBytesOr ErrTermFormat S.getTermFormat
|
||||
decodeBranchFormat :: ByteString -> Either DecodeError S.BranchFormat.BranchFormat
|
||||
decodeBranchFormat = getFromBytesOr "getBranchFormat" S.getBranchFormat
|
||||
|
||||
decodeComponentLengthOnly :: ByteString -> Either Error Word64
|
||||
decodeComponentLengthOnly = getFromBytesOr ErrFramedArrayLen (Get.skip 1 >> S.lengthFramedArray)
|
||||
decodePatchFormat :: ByteString -> Either DecodeError S.Patch.Format.PatchFormat
|
||||
decodePatchFormat = getFromBytesOr "getPatchFormat" S.getPatchFormat
|
||||
|
||||
decodeTermElementWithType :: C.Reference.Pos -> ByteString -> Either Error (LocalIds, S.Term.Term, S.Term.Type)
|
||||
decodeTermElementWithType i = getFromBytesOr (ErrTermElement i) (S.lookupTermElement i)
|
||||
decodeTermFormat :: ByteString -> Either DecodeError S.Term.TermFormat
|
||||
decodeTermFormat = getFromBytesOr "getTermFormat" S.getTermFormat
|
||||
|
||||
decodeTermElementDiscardingTerm :: C.Reference.Pos -> ByteString -> Either Error (LocalIds, S.Term.Type)
|
||||
decodeTermElementDiscardingTerm i = getFromBytesOr (ErrTermElement i) (S.lookupTermElementDiscardingTerm i)
|
||||
decodeComponentLengthOnly :: ByteString -> Either DecodeError Word64
|
||||
decodeComponentLengthOnly = getFromBytesOr "lengthFramedArray" (Get.skip 1 >> S.lengthFramedArray)
|
||||
|
||||
decodeTermElementDiscardingType :: C.Reference.Pos -> ByteString -> Either Error (LocalIds, S.Term.Term)
|
||||
decodeTermElementDiscardingType i = getFromBytesOr (ErrTermElement i) (S.lookupTermElementDiscardingType i)
|
||||
decodeTermElementWithType :: C.Reference.Pos -> ByteString -> Either DecodeError (LocalIds, S.Term.Term, S.Term.Type)
|
||||
decodeTermElementWithType i = getFromBytesOr ("lookupTermElement" <> tShow i) (S.lookupTermElement i)
|
||||
|
||||
decodeDeclFormat :: ByteString -> Either Error S.Decl.DeclFormat
|
||||
decodeDeclFormat = getFromBytesOr ErrDeclFormat S.getDeclFormat
|
||||
decodeTermElementDiscardingTerm :: C.Reference.Pos -> ByteString -> Either DecodeError (LocalIds, S.Term.Type)
|
||||
decodeTermElementDiscardingTerm i =
|
||||
getFromBytesOr ("lookupTermElementDiscardingTerm " <> tShow i) (S.lookupTermElementDiscardingTerm i)
|
||||
|
||||
decodeDeclElement :: Word64 -> ByteString -> Either Error (LocalIds, S.Decl.Decl Symbol)
|
||||
decodeDeclElement i = getFromBytesOr (ErrDeclElement i) (S.lookupDeclElement i)
|
||||
decodeTermElementDiscardingType :: C.Reference.Pos -> ByteString -> Either DecodeError (LocalIds, S.Term.Term)
|
||||
decodeTermElementDiscardingType i =
|
||||
getFromBytesOr ("lookupTermElementDiscardingType " <> tShow i) (S.lookupTermElementDiscardingType i)
|
||||
|
||||
decodeDeclFormat :: ByteString -> Either DecodeError S.Decl.DeclFormat
|
||||
decodeDeclFormat = getFromBytesOr "getDeclFormat" S.getDeclFormat
|
||||
|
||||
decodeDeclElement :: Word64 -> ByteString -> Either DecodeError (LocalIds, S.Decl.Decl Symbol)
|
||||
decodeDeclElement i = getFromBytesOr ("lookupDeclElement " <> tShow i) (S.lookupDeclElement i)
|
||||
|
||||
getCycleLen :: DB m => H.Hash -> m (Maybe Word64)
|
||||
getCycleLen h = do
|
||||
@ -389,9 +365,8 @@ getCycleLen h = do
|
||||
Q.expectObject oid decodeComponentLengthOnly
|
||||
|
||||
-- | Get the 'C.DeclType.DeclType' of a 'C.Reference.Id'.
|
||||
-- TODO rename to expectDeclTypeById
|
||||
getDeclTypeById :: DB m => C.Reference.Id -> m C.Decl.DeclType
|
||||
getDeclTypeById =
|
||||
expectDeclTypeById :: DB m => C.Reference.Id -> m C.Decl.DeclType
|
||||
expectDeclTypeById =
|
||||
fmap C.Decl.declType . expectDeclByReference
|
||||
|
||||
componentByObjectId :: DB m => Db.ObjectId -> m [S.Reference.Id]
|
||||
@ -681,7 +656,7 @@ listWatches k = Q.loadWatchesByWatchKind k >>= traverse h2cReferenceId
|
||||
loadWatch :: DB m => WatchKind -> C.Reference.Id -> MaybeT m (C.Term Symbol)
|
||||
loadWatch k r = do
|
||||
r' <- C.Reference.idH Q.saveHashHash r
|
||||
S.Term.WatchResult wlids t <- MaybeT (Q.loadWatch k r' (getFromBytesOr (ErrWatch k r) S.getWatchResultFormat))
|
||||
S.Term.WatchResult wlids t <- MaybeT (Q.loadWatch k r' (getFromBytesOr "getWatchResultFormat" S.getWatchResultFormat))
|
||||
w2cTerm wlids t
|
||||
|
||||
saveWatch :: DB m => WatchKind -> C.Reference.Id -> C.Term Symbol -> m ()
|
||||
@ -987,7 +962,7 @@ expectDbBranch id =
|
||||
deserializeBranchObject :: DB m => Db.BranchObjectId -> m S.BranchFormat
|
||||
deserializeBranchObject id = do
|
||||
when debug $ traceM $ "deserializeBranchObject " ++ show id
|
||||
Q.expectNamespaceObject (Db.unBranchObjectId id) (getFromBytesOr (ErrBranch id) S.getBranchFormat)
|
||||
Q.expectNamespaceObject (Db.unBranchObjectId id) decodeBranchFormat
|
||||
|
||||
doDiff :: DB m => Db.BranchObjectId -> [S.Branch.Diff] -> m S.DbBranch
|
||||
doDiff ref ds =
|
||||
@ -1126,7 +1101,7 @@ s2cPatch (S.Patch termEdits typeEdits) =
|
||||
deserializePatchObject :: DB m => Db.PatchObjectId -> m S.PatchFormat
|
||||
deserializePatchObject id = do
|
||||
when debug $ traceM $ "Operations.deserializePatchObject " ++ show id
|
||||
Q.expectPatchObject (Db.unPatchObjectId id) (getFromBytesOr (ErrPatch id) S.getPatchFormat)
|
||||
Q.expectPatchObject (Db.unPatchObjectId id) decodePatchFormat
|
||||
|
||||
lca :: DB m => CausalHash -> CausalHash -> Connection -> Connection -> m (Maybe CausalHash)
|
||||
lca h1 h2 c1 c2 = runMaybeT do
|
||||
|
@ -401,7 +401,6 @@ expectObjectIdForPrimaryHash h = do
|
||||
hashId <- expectHashIdByHash h
|
||||
expectObjectIdForPrimaryHashId hashId
|
||||
|
||||
-- FIXME this doesn't check that the object is actually a patch
|
||||
loadPatchObjectIdForPrimaryHash :: DB m => PatchHash -> m (Maybe PatchObjectId)
|
||||
loadPatchObjectIdForPrimaryHash =
|
||||
(fmap . fmap) PatchObjectId . loadObjectIdForPrimaryHash . unPatchHash
|
||||
|
@ -282,10 +282,10 @@ sqliteCodebase debugName root localOrRemote action = do
|
||||
++ ", but I've been asked for it's ConstructorType."
|
||||
in pure . fromMaybe err $
|
||||
Map.lookup (Reference.Builtin t) Builtins.builtinConstructorType
|
||||
C.Reference.ReferenceDerived i -> getDeclTypeById i
|
||||
C.Reference.ReferenceDerived i -> expectDeclTypeById i
|
||||
|
||||
getDeclTypeById :: forall m. DB m => C.Reference.Id -> m CT.ConstructorType
|
||||
getDeclTypeById = fmap Cv.decltype2to1 . Ops.getDeclTypeById
|
||||
expectDeclTypeById :: forall m. DB m => C.Reference.Id -> m CT.ConstructorType
|
||||
expectDeclTypeById = fmap Cv.decltype2to1 . Ops.expectDeclTypeById
|
||||
|
||||
getTypeOfTermImpl :: MonadIO m => Reference.Id -> m (Maybe (Type Symbol Ann))
|
||||
getTypeOfTermImpl id | debug && trace ("getTypeOfTermImpl " ++ show id) False = undefined
|
||||
|
Loading…
Reference in New Issue
Block a user