]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Message.hs
More record types...
[haskell-dns.git] / Network / DNS / Message.hs
index be0b79a33800a32dcc769c7ff68129f322261bf4..ab1a15426430bb9b7d559561562a1e5fc4b35af8 100644 (file)
@@ -13,14 +13,27 @@ module Network.DNS.Message
     , RecordType
     , RecordClass(..)
 
+    , SOAFields(..)
+
     , SomeQT
     , SomeRR
     , SomeRT
 
     , A(..)
     , NS(..)
+    , MD(..)
+    , MF(..)
     , CNAME(..)
+    , SOA(..)
+    , MB(..)
+    , MG(..)
+    , MR(..)
+    , NULL(..)
+    , PTR(..)
     , HINFO(..)
+    , MINFO(..)
+    , MX(..)
+    , TXT(..)
 
     , mkDomainName
     , wrapQueryType
@@ -34,15 +47,17 @@ import           Control.Monad
 import           Data.Binary
 import           Data.Binary.BitPut as BP
 import           Data.Binary.Get as G
-import           Data.Binary.Put as P
+import           Data.Binary.Put as P'
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
-import qualified Data.ByteString.Lazy as LBS
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
+import qualified Data.Map as M
+import           Data.Map (Map)
 import           Data.Word
+import           Network.DNS.Packer as P
 import           Network.DNS.Unpacker as U
 import           Network.Socket
 
@@ -108,11 +123,11 @@ data Question
 
 type SomeQT = SomeRT
 
-putQ :: Question -> Put
+putQ :: Question -> Packer CompTable ()
 putQ q
     = do putDomainName $ qName q
          putSomeRT $ qType q
-         put $ qClass q
+         putBinary $ qClass q
 
 getQ :: Unpacker DecompTable Question
 getQ = do nm <- getDomainName
@@ -125,23 +140,25 @@ getQ = do nm <- getDomainName
                      }
 
 
-newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
 type DomainLabel    = BS.ByteString
 
-nameToLabels :: DomainName -> [DomainLabel]
-nameToLabels (DN ls) = ls
-
-labelsToName :: [DomainLabel] -> DomainName
-labelsToName = DN
-
 rootName :: DomainName
 rootName = DN [BS.empty]
 
+isRootName :: DomainName -> Bool
+isRootName (DN [_]) = True
+isRootName _        = False
+
 consLabel :: DomainLabel -> DomainName -> DomainName
 consLabel x (DN ys) = DN (x:ys)
 
+unconsLabel :: DomainName -> (DomainLabel, DomainName)
+unconsLabel (DN (x:xs)) = (x, DN xs)
+unconsLabel x           = error ("Illegal use of unconsLabel: " ++ show x)
+
 mkDomainName :: String -> DomainName
-mkDomainName = labelsToName . mkLabels [] . notEmpty
+mkDomainName = DN . mkLabels [] . notEmpty
     where
       notEmpty :: String -> String
       notEmpty xs = assert (not $ null xs) xs
@@ -173,18 +190,6 @@ data RecordType rt dt => ResourceRecord rt dt
     deriving (Show, Eq, Typeable)
 
 
-putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
-putRR rr = do putDomainName $ rrName rr
-              putRecordType $ rrType  rr
-              put $ rrClass rr
-              putWord32be $ rrTTL rr
-
-              let dat = runPut $
-                        putRecordData (undefined :: rt) (rrData rr)
-              putWord16be $ fromIntegral $ LBS.length dat
-              putLazyByteString dat
-
-
 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
 
 instance Show SomeRR where
@@ -194,8 +199,8 @@ instance Eq SomeRR where
     (SomeRR a) == (SomeRR b) = Just a == cast b
 
 
-putSomeRR :: SomeRR -> Put
-putSomeRR (SomeRR rr) = putRR rr
+putSomeRR :: SomeRR -> Packer CompTable ()
+putSomeRR (SomeRR rr) = putResourceRecord rr
 
 getSomeRR :: Unpacker DecompTable SomeRR
 getSomeRR = do srt <- U.lookAhead $
