]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Message.hs
Introduce Packer monad so that we can compress binary packets.
[haskell-dns.git] / Network / DNS / Message.hs
index 6144d13766e037d551838526f18804bf0b896451..5c537956bd1657284a8c3b15db3dd9e9941cce62 100644 (file)
@@ -13,14 +13,19 @@ module Network.DNS.Message
     , RecordType
     , RecordClass(..)
 
-    , SomeRR(..)
-    , SomeRT(..)
+    , SomeQT
+    , SomeRR
+    , SomeRT
 
+    , A(..)
+    , NS(..)
     , CNAME(..)
     , HINFO(..)
 
-    , mkQueryType
     , mkDomainName
+    , wrapQueryType
+    , wrapRecordType
+    , wrapRecord
     )
     where
 
@@ -29,24 +34,19 @@ import           Control.Monad
 import           Data.Binary
 import           Data.Binary.BitPut as BP
 import           Data.Binary.Get as G
-import           Data.Binary.Put as P
+import           Data.Binary.Put as P'
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
-import qualified Data.ByteString.Lazy as LBS
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
+import qualified Data.Map as M
+import           Data.Map (Map)
 import           Data.Word
-
-
-replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a)
-replicateM' = worker []
-    where
-      worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a)
-      worker soFar 0 _ a = return (reverse soFar, a)
-      worker soFar n f a = do (b, a') <- f a
-                              worker (b : soFar) (n - 1) f a'
+import           Network.DNS.Packer as P
+import           Network.DNS.Unpacker as U
+import           Network.Socket
 
 
 data Message
@@ -110,38 +110,42 @@ data Question
 
 type SomeQT = SomeRT
 
-mkQueryType :: RecordType rt dt => rt -> SomeQT
-mkQueryType = SomeRT
-
-putQ :: Question -> Put
+putQ :: Question -> Packer CompTable ()
 putQ q
     = do putDomainName $ qName q
          putSomeRT $ qType q
-         put $ qClass q
-
-getQ :: DecompTable -> Get (Question, DecompTable)
-getQ dt
-    = do (nm, dt') <- getDomainName dt
-         ty        <- getSomeRT
-         cl        <- get
-         let q = Question {
-                   qName  = nm
-                 , qType  = ty
-                 , qClass = cl
-                 }
-         return (q, dt')
-
-newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
+         putBinary $ qClass q
+
+getQ :: Unpacker DecompTable Question
+getQ = do nm <- getDomainName
+          ty <- getSomeRT
+          cl <- getBinary
+          return Question {
+                       qName  = nm
+                     , qType  = ty
+                     , qClass = cl
+                     }
+
+
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
 type DomainLabel    = BS.ByteString
 
-nameToLabels :: DomainName -> [DomainLabel]
-nameToLabels (DN ls) = ls
+rootName :: DomainName
+rootName = DN [BS.empty]
+
+isRootName :: DomainName -> Bool
+isRootName (DN [_]) = True
+isRootName _        = False
 
-labelsToName :: [DomainLabel] -> DomainName
-labelsToName = DN
+consLabel :: DomainLabel -> DomainName -> DomainName
+consLabel x (DN ys) = DN (x:ys)
+
+unconsLabel :: DomainName -> (DomainLabel, DomainName)
+unconsLabel (DN (x:xs)) = (x, DN xs)
+unconsLabel x           = error ("Illegal use of unconsLabel: " ++ show x)
 
 mkDomainName :: String -> DomainName
-mkDomainName = labelsToName . mkLabels [] . notEmpty
+mkDomainName = DN . mkLabels [] . notEmpty
     where
       notEmpty :: String -> String
       notEmpty xs = assert (not $ null xs) xs
@@ -173,37 +177,6 @@ data RecordType rt dt => ResourceRecord rt dt
     deriving (Show, Eq, Typeable)
 
 
-putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
-putRR rr = do putDomainName $ rrName rr
-              putRecordType $ rrType  rr
-              put $ rrClass rr
-              putWord32be $ rrTTL rr
-
-              let dat = runPut $
-                        putRecordData (undefined :: rt) (rrData rr)
-              putWord16be $ fromIntegral $ LBS.length dat
-              putLazyByteString dat
-
-
-getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable)
-getRR dt rt
-    = do (nm, dt1)  <- getDomainName dt
-         G.skip 2   -- record type
-         cl         <- get
-         ttl        <- G.getWord32be
-         G.skip 2   -- data length
-         (dat, dt2) <- getRecordData (undefined :: rt) dt1
-
-         let rr = ResourceRecord {
-                    rrName  = nm
-                  , rrType  = rt
-                  , rrClass = cl
-                  , rrTTL   = ttl
-                  , rrData  = dat
-                  }
-         return (rr, dt2)
-
-
 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
 
 instance Show SomeRR where
@@ -213,79 +186,149 @@ instance Eq SomeRR where
     (SomeRR a) == (SomeRR b) = Just a == cast b
 
 
-putSomeRR :: SomeRR -> Put
-putSomeRR (SomeRR rr) = putRR rr
+putSomeRR :: SomeRR -> Packer CompTable ()
+putSomeRR (SomeRR rr) = putResourceRecord rr
 
-getSomeRR :: DecompTable -> Get (SomeRR, DecompTable)
-getSomeRR dt
-    = do srt <- lookAhead $
-                do getDomainName dt -- skip
-                   getSomeRT
-         case srt of
-           SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt')
+getSomeRR :: Unpacker DecompTable SomeRR
+getSomeRR = do srt <- U.lookAhead $
+                      do getDomainName -- skip
+                         getSomeRT
+               case srt of
+                 SomeRT rt
+                     -> getResourceRecord rt >>= return . SomeRR
 
+type CompTable   = Map DomainName Int
+type DecompTable = IntMap DomainName
+type TTL         = Word32
 
-type DecompTable = IntMap BS.ByteString
-type TTL = Word32
-
-getDomainName :: DecompTable -> Get (DomainName, DecompTable)
-getDomainName = flip worker []
+getDomainName :: Unpacker DecompTable DomainName
+getDomainName = worker
     where
-      worker :: DecompTable -> [DomainLabel] -> Get (DomainName, DecompTable)
-      worker dt soFar
-          = do (l, dt') <- getDomainLabel dt
-               case BS.null l of
-                 True  -> return (labelsToName (reverse (l : soFar)), dt')
-                 False -> worker dt' (l : soFar)
-
-getDomainLabel :: DecompTable -> Get (DomainLabel, DecompTable)
-getDomainLabel dt
-    = do header <- getByteString 1
-         let Right h
-                 = runBitGet header $
-                   do a <- getBit
-                      b <- getBit
-                      n <- liftM fromIntegral (getAsWord8 6)
-                      case (a, b) of
-                        ( True,  True) -> return $ Offset n
-                        (False, False) -> return $ Length n
-                        _              -> fail "Illegal label header"
-         case h of
-           Offset n
-               -> do let Just l = IM.lookup n dt
-                     return (l, dt)
-           Length n
-               -> do offset <- liftM fromIntegral bytesRead
-                     label  <- getByteString n
-                     let dt' = IM.insert offset label dt
-                     return (label, dt')
-
-getCharString :: Get BS.ByteString
-getCharString = do len <- G.getWord8
-                   getByteString (fromIntegral len)
-
-putCharString :: BS.ByteString -> Put
-putCharString = putDomainLabel
+      worker :: Unpacker DecompTable DomainName
+      worker
+          = do offset <- U.bytesRead
+               hdr    <- getLabelHeader
+               case hdr of
+                 Offset n
+                     -> do dt <- U.getState
+                           case IM.lookup n dt of
+                             Just name
+                                 -> return name
+                             Nothing
+                                 -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
+                 Length 0
+                     -> return rootName
+                 Length n
+                     -> do label <- U.getByteString n
+                           rest  <- worker
+                           let name = consLabel label rest
+                           U.modifyState $ IM.insert offset name
+                           return name
+
+      getLabelHeader :: Unpacker s LabelHeader
+      getLabelHeader
+          = do header <- U.lookAhead $ U.getByteString 1
+               let Right h
+                       = runBitGet header $
+                         do a <- getBit
+                            b <- getBit
+                            n <- liftM fromIntegral (getAsWord8 6)
+                            case (a, b) of
+                              ( True,  True) -> return $ Offset n
+                              (False, False) -> return $ Length n
+                              _              -> fail "Illegal label header"
+               case h of
+                 Offset _
+                     -> do header' <- U.getByteString 2 -- Pointers have 2 octets.
+                           let Right h'
+                                   = runBitGet header' $
+                                     do BG.skip 2
+                                        n <- liftM fromIntegral (getAsWord16 14)
+                                        return $ Offset n
+                           return h'
+                 len@(Length _)
+                     -> do U.skip 1
+                           return len
+
+
+getCharString :: Unpacker s BS.ByteString
+getCharString = do len <- U.getWord8
+                   U.getByteString (fromIntegral len)
+
+putCharString :: BS.ByteString -> Packer s ()
+putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs
+                      P.putByteString xs
 
 data LabelHeader
     = Offset !Int
     | Length !Int
 
-putDomainName :: DomainName -> Put
-putDomainName = mapM_ putDomainLabel . nameToLabels
+putDomainName :: DomainName -> Packer CompTable ()
+putDomainName name
+    = do ct <- P.getState
+         case M.lookup name ct of
+           Just n
+               -> do let ptr = runBitPut $
+                               do putBit True
+                                  putBit True
+                                  putNBits 14 n
+                     P.putLazyByteString ptr
+           Nothing
+               -> do offset <- bytesWrote
+                     P.modifyState $ M.insert name offset
+
+                     let (label, rest) = unconsLabel name
+
+                     putCharString label
+
+                     if isRootName rest then
+                         P.putWord8 0
+                       else
+                         putDomainName rest
 
-putDomainLabel :: DomainLabel -> Put
-putDomainLabel l
-    = do putWord8 $ fromIntegral $ BS.length l
-         P.putByteString l
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
-    putRecordType :: rt -> Put
-    putRecordData :: rt -> dt -> Put
-    getRecordData :: rt -> DecompTable -> Get (dt, DecompTable)
-
-    putRecordType = putWord16be . fromIntegral . rtToInt
+    putRecordData :: rt -> dt -> Packer CompTable ()
+    getRecordData :: rt -> Unpacker DecompTable dt
+
+    putRecordType :: rt -> Packer s ()
+    putRecordType = P.putWord16be . fromIntegral . rtToInt
+
+    putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
+    putResourceRecord rr
+        = do putDomainName $ rrName  rr
+             putRecordType $ rrType  rr
+             putBinary     $ rrClass rr
+             P.putWord32be $ rrTTL   rr
+
+             -- First, write a dummy data length.
+             offset <- bytesWrote
+             P.putWord16be 0
+
+             -- Second, write data.
+             putRecordData (rrType rr) (rrData rr)
+
+             -- Third, rewrite the dummy length to an actual value.
+             offset' <- bytesWrote
+             withOffset offset
+                 $ P.putWord16be (fromIntegral (offset' - offset - 2))
+
+    getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
+    getResourceRecord rt
+        = do name     <- getDomainName
+             U.skip 2 -- record type
+             cl       <- getBinary
+             ttl      <- U.getWord32be
+             U.skip 2 -- data length
+             dat      <- getRecordData rt
+             return $ ResourceRecord {
+                          rrName  = name
+                        , rrType  = rt
+                        , rrClass = cl
+                        , rrTTL   = ttl
+                        , rrData  = dat
+                        }
 
 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
 
@@ -295,17 +338,29 @@ instance Show SomeRT where
 instance Eq SomeRT where
     (SomeRT a) == (SomeRT b) = Just a == cast b
 
-putSomeRT :: SomeRT -> Put
+putSomeRT :: SomeRT -> Packer s ()
 putSomeRT (SomeRT rt) = putRecordType rt
 
-getSomeRT :: Get SomeRT
-getSomeRT = do n <- liftM fromIntegral G.getWord16be
+getSomeRT :: Unpacker s SomeRT
+getSomeRT = do n <- liftM fromIntegral U.getWord16be
                case IM.lookup n defaultRTTable of
                  Nothing
                      -> fail ("Unknown resource record type: " ++ show n)
                  Just srt
                      -> return srt
 
+data A = A deriving (Show, Eq, Typeable)
+instance RecordType A HostAddress where
+    rtToInt       _ = 1
+    putRecordData _ = P.putWord32be
+    getRecordData _ = U.getWord32be
+
+data NS = NS deriving (Show, Eq, Typeable)
+instance RecordType NS DomainName where
+    rtToInt       _ = 2
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
 data CNAME = CNAME deriving (Show, Eq, Typeable)
 instance RecordType CNAME DomainName where
     rtToInt       _ = 5
@@ -317,9 +372,10 @@ instance RecordType HINFO (BS.ByteString, BS.ByteString) where
     rtToInt       _           = 13
     putRecordData _ (cpu, os) = do putCharString cpu
                                    putCharString os
-    getRecordData _ dt        = do cpu <- getCharString
+    getRecordData _           = do cpu <- getCharString
                                    os  <- getCharString
-                                   return ((cpu, os), dt)
+                                   return (cpu, os)
+
 
 {-
 data RecordType
@@ -349,25 +405,27 @@ data RecordType
 -}
 
 instance Binary Message where
-    put m = do put $ msgHeader m
-               putWord16be $ fromIntegral $ length $ msgQuestions m
-               putWord16be $ fromIntegral $ length $ msgAnswers m
-               putWord16be $ fromIntegral $ length $ msgAuthorities m
-               putWord16be $ fromIntegral $ length $ msgAdditionals m
-               mapM_ putQ  $ msgQuestions m
+    put m = P.liftToBinary M.empty $
+            do putBinary $ msgHeader m
+               P.putWord16be $ fromIntegral $ length $ msgQuestions m
+               P.putWord16be $ fromIntegral $ length $ msgAnswers m
+               P.putWord16be $ fromIntegral $ length $ msgAuthorities m
+               P.putWord16be $ fromIntegral $ length $ msgAdditionals m
+               mapM_ putQ      $ msgQuestions m
                mapM_ putSomeRR $ msgAnswers m
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
 
-    get = do hdr  <- get
-             nQ   <- liftM fromIntegral G.getWord16be
-             nAns <- liftM fromIntegral G.getWord16be
-             nAth <- liftM fromIntegral G.getWord16be
-             nAdd <- liftM fromIntegral G.getWord16be
-             (qs  , dt1) <- replicateM' nQ   getQ IM.empty
-             (anss, dt2) <- replicateM' nAns getSomeRR dt1
-             (aths, dt3) <- replicateM' nAth getSomeRR dt2
-             (adds, _  ) <- replicateM' nAdd getSomeRR dt3
+    get = U.liftToBinary IM.empty $
+          do hdr  <- getBinary
+             nQ   <- liftM fromIntegral U.getWord16be
+             nAns <- liftM fromIntegral U.getWord16be
+             nAth <- liftM fromIntegral U.getWord16be
+             nAdd <- liftM fromIntegral U.getWord16be
+             qs   <- replicateM nQ   getQ
+             anss <- replicateM nAns getSomeRR
+             aths <- replicateM nAth getSomeRR
+             adds <- replicateM nAdd getSomeRR
              return Message {
                           msgHeader      = hdr
                         , msgQuestions   = qs
@@ -377,8 +435,8 @@ instance Binary Message where
                         }
 
 instance Binary Header where
-    put h = do putWord16be $ hdMessageID h
-               putLazyByteString flags
+    put h = do P'.putWord16be $ hdMessageID h
+               P'.putLazyByteString flags
         where
           flags = runBitPut $
                   do putNBits 1 $ fromEnum $ hdMessageType h
@@ -391,7 +449,7 @@ instance Binary Header where
                      putNBits 4 $ fromEnum $ hdResponseCode h
 
     get = do mID   <- G.getWord16be
-             flags <- getByteString 2
+             flags <- G.getByteString 2
              let Right hd
                      = runBitGet flags $
                        do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
@@ -510,13 +568,26 @@ instance Enum RecordClass where
 
 instance Binary RecordClass where
     get = liftM (toEnum . fromIntegral) G.getWord16be
-    put = putWord16be . fromIntegral . fromEnum
+    put = P'.putWord16be . fromIntegral . fromEnum
 
 
 defaultRTTable :: IntMap SomeRT
 defaultRTTable = IM.fromList $ map toPair $
-                 [ SomeRT CNAME
+                 [ wrapRecordType A
+                 , wrapRecordType NS
+                 , wrapRecordType CNAME
+                 , wrapRecordType HINFO
                  ]
     where
       toPair :: SomeRT -> (Int, SomeRT)
       toPair srt@(SomeRT rt) = (rtToInt rt, srt)
+
+
+wrapQueryType :: RecordType rt dt => rt -> SomeQT
+wrapQueryType = SomeRT
+
+wrapRecordType :: RecordType rt dt => rt -> SomeRT
+wrapRecordType = SomeRT
+
+wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
+wrapRecord = SomeRR
\ No newline at end of file