mirror of
https://github.com/HuwCampbell/grenade.git
synced 2024-11-22 06:55:13 +03:00
Merge pull request #49 from HuwCampbell/topic/exact-hotmap
Topic/exact hotmap
This commit is contained in:
commit
fe8a5ff37f
@ -41,6 +41,9 @@ import System.IO.Unsafe ( unsafeInterleaveIO )
|
|||||||
-- This network is able to learn and generate simple words in
|
-- This network is able to learn and generate simple words in
|
||||||
-- about an hour.
|
-- about an hour.
|
||||||
--
|
--
|
||||||
|
-- Grab the input from
|
||||||
|
-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
|
||||||
|
--
|
||||||
-- This is a first class recurrent net.
|
-- This is a first class recurrent net.
|
||||||
--
|
--
|
||||||
-- The F and R types are tagging types to ensure that the runner and
|
-- The F and R types are tagging types to ensure that the runner and
|
||||||
@ -69,7 +72,7 @@ loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Ve
|
|||||||
loadShakespeare path = do
|
loadShakespeare path = do
|
||||||
contents <- lift $ readFile path
|
contents <- lift $ readFile path
|
||||||
let annotated = annotateCapitals contents
|
let annotated = annotateCapitals contents
|
||||||
(m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
|
(m,cs) <- ExceptT . return $ hotMap (Proxy :: Proxy 40) annotated
|
||||||
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
|
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
|
||||||
return (V.fromList hot, m, cs)
|
return (V.fromList hot, m, cs)
|
||||||
|
|
||||||
|
@ -42,12 +42,9 @@ library
|
|||||||
, containers >= 0.5 && < 0.6
|
, containers >= 0.5 && < 0.6
|
||||||
, cereal >= 0.5 && < 0.6
|
, cereal >= 0.5 && < 0.6
|
||||||
, deepseq >= 1.4 && < 1.5
|
, deepseq >= 1.4 && < 1.5
|
||||||
, exceptions == 0.8.*
|
|
||||||
, hmatrix == 0.18.*
|
, hmatrix == 0.18.*
|
||||||
, MonadRandom >= 0.4 && < 0.6
|
, MonadRandom >= 0.4 && < 0.6
|
||||||
, mtl >= 2.2.1 && < 2.3
|
|
||||||
, primitive >= 0.6 && < 0.7
|
, primitive >= 0.6 && < 0.7
|
||||||
, text == 1.2.*
|
|
||||||
, singletons >= 2.1 && < 2.4
|
, singletons >= 2.1 && < 2.4
|
||||||
, vector >= 0.11 && < 0.13
|
, vector >= 0.11 && < 0.13
|
||||||
|
|
||||||
|
@ -52,15 +52,16 @@ oneHot hot =
|
|||||||
|
|
||||||
-- | Create a one hot map from any enumerable.
|
-- | Create a one hot map from any enumerable.
|
||||||
-- Returns a map, and the ordered list for the reverse transformation
|
-- Returns a map, and the ordered list for the reverse transformation
|
||||||
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
|
hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Either String (Map a Int, Vector a)
|
||||||
hotMap n as =
|
hotMap n as =
|
||||||
let len = fromIntegral $ natVal n
|
let len = fromIntegral $ natVal n
|
||||||
uniq = [ c | (c:_) <- group $ sort as]
|
uniq = [ c | (c:_) <- group $ sort as]
|
||||||
hotl = length uniq
|
hotl = length uniq
|
||||||
in if hotl <= len
|
in if hotl == len
|
||||||
then
|
then
|
||||||
Just (M.fromList $ zip uniq [0..], V.fromList uniq)
|
Right (M.fromList $ zip uniq [0..], V.fromList uniq)
|
||||||
else Nothing
|
else
|
||||||
|
Left ("Couldn't create hotMap of size " ++ show len ++ " from vector with " ++ show hotl ++ " unique characters")
|
||||||
|
|
||||||
-- | From a map and value, create a 1D Shape
|
-- | From a map and value, create a 1D Shape
|
||||||
-- with one index hot (1) with the rest 0.
|
-- with one index hot (1) with the rest 0.
|
||||||
|
Loading…
Reference in New Issue
Block a user