@@ -205,8 +210,9 @@ getSomeRR = do srt <- U.lookAhead $
                  SomeRT rt
                      -> getResourceRecord rt >>= return . SomeRR
 
+type CompTable   = Map DomainName Int
 type DecompTable = IntMap DomainName
-type TTL = Word32
+type TTL         = Word32
 
 getDomainName :: Unpacker DecompTable DomainName
 getDomainName = worker
@@ -217,7 +223,7 @@ getDomainName = worker
                hdr    <- getLabelHeader
                case hdr of
                  Offset n
-                     -> do dt <- getState
+                     -> do dt <- U.getState
                            case IM.lookup n dt of
                              Just name
                                  -> return name
@@ -229,7 +235,7 @@ getDomainName = worker
                      -> do label <- U.getByteString n
                            rest  <- worker
                            let name = consLabel label rest
-                           modifyState $ IM.insert offset name
+                           U.modifyState $ IM.insert offset name
                            return name
 
       getLabelHeader :: Unpacker s LabelHeader
@@ -262,28 +268,81 @@ getCharString :: Unpacker s BS.ByteString
 getCharString = do len <- U.getWord8
                    U.getByteString (fromIntegral len)
 
-putCharString :: BS.ByteString -> Put
-putCharString = putDomainLabel
+putCharString :: BS.ByteString -> Packer s ()
+putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs
+                      P.putByteString xs
 
 data LabelHeader
     = Offset !Int
     | Length !Int
 
-putDomainName :: DomainName -> Put
-putDomainName = mapM_ putDomainLabel . nameToLabels
+putDomainName :: DomainName -> Packer CompTable ()
+putDomainName name
+    = do ct <- P.getState
+         case M.lookup name ct of
+           Just n
+               -> do let ptr = runBitPut $
+                               do putBit True
+                                  putBit True
+                                  putNBits 14 n
+                     P.putLazyByteString ptr
+           Nothing
+               -> do offset <- bytesWrote
+                     P.modifyState $ M.insert name offset
+
+                     let (label, rest) = unconsLabel name
+
+                     putCharString label
+
+                     if isRootName rest then
+                         P.putWord8 0
+                       else
+                         putDomainName rest
 
-putDomainLabel :: DomainLabel -> Put
-putDomainLabel l
-    = do putWord8 $ fromIntegral $ BS.length l
-         P.putByteString l
 
 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
     rtToInt       :: rt -> Int
-    putRecordData :: rt -> dt -> Put
+    putRecordData :: rt -> dt -> Packer CompTable ()
     getRecordData :: rt -> Unpacker DecompTable dt
 
