Merge pull request #49 from HuwCampbell/topic/exact-hotmap

Topic/exact hotmap
This commit is contained in:
Huw Campbell 2018-01-04 21:35:13 +11:00 committed by GitHub
commit fe8a5ff37f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 8 deletions

View File

@ -41,6 +41,9 @@ import System.IO.Unsafe ( unsafeInterleaveIO )
-- This network is able to learn and generate simple words in
-- about an hour.
--
-- Grab the input from
-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
--
-- This is a first class recurrent net.
--
-- The F and R types are tagging types to ensure that the runner and
@ -69,7 +72,7 @@ loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Ve
loadShakespeare path = do
contents <- lift $ readFile path
let annotated = annotateCapitals contents
(m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
(m,cs) <- ExceptT . return $ hotMap (Proxy :: Proxy 40) annotated
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
return (V.fromList hot, m, cs)

View File

@ -42,12 +42,9 @@ library
, containers >= 0.5 && < 0.6
, cereal >= 0.5 && < 0.6
, deepseq >= 1.4 && < 1.5
, exceptions == 0.8.*
, hmatrix == 0.18.*
, MonadRandom >= 0.4 && < 0.6
, mtl >= 2.2.1 && < 2.3
, primitive >= 0.6 && < 0.7
, text == 1.2.*
, singletons >= 2.1 && < 2.4
, vector >= 0.11 && < 0.13

View File

@ -52,15 +52,16 @@ oneHot hot =
-- | Create a one hot map from any enumerable.
-- Returns a map, and the ordered list for the reverse transformation
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Either String (Map a Int, Vector a)
hotMap n as =
let len = fromIntegral $ natVal n
uniq = [ c | (c:_) <- group $ sort as]
hotl = length uniq
in if hotl <= len
in if hotl == len
then
Just (M.fromList $ zip uniq [0..], V.fromList uniq)
else Nothing
Right (M.fromList $ zip uniq [0..], V.fromList uniq)
else
Left ("Couldn't create hotMap of size " ++ show len ++ " from vector with " ++ show hotl ++ " unique characters")
-- | From a map and value, create a 1D Shape
-- with one index hot (1) with the rest 0.