mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-21 21:59:30 +03:00
Merge pull request #49 from HuwCampbell/topic/exact-hotmap
Topic/exact hotmap
This commit is contained in:
commit
fe8a5ff37f
@ -41,6 +41,9 @@ import System.IO.Unsafe ( unsafeInterleaveIO )
|
||||
-- This network is able to learn and generate simple words in
|
||||
-- about an hour.
|
||||
--
|
||||
-- Grab the input from
|
||||
-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
||||
--
|
||||
-- This is a first class recurrent net.
|
||||
--
|
||||
-- The F and R types are tagging types to ensure that the runner and
|
||||
@ -69,7 +72,7 @@ loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Ve
|
||||
loadShakespeare path = do
|
||||
contents <- lift $ readFile path
|
||||
let annotated = annotateCapitals contents
|
||||
(m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
|
||||
(m,cs) <- ExceptT . return $ hotMap (Proxy :: Proxy 40) annotated
|
||||
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
|
||||
return (V.fromList hot, m, cs)
|
||||
|
||||
|
@ -42,12 +42,9 @@ library
|
||||
, containers >= 0.5 && < 0.6
|
||||
, cereal >= 0.5 && < 0.6
|
||||
, deepseq >= 1.4 && < 1.5
|
||||
, exceptions == 0.8.*
|
||||
, hmatrix == 0.18.*
|
||||
, MonadRandom >= 0.4 && < 0.6
|
||||
, mtl >= 2.2.1 && < 2.3
|
||||
, primitive >= 0.6 && < 0.7
|
||||
, text == 1.2.*
|
||||
, singletons >= 2.1 && < 2.4
|
||||
, vector >= 0.11 && < 0.13
|
||||
|
||||
|
@ -52,15 +52,16 @@ oneHot hot =
|
||||
|
||||
-- | Create a one hot map from any enumerable.
|
||||
-- Returns a map, and the ordered list for the reverse transformation
|
||||
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
|
||||
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Either String (Map a Int, Vector a)
|
||||
hotMap n as =
|
||||
let len = fromIntegral $ natVal n
|
||||
uniq = [ c | (c:_) <- group $ sort as]
|
||||
hotl = length uniq
|
||||
in if hotl <= len
|
||||
in if hotl == len
|
||||
then
|
||||
Just (M.fromList $ zip uniq [0..], V.fromList uniq)
|
||||
else Nothing
|
||||
Right (M.fromList $ zip uniq [0..], V.fromList uniq)
|
||||
else
|
||||
Left ("Couldn't create hotMap of size " ++ show len ++ " from vector with " ++ show hotl ++ " unique characters")
|
||||
|
||||
-- | From a map and value, create a 1D Shape
|
||||
-- with one index hot (1) with the rest 0.
|
||||
|
Loading…
Reference in New Issue
Block a user