Better recursive ref search

This commit is contained in:
Chris Penner 2024-07-15 11:15:54 -07:00
parent 212a232736
commit ba78b69ed8

View File

@ -23,8 +23,10 @@ where
import Control.Concurrent.STM as STM
import Control.Exception (throwIO)
import Control.Monad
import Data.Binary.Get (runGetOrFail)
-- import Data.Bits (shiftL)
import Control.Monad.State
import Data.Binary.Get (runGetOrFail)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as BL
import Data.Bytes.Get (MonadGet, getWord8, runGetS)
@ -133,6 +135,7 @@ import Unison.Syntax.NamePrinter (prettyHashQualified)
import Unison.Syntax.TermPrinter
import Unison.Term qualified as Tm
import Unison.Util.EnumContainers as EC
import Unison.Util.Monoid (foldMapM)
import Unison.Util.Pretty as P
import UnliftIO qualified
import UnliftIO.Concurrent qualified as UnliftIO
@ -195,23 +198,26 @@ allocType ctx r cons =
pure $ ctx {dspec = Map.insert r cons $ dspec ctx}
recursiveDeclDeps ::
Set RF.LabeledDependency ->
CodeLookup Symbol IO () ->
Decl Symbol () ->
-- (type deps, term deps)
IO (Set Reference, Set Reference)
recursiveDeclDeps seen0 cl d = do
rec <- for (toList newDeps) $ \case
RF.DerivedId i ->
getTypeDeclaration cl i >>= \case
Just d -> recursiveDeclDeps seen cl d
Nothing -> pure mempty
_ -> pure mempty
pure $ (deps, mempty) <> fold rec
StateT (Set RF.LabeledDependency) IO (Set Reference, Set Reference)
recursiveDeclDeps cl d = do
seen0 <- get
let seen = seen0 <> Set.map RF.typeRef deps
put seen
let newDeps = Set.filter (\r -> notMember (RF.typeRef r) seen0) deps
rec <-
(toList newDeps) & foldMapM \r -> do
case r of
RF.DerivedId i ->
lift (getTypeDeclaration cl i) >>= \case
Just d -> recursiveDeclDeps cl d
Nothing -> pure mempty
_ -> pure mempty
pure $ (deps, mempty) <> rec
where
deps = declTypeDependencies d
newDeps = Set.filter (\r -> notMember (RF.typeRef r) seen0) deps
seen = seen0 <> Set.map RF.typeRef deps
categorize :: RF.LabeledDependency -> (Set Reference, Set Reference)
categorize =
@ -221,37 +227,39 @@ categorize =
RF.TermReference ref -> (mempty, Set.singleton ref)
recursiveTermDeps ::
Set RF.LabeledDependency ->
CodeLookup Symbol IO () ->
Term Symbol ->
-- (type deps, term deps)
IO (Set Reference, Set Reference)
recursiveTermDeps seen0 cl tm = do
rec <- for (toList (deps \\ seen0)) $ \case
RF.ConReference (RF.ConstructorReference (RF.DerivedId refId) _conId) _conType -> handleTypeReferenceId refId
RF.TypeReference (RF.DerivedId refId) -> handleTypeReferenceId refId
RF.TermReference r -> recursiveRefDeps seen cl r
_ -> pure mempty
pure $ foldMap categorize deps <> fold rec
StateT (Set RF.LabeledDependency) IO (Set Reference, Set Reference)
recursiveTermDeps cl tm = do
seen0 <- get
let seen = seen0 <> deps
put seen
rec <-
(toList (deps \\ seen0)) & foldMapM \r ->
case r of
RF.ConReference (RF.ConstructorReference (RF.DerivedId refId) _conId) _conType -> handleTypeReferenceId refId
RF.TypeReference (RF.DerivedId refId) -> handleTypeReferenceId refId
RF.TermReference r -> recursiveRefDeps cl r
_ -> pure mempty
pure $ foldMap categorize deps <> rec
where
handleTypeReferenceId :: RF.Id -> IO (Set Reference, Set Reference)
handleTypeReferenceId :: RF.Id -> StateT (Set RF.LabeledDependency) IO (Set Reference, Set Reference)
handleTypeReferenceId refId =
getTypeDeclaration cl refId >>= \case
Just d -> recursiveDeclDeps seen cl d
lift (getTypeDeclaration cl refId) >>= \case
Just d -> recursiveDeclDeps cl d
Nothing -> pure mempty
deps = Tm.labeledDependencies tm
seen = seen0 <> deps
recursiveRefDeps ::
Set RF.LabeledDependency ->
CodeLookup Symbol IO () ->
Reference ->
IO (Set Reference, Set Reference)
recursiveRefDeps seen cl (RF.DerivedId i) =
getTerm cl i >>= \case
Just tm -> recursiveTermDeps seen cl tm
StateT (Set RF.LabeledDependency) IO (Set Reference, Set Reference)
recursiveRefDeps cl (RF.DerivedId i) =
lift (getTerm cl i) >>= \case
Just tm -> recursiveTermDeps cl tm
Nothing -> pure mempty
recursiveRefDeps _ _ _ = pure mempty
recursiveRefDeps _ _ = pure mempty
recursiveIRefDeps ::
Map.Map Reference (SuperGroup Symbol) ->
@ -289,8 +297,8 @@ collectDeps ::
Term Symbol ->
IO ([(Reference, Either [Int] [Int])], [Reference])
collectDeps cl tm = do
(tys, tms) <- recursiveTermDeps mempty cl tm
(,toList tms) <$> traverse getDecl (toList tys)
(tys, tms) <- evalStateT (recursiveTermDeps cl tm) mempty
(,toList tms) <$> (traverse getDecl (toList tys))
where
getDecl ty@(RF.DerivedId i) =
(ty,) . maybe (Right []) declFields