grenade/main/shakespeare.hs

185 lines
6.8 KiB
Haskell
Raw Normal View History

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE LambdaCase #-}
import Control.Monad.Random
import Control.Monad.Trans.Except
import Data.Char ( isUpper, toUpper, toLower )
2017-01-19 10:55:13 +03:00
import Data.List ( foldl' )
import Data.Maybe ( fromMaybe )
import Data.Semigroup ( (<>) )
import qualified Data.Vector as V
import Data.Vector ( Vector )
import qualified Data.Map as M
import Data.Proxy ( Proxy (..) )
2017-01-19 10:55:13 +03:00
import qualified Data.ByteString as B
import Data.Serialize
import Data.Singletons.Prelude
import GHC.TypeLits
import Numeric.LinearAlgebra.Static ( konst )
import Options.Applicative
import Grenade
import Grenade.Recurrent
import Grenade.Utils.OneHot
2017-01-19 10:55:13 +03:00
import System.IO.Unsafe ( unsafeInterleaveIO )
-- The defininition for our natural language recurrent network.
-- This network is able to learn and generate simple words in
-- about an hour.
--
2017-01-19 10:55:13 +03:00
-- This is a first class recurrent net.
--
-- The F and R types are tagging types to ensure that the runner and
-- creation function know how to treat the layers.
--
-- As an example, here's a short sequence generated.
--
2017-01-19 10:55:13 +03:00
-- > KING RICHARD III:
-- > And as the heaven her his words, we the son, I show sand stape but the lament to shall were the sons with a strend
type F = FeedForward
type R = Recurrent
-- The definition of our network
2017-01-19 10:55:13 +03:00
type Shakespeare = RecurrentNetwork '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit]
'[ 'D1 40, 'D1 80, 'D1 40, 'D1 40, 'D1 40 ]
2017-01-20 10:53:18 +03:00
-- The definition of the "sideways" input, which the network is fed recurrently.
2017-01-19 10:55:13 +03:00
type Shakespearian = RecurrentInputs '[ R (LSTM 40 80), R (LSTM 80 40), F (FullyConnected 40 40), F Logit]
randomNet :: MonadRandom m => m (Shakespeare, Shakespearian)
randomNet = randomRecurrent
-- | Load the data files and prepare a map of characters to a compressed int representation.
loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Vector Char)
loadShakespeare path = do
contents <- lift $ readFile path
let annotated = annotateCapitals contents
(m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
return (V.fromList hot, m, cs)
trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
trainSlice !rate !net !recIns input offset size =
let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
in case reverse e of
(o : l : xs) ->
let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
in trainRecurrent rate net recIns examples
_ -> error "Not enough input"
where
x = fromMaybe (error "Hot variable didn't fit.")
runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
runShakespeare ShakespeareOpts {..} = do
(shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
2017-01-19 10:55:13 +03:00
(net0, i0) <- lift $
case loadPath of
Just loadFile -> netLoad loadFile
Nothing -> randomNet
(trained, bestInput) <- lift $ foldM (\(!net, !io) size -> do
xs <- take (iterations `div` 10) <$> getRandomRs (0, length shakespeare - size - 1)
let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
results <- take 1000 <$> generateParagraph trained bestInput temperature oneHotMap oneHotDictionary ( S1D $ konst 0)
putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
putStrLn (unAnnotateCapitals results)
return (trained, bestInput)
) (net0, i0) $ replicate 10 sequenceSize
2017-01-19 10:55:13 +03:00
case savePath of
Just saveFile -> lift . B.writeFile saveFile $ runPut (put (trained, bestInput))
Nothing -> return ()
generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a)
=> RecurrentNetwork layers shapes
-> RecurrentInputs layers
-> Double
-> M.Map a Int
-> Vector a
-> S ('D1 n)
2017-01-19 10:55:13 +03:00
-> IO [a]
generateParagraph n s temperature hotmap hotdict =
2017-01-19 10:55:13 +03:00
go s
where
2017-01-19 10:55:13 +03:00
go x y =
do let (ns, o) = runRecurrent n x y
un <- sample temperature hotdict o
2017-01-19 10:55:13 +03:00
Just re <- return $ makeHot hotmap un
rest <- unsafeInterleaveIO $ go ns re
return (un : rest)
data ShakespeareOpts = ShakespeareOpts {
trainingFile :: FilePath
, iterations :: Int
, rate :: LearningParameters
, sequenceSize :: Int
, temperature :: Double
2017-01-19 10:55:13 +03:00
, loadPath :: Maybe FilePath
, savePath :: Maybe FilePath
}
shakespeare' :: Parser ShakespeareOpts
shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN")
<*> option auto (long "examples" <> short 'e' <> value 1000000)
<*> (LearningParameters
<$> option auto (long "train_rate" <> short 'r' <> value 0.01)
<*> option auto (long "momentum" <> value 0.95)
<*> option auto (long "l2" <> value 0.000001)
)
<*> option auto (long "sequence-length" <> short 's' <> value 50)
<*> option auto (long "temperature" <> short 't' <> value 0.4)
2017-01-19 10:55:13 +03:00
<*> optional (strOption (long "load"))
<*> optional (strOption (long "save"))
main :: IO ()
main = do
shopts <- execParser (info (shakespeare' <**> helper) idm)
res <- runExceptT $ runShakespeare shopts
case res of
Right () -> pure ()
Left err -> putStrLn err
2017-01-19 10:55:13 +03:00
netLoad :: FilePath -> IO (Shakespeare, Shakespearian)
netLoad modelPath = do
modelData <- B.readFile modelPath
either fail return $ runGet get modelData
-- Replace capitals with an annotation and the lower case letter
-- http://fastml.com/one-weird-trick-for-training-char-rnns/
annotateCapitals :: String -> String
annotateCapitals (x : rest)
| isUpper x
= '^' : toLower x : annotateCapitals rest
| otherwise
= x : annotateCapitals rest
annotateCapitals []
= []
unAnnotateCapitals :: String -> String
unAnnotateCapitals ('^' : x : rest)
= toUpper x : unAnnotateCapitals rest
unAnnotateCapitals (x : rest)
= x : unAnnotateCapitals rest
unAnnotateCapitals []
= []
-- | Tag the 'Nothing' value of a 'Maybe'
note :: a -> Maybe b -> Either a b
note a = maybe (Left a) Right