]> gitweb @ CieloNegro.org - haskell-dns.git/commitdiff
Introduce Packer monad so that we can compress binary packets.
authorPHO <pho@cielonegro.org>
Fri, 22 May 2009 07:35:49 +0000 (16:35 +0900)
committerPHO <pho@cielonegro.org>
Fri, 22 May 2009 07:35:49 +0000 (16:35 +0900)
Network/DNS/Message.hs
Network/DNS/Packer.hs [new file with mode: 0644]
Network/DNS/Unpacker.hs
dns.cabal

index be0b79a33800a32dcc769c7ff68129f322261bf4..5c537956bd1657284a8c3b15db3dd9e9941cce62 100644 (file)
@@ -34,15 +34,17 @@ import           Control.Monad
 import           Data.Binary
 import           Data.Binary.BitPut as BP
 import           Data.Binary.Get as G
 import           Data.Binary
 import           Data.Binary.BitPut as BP
 import           Data.Binary.Get as G
-import           Data.Binary.Put as P
+import           Data.Binary.Put as P'
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
-import qualified Data.ByteString.Lazy as LBS
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
+import qualified Data.Map as M
+import           Data.Map (Map)
 import           Data.Word
 import           Data.Word
+import           Network.DNS.Packer as P
 import           Network.DNS.Unpacker as U
 import           Network.Socket
 
 import           Network.DNS.Unpacker as U
 import           Network.Socket
 
@@ -108,11 +110,11 @@ data Question
 
 type SomeQT = SomeRT
 
 
 type SomeQT = SomeRT
 
-putQ :: Question -> Put
+putQ :: Question -> Packer CompTable ()
 putQ q
     = do putDomainName $ qName q
          putSomeRT $ qType q
 putQ q
     = do putDomainName $ qName q
          putSomeRT $ qType q
-         put $ qClass q
+         putBinary $ qClass q
 
 getQ :: Unpacker DecompTable Question
 getQ = do nm <- getDomainName
 
 getQ :: Unpacker DecompTable Question
 getQ = do nm <- getDomainName
@@ -125,23 +127,25 @@ getQ = do nm <- getDomainName
                      }
 
 
                      }
 
 
-newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
 type DomainLabel    = BS.ByteString
 
 type DomainLabel    = BS.ByteString
 
-nameToLabels :: DomainName -> [DomainLabel]
-nameToLabels (DN ls) = ls
-
-labelsToName :: [DomainLabel] -> DomainName
-labelsToName = DN
-
 rootName :: DomainName
 rootName = DN [BS.empty]
 
 rootName :: DomainName
 rootName = DN [BS.empty]
 
+isRootName :: DomainName -> Bool
+isRootName (DN [_]) = True
+isRootName _        = False
+
 consLabel :: DomainLabel -> DomainName -> DomainName
 consLabel x (DN ys) = DN (x:ys)
 
 consLabel :: DomainLabel -> DomainName -> DomainName
 consLabel x (DN ys) = DN (x:ys)
 
+unconsLabel :: DomainName -> (DomainLabel, DomainName)
+unconsLabel (DN (x:xs)) = (x, DN xs)
+unconsLabel x           = error ("Illegal use of unconsLabel: " ++ show x)
+
 mkDomainName :: String -> DomainName
 mkDomainName :: String -> DomainName
-mkDomainName = labelsToName . mkLabels [] . notEmpty
+mkDomainName = DN . mkLabels [] . notEmpty
     where
       notEmpty :: String -> String
       notEmpty xs = assert (not $ null xs) xs
     where
       notEmpty :: String -> String
       notEmpty xs = assert (not $ null xs) xs
@@ -173,18 +177,6 @@ data RecordType rt dt => ResourceRecord rt dt
     deriving (Show, Eq, Typeable)
 
 
     deriving (Show, Eq, Typeable)
 
 
-putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
-putRR rr = do putDomainName $ rrName rr
-              putRecordType $ rrType  rr
-              put $ rrClass rr
-              putWord32be $ rrTTL rr
-
-              let dat = runPut $
-                        putRecordData (undefined :: rt) (rrData rr)
-              putWord16be $ fromIntegral $ LBS.length dat
-              putLazyByteString dat
-
-
 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
 
 instance Show SomeRR where
 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
 
 instance Show SomeRR where
@@ -194,8 +186,8 @@ instance Eq SomeRR where
     (SomeRR a) == (SomeRR b) = Just a == cast b
 
 
     (SomeRR a) == (SomeRR b) = Just a == cast b
 
 
-putSomeRR :: SomeRR -> Put
-putSomeRR (SomeRR rr) = putRR rr
+putSomeRR :: SomeRR -> Packer CompTable ()
+putSomeRR (SomeRR rr) = putResourceRecord rr
 
 getSomeRR :: Unpacker DecompTable SomeRR
 getSomeRR = do srt <- U.lookAhead $
 
 getSomeRR :: Unpacker DecompTable SomeRR
 getSomeRR = do srt <- U.lookAhead $
@@ -205,8 +197,9 @@ getSomeRR = do srt <- U.lookAhead $
                  SomeRT rt
                      -> getResourceRecord rt >>= return . SomeRR
 
                  SomeRT rt
                      -> getResourceRecord rt >>= return . SomeRR
 
+type CompTable   = Map DomainName Int
 type DecompTable = IntMap DomainName
 type DecompTable = IntMap DomainName
-type TTL = Word32
+type TTL         = Word32
 
 getDomainName :: Unpacker DecompTable DomainName
 getDomainName = worker
 
 getDomainName :: Unpacker DecompTable DomainName
 getDomainName = worker
@@ -217,7 +210,7 @@ getDomainName = worker
                hdr    <- getLabelHeader
                case hdr of
                  Offset n
                hdr    <- getLabelHeader
                case hdr of
                  Offset n
-                     -> do dt <- getState
+                     -> do dt <- U.getState
                            case IM.lookup n dt of
                              Just name
                                  -> return name
                            case IM.lookup n dt of
                              Just name
                                  -> return name
@@ -229,7 +222,7 @@ getDomainName = worker
                      -> do label <- U.getByteString n
                            rest  <- worker
                            let name = consLabel label rest
                      -> do label <- U.getByteString n
                            rest  <- worker
                            let name = consLabel label rest
-                           modifyState $ IM.insert offset name
+                           U.modifyState $ IM.insert offset name
                            return name
 
       getLabelHeader :: Unpacker s LabelHeader
                            return name
 
       getLabelHeader :: Unpacker s LabelHeader
@@ -262,28 +255,64 @@ getCharString :: Unpacker s BS.ByteString
 getCharString = do len <- U.getWord8
                    U.getByteString (fromIntegral len)
 
 getCharString = do len <- U.getWord8
                    U.getByteString (fromIntegral len)
 
-putCharString :: BS.ByteString -> Put
-putCharString = putDomainLabel
+putCharString :: BS.ByteString -> Packer s ()
+putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs
+                      P.putByteString xs
 
 data LabelHeader
     = Offset !Int
     | Length !Int
 
 
 data LabelHeader
     = Offset !Int
     | Length !Int
 
-putDomainName :: DomainName -> Put
-putDomainName = mapM_ putDomainLabel . nameToLabels
+putDomainName :: DomainName -> Packer CompTable ()
+putDomainName name
+    = do ct <- P.getState
+         case M.lookup name ct of
+           Just n
+               -> do let ptr = runBitPut $
+                               do putBit True
+                                  putBit True
+                                  putNBits 14 n
+                     P.putLazyByteString ptr
+           Nothing
+               -> do offset <- bytesWrote
+                     P.modifyState $ M.insert name offset
+
+                     let (label, rest) = unconsLabel name
+
+                     putCharString label
+
+                     if isRootName rest then
+                         P.putWord8 0
+                       else
+                         putDomainName rest
 
 
-putDomainLabel :: DomainLabel -> Put
-putDomainLabel l
-    = do putWord8 $ fromIntegral $ BS.length l
-         P.putByteString l
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
-    putRecordData :: rt -> dt -> Put
+    putRecordData :: rt -> dt -> Packer CompTable ()
     getRecordData :: rt -> Unpacker DecompTable dt
 
     getRecordData :: rt -> Unpacker DecompTable dt
 