-    putRecordType :: rt -> Put
-    putRecordType = putWord16be . fromIntegral . rtToInt
+    putRecordType :: rt -> Packer s ()
+    putRecordType = P.putWord16be . fromIntegral . rtToInt
+
+    putRecordDataWithLength :: rt -> dt -> Packer CompTable ()
+    putRecordDataWithLength rt dt
+        = do -- First, write a dummy data length.
+             offset <- bytesWrote
+             P.putWord16be 0
+
+             -- Second, write data.
+             putRecordData rt dt
+
+             -- Third, rewrite the dummy length to an actual value.
+             offset' <- bytesWrote
+             withOffset offset
+                 $ P.putWord16be (fromIntegral (offset' - offset - 2))
+
+    putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
+    putResourceRecord rr
+        = do putDomainName $ rrName  rr
+             putRecordType $ rrType  rr
+             putBinary     $ rrClass rr
+             P.putWord32be $ rrTTL   rr
+             putRecordDataWithLength (rrType rr) (rrData rr)
+
+    getRecordDataWithLength :: rt -> Unpacker DecompTable dt
+    getRecordDataWithLength rt
+        = do len     <- U.getWord16be
+             offset  <- U.bytesRead
+             dat     <- getRecordData rt
+             offset' <- U.bytesRead
+
+             let consumed = offset' - offset
+             when (consumed /= len)
+                      $ fail ("getRecordData " ++ show rt ++ " consumed " ++ show consumed ++
+                              " bytes but it had to consume " ++ show len ++ " bytes")
+
+             return dat
 
     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
     getResourceRecord rt
@@ -291,8 +350,7 @@ class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType r
              U.skip 2 -- record type
              cl       <- getBinary
              ttl      <- U.getWord32be
-             U.skip 2 -- data length
-             dat      <- getRecordData rt
+             dat      <- getRecordDataWithLength rt
              return $ ResourceRecord {
                           rrName  = name
                         , rrType  = rt
@@ -309,7 +367,7 @@ instance Show SomeRT where
 instance Eq SomeRT where
     (SomeRT a) == (SomeRT b) = Just a == cast b
 
-putSomeRT :: SomeRT -> Put
+putSomeRT :: SomeRT -> Packer s ()
 putSomeRT (SomeRT rt) = putRecordType rt
 
 getSomeRT :: Unpacker s SomeRT
@@ -320,10 +378,22 @@ getSomeRT = do n <- liftM fromIntegral U.getWord16be
                  Just srt
                      -> return srt
 
+data SOAFields
+    = SOAFields {
+        soaMasterNameServer   :: !DomainName
+      , soaResponsibleMailbox :: !DomainName
+      , soaSerialNumber       :: !Word32
+      , soaRefreshInterval    :: !Word32
+      , soaRetryInterval      :: !Word32
+      , soaExpirationLimit    :: !Word32
+      , soaMinimumTTL         :: !Word32
+      }
+    deriving (Show, Eq, Typeable)
+
 data A = A deriving (Show, Eq, Typeable)
 instance RecordType A HostAddress where
     rtToInt       _ = 1
-    putRecordData _ = putWord32be
+    putRecordData _ = P.putWord32be
     getRecordData _ = U.getWord32be
 
 data NS = NS deriving (Show, Eq, Typeable)
@@ -332,21 +402,129 @@ instance RecordType NS DomainName where
     putRecordData _ = putDomainName
     getRecordData _ = getDomainName
 
+data MD = MD deriving (Show, Eq, Typeable)
+instance RecordType MD DomainName where
+    rtToInt       _ = 3
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data MF = MF deriving (Show, Eq, Typeable)
+instance RecordType MF DomainName where
+    rtToInt       _ = 4
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
 data CNAME = CNAME deriving (Show, Eq, Typeable)
 instance RecordType CNAME DomainName where
     rtToInt       _ = 5
     putRecordData _ = putDomainName
     getRecordData _ = getDomainName
 
+data SOA = SOA deriving (Show, Eq, Typeable)
+instance RecordType SOA SOAFields where
+    rtToInt       _ = 6
+    putRecordData _ = \ soa ->
+                      do putDomainName $ soaMasterNameServer soa
+                         putDomainName $ soaResponsibleMailbox soa
+                         P.putWord32be $ soaSerialNumber soa
+                         P.putWord32be $ soaRefreshInterval soa
+                         P.putWord32be $ soaRetryInterval soa
+                         P.putWord32be $ soaExpirationLimit soa
+                         P.putWord32be $ soaMinimumTTL soa
+    getRecordData _ = do master  <- getDomainName
+                         mail    <- getDomainName
+                         serial  <- U.getWord32be
+                         refresh <- U.getWord32be
+                         retry   <- U.getWord32be
+                         expire  <- U.getWord32be
+                         ttl     <- U.getWord32be
+                         return SOAFields {
+                                      soaMasterNameServer   = master
+                                    , soaResponsibleMailbox = mail
+                                    , soaSerialNumber       = serial
+                                    , soaRefreshInterval    = refresh
+                                    , soaRetryInterval      = retry
+                                    , soaExpirationLimit    = expire
+                                    , soaMinimumTTL         = ttl
+                                    }
+
+data MB = MB deriving (Show, Eq, Typeable)
+instance RecordType MB DomainName where
+    rtToInt       _ = 7
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data MG = MG deriving (Show, Eq, Typeable)
+instance RecordType MG DomainName where
+    rtToInt       _ = 8
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data MR = MR deriving (Show, Eq, Typeable)
+instance RecordType MR DomainName where
+    rtToInt       _ = 9
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data NULL = NULL deriving (Show, Eq, Typeable)
+instance RecordType NULL BS.ByteString where
+    rtToInt                 _ = 10
+    putRecordData         _ _ = fail "putRecordData NULL can't be defined"
+    getRecordData           _ = fail "getRecordData NULL can't be defined"
+    putRecordDataWithLength _ = \ dat ->
+                                do P.putWord16be $ fromIntegral $ BS.length dat
+                                   P.putByteString dat
+    getRecordDataWithLength _ = do len <- U.getWord16be
+                                   U.getByteString $ fromIntegral len
+
+data PTR = PTR deriving (Show, Eq, Typeable)
+instance RecordType PTR DomainName where
+    rtToInt       _ = 12
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
 data HINFO = HINFO deriving (Show, Eq, Typeable)
 instance RecordType HINFO (BS.ByteString, BS.ByteString) where
-    rtToInt       _           = 13
-    putRecordData _ (cpu, os) = do putCharString cpu
-                                   putCharString os
-    getRecordData _           = do cpu <- getCharString
-                                   os  <- getCharString
-                                   return (cpu, os)
-
+    rtToInt       _ = 13
+    putRecordData _ = \ (cpu, os) ->
+                      do putCharString cpu
+                         putCharString os
+    getRecordData _ = do cpu <- getCharString
+                         os  <- getCharString
+                         return (cpu, os)
+
+data MINFO = MINFO deriving (Show, Eq, Typeable)
+instance RecordType MINFO (DomainName, DomainName) where
+    rtToInt       _ = 14
+    putRecordData _ = \ (r, e) ->
+                      do putDomainName r
+                         putDomainName e
+    getRecordData _ = do r <- getDomainName
+                         e <- getDomainName
+                         return (r, e)
+
+data MX = MX deriving (Show, Eq, Typeable)
+instance RecordType MX (Word16, DomainName) where
+    rtToInt       _ = 15
+    putRecordData _ = \ (pref, exch) ->
+                      do P.putWord16be pref
+                         putDomainName exch
+    getRecordData _ = do pref <- U.getWord16be
+                         exch <- getDomainName
+                         return (pref, exch)
+
+data TXT = TXT deriving (Show, Eq, Typeable)
+instance RecordType TXT [BS.ByteString] where
+    rtToInt       _ = 16
+    putRecordData _ = mapM_ putCharString
+    getRecordData _ = fail "getRecordData TXT can't be defined"
+
+    getRecordDataWithLength _ = U.getWord16be >>= worker [] . fromIntegral
+        where
+          worker :: [BS.ByteString] -> Int -> Unpacker s [BS.ByteString]
+          worker soFar 0 = return (reverse soFar)
+          worker soFar n = do str <- getCharString
+                              worker (str : soFar) (0 `max` n - 1 - BS.length str)
 
 {-
 data RecordType
@@ -376,17 +554,18 @@ data RecordType
 -}
 
 instance Binary Message where
-    put m = do put $ msgHeader m
-               putWord16be $ fromIntegral $ length $ msgQuestions m
-               putWord16be $ fromIntegral $ length $ msgAnswers m
-               putWord16be $ fromIntegral $ length $ msgAuthorities m
-               putWord16be $ fromIntegral $ length $ msgAdditionals m
-               mapM_ putQ  $ msgQuestions m
+    put m = P.liftToBinary M.empty $
+            do putBinary $ msgHeader m
+               P.putWord16be $ fromIntegral $ length $ msgQuestions m
+               P.putWord16be $ fromIntegral $ length $ msgAnswers m
+               P.putWord16be $ fromIntegral $ length $ msgAuthorities m
+               P.putWord16be $ fromIntegral $ length $ msgAdditionals m
+               mapM_ putQ      $ msgQuestions m
                mapM_ putSomeRR $ msgAnswers m
                mapM_ putSomeRR $ msgAuthorities m
                mapM_ putSomeRR $ msgAdditionals m
 
-    get = liftToBinary IM.empty $
+    get = U.liftToBinary IM.empty $
           do hdr  <- getBinary
              nQ   <- liftM fromIntegral U.getWord16be
              nAns <- liftM fromIntegral U.getWord16be
@@ -405,8 +584,8 @@ instance Binary Message where
                         }
 
 instance Binary Header where
-    put h = do putWord16be $ hdMessageID h
-               putLazyByteString flags
+    put h = do P'.putWord16be $ hdMessageID h
+               P'.putLazyByteString flags
         where
           flags = runBitPut $
                   do putNBits 1 $ fromEnum $ hdMessageType h
@@ -478,48 +657,26 @@ instance Enum ResponseCode where
 
 {-
 instance Enum RecordType where
-    fromEnum A       = 1
-    fromEnum NS      = 2
-    fromEnum MD      = 3
-    fromEnum MF      = 4
-    fromEnum CNAME   = 5
-    fromEnum SOA     = 6
-    fromEnum MB      = 7
-    fromEnum MG      = 8
-    fromEnum MR      = 9
-    fromEnum NULL    = 10
+    fromEnum A       = 1 /
+    fromEnum NS      = 2 /
+    fromEnum MD      = 3 /
+    fromEnum MF      = 4 /
+    fromEnum CNAME   = 5 /
+    fromEnum SOA     = 6 /
+    fromEnum MB      = 7 /
+    fromEnum MG      = 8 /
+    fromEnum MR      = 9 /
+    fromEnum NULL    = 10 /
     fromEnum WKS     = 11
-    fromEnum PTR     = 12
-    fromEnum HINFO   = 13
-    fromEnum MINFO   = 14
-    fromEnum MX      = 15
-    fromEnum TXT     = 16
+    fromEnum PTR     = 12 /
+    fromEnum HINFO   = 13 /
+    fromEnum MINFO   = 14 /
+    fromEnum MX      = 15 /
+    fromEnum TXT     = 16 /
     fromEnum AXFR    = 252
     fromEnum MAILB   = 253
     fromEnum MAILA   = 254
     fromEnum AnyType = 255
-
-    toEnum 1  = A
-    toEnum 2  = NS
-    toEnum 3  = MD
-    toEnum 4  = MF
-    toEnum 5  = CNAME
-    toEnum 6  = SOA
-    toEnum 7  = MB
-    toEnum 8  = MG
-    toEnum 9  = MR
-    toEnum 10 = NULL
-    toEnum 11 = WKS
-    toEnum 12 = PTR
-    toEnum 13 = HINFO
-    toEnum 14 = MINFO
-    toEnum 15 = MX
-    toEnum 16 = TXT
-    toEnum 252 = AXFR
-    toEnum 253 = MAILB
-    toEnum 254 = MAILA
-    toEnum 255 = AnyType
-    toEnum _  = undefined
 -}
 
 instance Enum RecordClass where
@@ -538,7 +695,7 @@ instance Enum RecordClass where
 
 instance Binary RecordClass where
     get = liftM (toEnum . fromIntegral) G.getWord16be
-    put = putWord16be . fromIntegral . fromEnum
+    put = P'.putWord16be . fromIntegral . fromEnum
 
 
 defaultRTTable :: IntMap SomeRT