-    putRecordType :: rt -> Put
-    putRecordType = putWord16be . fromIntegral . rtToInt
+    putRecordType :: rt -> Packer s ()
+    putRecordType = P.putWord16be . fromIntegral . rtToInt
+
+    putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
+    putResourceRecord rr
+        = do putDomainName $ rrName  rr
+             putRecordType $ rrType  rr
+             putBinary     $ rrClass rr
+             P.putWord32be $ rrTTL   rr
+
+             -- First, write a dummy data length.
+             offset <- bytesWrote
+             P.putWord16be 0
+
+             -- Second, write data.
+             putRecordData (rrType rr) (rrData rr)
+
+             -- Third, rewrite the dummy length to an actual value.
+             offset' <- bytesWrote
+             withOffset offset
+                 $ P.putWord16be (fromIntegral (offset' - offset - 2))
 
     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
     getResourceRecord rt
 
     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
     getResourceRecord rt
@@ -309,7 +338,7 @@ instance Show SomeRT where
 instance Eq SomeRT where
     (SomeRT a) == (SomeRT b) = Just a == cast b
 
 instance Eq SomeRT where
     (SomeRT a) == (SomeRT b) = Just a == cast b
 
-putSomeRT :: SomeRT -> Put
+putSomeRT :: SomeRT -> Packer s ()
 putSomeRT (SomeRT rt) = putRecordType rt
 
 getSomeRT :: Unpacker s SomeRT
 putSomeRT (SomeRT rt) = putRecordType rt
 
 getSomeRT :: Unpacker s SomeRT
@@ -323,7 +352,7 @@ getSomeRT = do n <- liftM fromIntegral U.getWord16be
 data A = A deriving (Show, Eq, Typeable)
 instance RecordType A HostAddress where
     rtToInt       _ = 1
 data A = A deriving (Show, Eq, Typeable)
 instance RecordType A HostAddress where
     rtToInt       _ = 1
-    putRecordData _ = putWord32be
+    putRecordData _ = P.putWord32be
     getRecordData _ = U.getWord32be
 
 data NS = NS deriving (Show, Eq, Typeable)
     getRecordData _ = U.getWord32be
 
 data NS = NS deriving (Show, Eq, Typeable)
@@ -376,17 +405,18 @@ data RecordType
 -}
 
 instance Binary Message where
 -}
 
 instance Binary Message where
-    put m = do put $ msgHeader m
-               putWord16be $ fromIntegral $ length $ msgQuestions m
-               putWord16be $ fromIntegral $ length $ msgAnswers m
-               putWord16be $ fromIntegral $ length $ msgAuthorities m
-               putWord16be $ fromIntegral $ length $ msgAdditionals m
-               mapM_ putQ  $ msgQuestions m
+    put m = P.liftToBinary M.empty $
+            do putBinary $ msgHeader m
+               P.putWord16be $ fromIntegral $ length $ msgQuestions m
+               P.putWord16be $ fromIntegral $ length $ msgAnswers m
+               P.putWord16be $ fromIntegral $ length $ msgAuthorities m
+               P.putWord16be $ fromIntegral $ length $ msgAdditionals m
+               mapM_ putQ      $ msgQuestions m
                mapM_ putSomeRR $ msgAnswers m
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
 
                mapM_ putSomeRR $ msgAnswers m
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
 
-    get = liftToBinary IM.empty $
+    get = U.liftToBinary IM.empty $
           do hdr  <- getBinary
              nQ   <- liftM fromIntegral U.getWord16be
              nAns <- liftM fromIntegral U.getWord16be
           do hdr  <- getBinary
              nQ   <- liftM fromIntegral U.getWord16be
              nAns <- liftM fromIntegral U.getWord16be
@@ -405,8 +435,8 @@ instance Binary Message where
                         }
 
 instance Binary Header where
                         }
 
 instance Binary Header where
-    put h = do putWord16be $ hdMessageID h
-               putLazyByteString flags
+    put h = do P'.putWord16be $ hdMessageID h
+               P'.putLazyByteString flags
         where
           flags = runBitPut $
                   do putNBits 1 $ fromEnum $ hdMessageType h
         where
           flags = runBitPut $
                   do putNBits 1 $ fromEnum $ hdMessageType h
@@ -538,7 +568,7 @@ instance Enum RecordClass where
 
 instance Binary RecordClass where
     get = liftM (toEnum . fromIntegral) G.getWord16be
 
 instance Binary RecordClass where
     get = liftM (toEnum . fromIntegral) G.getWord16be
-    put = putWord16be . fromIntegral . fromEnum
+    put = P'.putWord16be . fromIntegral . fromEnum
 
 
 defaultRTTable :: IntMap SomeRT
 
 
 defaultRTTable :: IntMap SomeRT
diff --git a/Network/DNS/Packer.hs b/Network/DNS/Packer.hs
new file mode 100644 (file)
index 0000000..7f8f895
--- /dev/null
@@ -0,0 +1,154 @@
+module Network.DNS.Packer
+    ( Packer
+
+    , pack
+    , pack'
+
+    , getState
+    , setState
+    , modifyState
+
+    , bytesWrote
+    , withOffset
+
+    , putByteString
+    , putLazyByteString
+
+    , putWord8
+    , putWord16be
+    , putWord32be
+
+    , putBinary
+    , liftToBinary
+    )
+    where
+
+import qualified Data.Binary as Binary
+import qualified Data.Binary.Put as Bin
+import           Data.Bits
+import qualified Data.ByteString as Strict
+import qualified Data.ByteString.Lazy as Lazy
+import           Data.Int
+import           Data.Word
+
+
+data PackingState s
+    = PackingState {
+        stResult     :: !Lazy.ByteString
+      , stBytesWrote :: !Int64
+      , stUserState  :: s
+      }
+
+newtype Packer s a = P { unP :: PackingState s -> (a, PackingState s) }
+
+instance Monad (Packer s) where
+    return a = P (\ s -> (a, s))
+    m >>= k  = P (\ s -> let (a, s') = unP m s
+                         in
+                           unP (k a) s')
+    fail err = do bytes <- get stBytesWrote
+                  P (error (err
+                            ++ ". Failed packing at byte position "
+                            ++ show bytes))
+
+get :: (PackingState s -> a) -> Packer s a
+get f = P (\ s -> (f s, s))
+
+set :: (PackingState s -> PackingState s) -> Packer s ()
+set f = P (\ s -> ((), f s))
+
+mkState :: Lazy.ByteString -> Int64 -> s -> PackingState s
+mkState xs n s
+    = PackingState {
+        stResult     = xs
+      , stBytesWrote = n
+      , stUserState  = s
+      }
+
+pack' :: Packer s a -> s -> (Lazy.ByteString, s, a)
+pack' m s
+    = let (a, s') = unP m (mkState Lazy.empty 0 s)
+      in
+        (stResult s', stUserState s', a)
+
+pack :: Packer s a -> s -> Lazy.ByteString
+pack = (fst' .) . pack'
+    where
+      fst' (xs, _, _) = xs
+
+getState :: Packer s s
+getState = get stUserState
+
+setState :: s -> Packer s ()
+setState = modifyState . const
+
+modifyState :: (s -> s) -> Packer s ()
+modifyState f
+    = set $ \ st -> st { stUserState = f (stUserState st) }
+
+bytesWrote :: Integral i => Packer s i
+bytesWrote = get stBytesWrote >>= return . fromIntegral
+
+withOffset :: Int64 -> Packer s a -> Packer s a
+withOffset n m
+    = P $ \ s -> let (taken, dropped) = Lazy.splitAt n (stResult s)
+                     padded           = Lazy.take n (taken `Lazy.append` Lazy.repeat 0)
+                     tempState        = s {
+                                          stResult     = padded
+                                        , stBytesWrote = stBytesWrote s - Lazy.length dropped
+                                        }
+                     (a, tempState')  = unP m tempState
+                     newState         = tempState {
+                                          stResult     = replaceHead (stResult s) (stResult tempState')
+                                        , stBytesWrote = max (stBytesWrote s) (stBytesWrote tempState')
+                                        }
+                 in
+                   (a, newState)
+      where
+        replaceHead :: Lazy.ByteString -> Lazy.ByteString -> Lazy.ByteString
+        replaceHead world newHead
+            = let rest = Lazy.drop (Lazy.length newHead) world
+              in
+                newHead `Lazy.append` rest
+
+
+putByteString :: Strict.ByteString -> Packer s ()
+putByteString = putLazyByteString . Lazy.fromChunks . (:[])
+
+putLazyByteString :: Lazy.ByteString -> Packer s ()
+putLazyByteString xs
+    = set $ \ st -> st {
+                      stResult     = stResult st `Lazy.append` xs
+                    , stBytesWrote = stBytesWrote st + Lazy.length xs
+                    }
+
+putWord8 :: Word8 -> Packer s ()
+putWord8 w
+    = set $ \ st -> st {
+                      stResult     = stResult st `Lazy.snoc` w
+                    , stBytesWrote = stBytesWrote st + 1
+                    }
+
+putWord16be :: Word16 -> Packer s ()
+putWord16be w
+    = do putWord8 $ fromIntegral $ (w `shiftR`  8) .&. 0xFF
+         putWord8 $ fromIntegral $  w              .&. 0xFF
+
+putWord32be :: Word32 -> Packer s ()
+putWord32be w
+    = do putWord8 $ fromIntegral $ (w `shiftR` 24) .&. 0xFF
+         putWord8 $ fromIntegral $ (w `shiftR` 16) .&. 0xFF
+         putWord8 $ fromIntegral $ (w `shiftR`  8) .&. 0xFF
+         putWord8 $ fromIntegral $  w              .&. 0xFF
+
+
+putBinary :: Binary.Binary a => a -> Packer s ()
+putBinary = putLazyByteString . Binary.encode
+
+
+liftToBinary :: s -> Packer s a -> Bin.PutM a
+liftToBinary s m
+    = do let (a, s') = unP m (mkState Lazy.empty 0 s)
+
+         Bin.putLazyByteString (stResult s')
+         return a
index db349468262015599ccaf996bd98ceda693ff0c1..36106501e2f432a02f427b4830b8ff76a983099f 100644 (file)
@@ -26,9 +26,9 @@ module Network.DNS.Unpacker
 
 import qualified Data.Binary as Binary
 import qualified Data.Binary.Get as Bin
 
 import qualified Data.Binary as Binary
 import qualified Data.Binary.Get as Bin
+import           Data.Bits
 import qualified Data.ByteString as Strict
 import qualified Data.ByteString.Lazy as Lazy
 import qualified Data.ByteString as Strict
 import qualified Data.ByteString.Lazy as Lazy
-import           Data.Bits
 import           Data.Int
 import           Data.Word
 
 import           Data.Int
 import           Data.Word
 
index 3f609129c303318b7b38eab2481378c51a3087a0..b62c8714a7d20dc838cd9eeaff9a50bbfc79fadd 100644 (file)
--- a/dns.cabal
+++ b/dns.cabal
@@ -22,8 +22,7 @@ Library
 
     Exposed-Modules:
         Network.DNS.Message
 
     Exposed-Modules:
         Network.DNS.Message
-
-    Other-Modules:
+        Network.DNS.Packer
         Network.DNS.Unpacker
 
     Extensions:
         Network.DNS.Unpacker
 
     Extensions: