]> gitweb @ CieloNegro.org - haskell-dns.git/blobdiff - Network/DNS/Message.hs
Response parsing
[haskell-dns.git] / Network / DNS / Message.hs
index e6aaaa5e2af26fd3671fc812844148f562adb827..7bedacf5a0922b1816280a8ae4162f9aaf3ff698 100644 (file)
@@ -10,15 +10,26 @@ module Network.DNS.Message
     , DomainName
     , DomainLabel
     , TTL
-    , SomeRR(..)
-    , RecordType(..)
+    , RecordType
     , RecordClass(..)
 
+    , SomeQT
+    , SomeRR
+    , SomeRT
+
+    , A(..)
+    , NS(..)
     , CNAME(..)
     , HINFO(..)
+
+    , mkDomainName
+    , wrapQueryType
+    , wrapRecordType
+    , wrapRecord
     )
     where
 
+import           Control.Exception
 import           Control.Monad
 import           Data.Binary
 import           Data.Binary.BitPut as BP
@@ -26,11 +37,13 @@ import           Data.Binary.Get as G
 import           Data.Binary.Put as P
 import           Data.Binary.Strict.BitGet as BG
 import qualified Data.ByteString as BS
+import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
 import qualified Data.ByteString.Lazy as LBS
 import           Data.Typeable
 import qualified Data.IntMap as IM
 import           Data.IntMap (IntMap)
 import           Data.Word
+import           Network.Socket
 
 
 replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a)
@@ -50,6 +63,7 @@ data Message
       , msgAuthorities :: ![SomeRR]
       , msgAdditionals :: ![SomeRR]
       }
+    deriving (Show, Eq)
 
 data Header
     = Header {
@@ -68,6 +82,7 @@ data Header
       -- + NSCOUNT
       -- + ARCOUNT
       }
+    deriving (Show, Eq)
 
 type MessageID = Word16
 
@@ -94,21 +109,23 @@ data ResponseCode
 data Question
     = Question {
         qName  :: !DomainName
-      , qType  :: !RecordType
+      , qType  :: !SomeQT
       , qClass :: !RecordClass
       }
     deriving (Show, Eq)
 
+type SomeQT = SomeRT
+
 putQ :: Question -> Put
 putQ q
     = do putDomainName $ qName q
-         put $ qType  q
+         putSomeRT $ qType q
          put $ qClass q
 
 getQ :: DecompTable -> Get (Question, DecompTable)
 getQ dt
     = do (nm, dt') <- getDomainName dt
-         ty        <- get
+         ty        <- getSomeRT
          cl        <- get
          let q = Question {
                    qName  = nm
@@ -117,8 +134,33 @@ getQ dt
                  }
          return (q, dt')
 
-type DomainName  = [DomainLabel]
-type DomainLabel = BS.ByteString
+newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
+type DomainLabel    = BS.ByteString
+
+nameToLabels :: DomainName -> [DomainLabel]
+nameToLabels (DN ls) = ls
+
+labelsToName :: [DomainLabel] -> DomainName
+labelsToName = DN
+
+rootName :: DomainName
+rootName = DN [BS.empty]
+
+consLabel :: DomainLabel -> DomainName -> DomainName
+consLabel x (DN ys) = DN (x:ys)
+
+mkDomainName :: String -> DomainName
+mkDomainName = labelsToName . mkLabels [] . notEmpty
+    where
+      notEmpty :: String -> String
+      notEmpty xs = assert (not $ null xs) xs
+
+      mkLabels :: [DomainLabel] -> String -> [DomainLabel]
+      mkLabels soFar [] = reverse (C8.empty : soFar)
+      mkLabels soFar xs = case break (== '.') xs of
+                            (l, ('.':rest))
+                                -> mkLabels (C8.pack l : soFar) rest
+                            _   -> error ("Illegal domain name: " ++ xs)
 
 data RecordClass
     = IN
@@ -128,122 +170,121 @@ data RecordClass
     | AnyClass -- Only for queries
     deriving (Show, Eq)
 
-class (Typeable rr, Show rr, Eq rr) => ResourceRecord rr where
-    rrName    :: rr -> DomainName
-    rrType    :: rr -> RecordType
-    rrClass   :: rr -> RecordClass
-    rrTTL     :: rr -> TTL
-    rrPutData :: rr -> Put
-    rrGetData :: DecompTable -> DomainName -> RecordClass -> TTL -> Get (rr, DecompTable)
-    toRR      :: rr -> SomeRR
-    fromRR    :: SomeRR -> Maybe rr
 
-    toRR   rr           = SomeRR rr
-    fromRR (SomeRR rr') = cast rr'
+data RecordType rt dt => ResourceRecord rt dt
+    = ResourceRecord {
+        rrName  :: !DomainName
+      , rrType  :: !rt
+      , rrClass :: !RecordClass
+      , rrTTL   :: !TTL
+      , rrData  :: !dt
+      }
+    deriving (Show, Eq, Typeable)
+
 
-putRR :: ResourceRecord rr => rr -> Put
+putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
 putRR rr = do putDomainName $ rrName rr
-              put $ rrType  rr
+              putRecordType $ rrType  rr
               put $ rrClass rr
               putWord32be $ rrTTL rr
 
-              let dat = runPut $ rrPutData rr
+              let dat = runPut $
+                        putRecordData (undefined :: rt) (rrData rr)
               putWord16be $ fromIntegral $ LBS.length dat
               putLazyByteString dat
 
-getRR :: DecompTable -> Get (SomeRR, DecompTable)
-getRR dt
-    = do (nm, dt') <- getDomainName dt
-         ty        <- get
-         cl        <- get
-         ttl       <- G.getWord32be
-         case ty of
-           CNAME   -> do (rr, dt'') <- rrGetData dt' nm cl ttl
-                         return (toRR (rr :: CNAME), dt'')
-           HINFO   -> do (rr, dt'') <- rrGetData dt' nm cl ttl
-                         return (toRR (rr :: HINFO), dt'')
-           AXFR    -> onlyForQuestions "AXFR"
-           MAILB   -> onlyForQuestions "MAILB"
-           MAILA   -> onlyForQuestions "MAILA"
-           AnyType -> onlyForQuestions "ANY"
-    where
-      onlyForQuestions name
-          = fail (name ++ " is only for questions, not an actual resource record.")
-
-data SomeRR = forall rr. ResourceRecord rr => SomeRR rr
-              deriving Typeable
-instance ResourceRecord SomeRR where
-    rrName    (SomeRR rr) = rrName  rr
-    rrType    (SomeRR rr) = rrType  rr
-    rrClass   (SomeRR rr) = rrClass rr
-    rrTTL     (SomeRR rr) = rrTTL   rr
-    rrPutData (SomeRR rr) = rrPutData rr
-    rrGetData _ _ _ _     = fail "SomeRR can't directly be constructed."
-    toRR   = id
-    fromRR = Just
-instance Eq SomeRR where
-    (SomeRR a) == (SomeRR b) = Just a == cast b
+
+getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable)
+getRR dt rt
+    = do (nm, dt1)  <- getDomainName dt
+         G.skip 2   -- record type
+         cl         <- get
+         ttl        <- G.getWord32be
+         G.skip 2   -- data length
+         (dat, dt2) <- getRecordData (undefined :: rt) dt1
+
+         let rr = ResourceRecord {
+                    rrName  = nm
+                  , rrType  = rt
+                  , rrClass = cl
+                  , rrTTL   = ttl
+                  , rrData  = dat
+                  }
+         return (rr, dt2)
+
+
+data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
+
 instance Show SomeRR where
     show (SomeRR rr) = show rr
 
-type DecompTable = IntMap BS.ByteString
-type TTL = Word32
+instance Eq SomeRR where
+    (SomeRR a) == (SomeRR b) = Just a == cast b
 
-data CNAME = CNAME' !DomainName !RecordClass !TTL !DomainName
-             deriving (Eq, Show, Typeable)
-instance ResourceRecord CNAME where
-    rrName    (CNAME' n _ _ _) = n
-    rrType    _                = CNAME
-    rrClass   (CNAME' _ c _ _) = c
-    rrTTL     (CNAME' _ _ t _) = t
-    rrGetData dt n c t         = do (d, dt') <- getDomainName dt
-                                    return (CNAME' n c t d, dt')
-    rrPutData (CNAME' _ _ _ d) = putDomainName d
-
-data HINFO = HINFO' !DomainName !RecordClass !TTL !BS.ByteString !BS.ByteString
-             deriving (Eq, Show, Typeable)
-instance ResourceRecord HINFO where
-    rrName    (HINFO' n _ _ _ _) = n
-    rrType    _                  = HINFO
-    rrClass   (HINFO' _ c _ _ _) = c
-    rrTTL     (HINFO' _ _ t _ _) = t
-    rrGetData dt n c t           = do cpu <- getCharString
-                                      os  <- getCharString
-                                      return (HINFO' n c t cpu os, dt)
-    rrPutData (HINFO' _ _ _ c o) = do putCharString c
-                                      putCharString o
+
+putSomeRR :: SomeRR -> Put
+putSomeRR (SomeRR rr) = putRR rr
+
+getSomeRR :: DecompTable -> Get (SomeRR, DecompTable)
+getSomeRR dt
+    = do srt <- lookAhead $
+                do getDomainName dt -- skip
+                   getSomeRT
+         case srt of
+           SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt')
+
+
+type DecompTable = IntMap DomainName
+type TTL = Word32
 
 getDomainName :: DecompTable -> Get (DomainName, DecompTable)
-getDomainName = flip worker []
+getDomainName = worker
     where
-      worker :: DecompTable -> [DomainLabel] -> Get ([DomainLabel], DecompTable)
-      worker dt soFar
-          = do (l, dt') <- getDomainLabel dt
-               case BS.null l of
-                 True  -> return (reverse (l : soFar), dt')
-                 False -> worker dt' (l : soFar)
-
-getDomainLabel :: DecompTable -> Get (DomainLabel, DecompTable)
-getDomainLabel dt
-    = do header <- getByteString 1
-         let Right h
-                 = runBitGet header $
-                   do a <- getBit
-                      b <- getBit
-                      n <- liftM fromIntegral (getAsWord8 6)
-                      case (a, b) of
-                        ( True,  True) -> return $ Offset n
-                        (False, False) -> return $ Length n
-                        _              -> fail "Illegal label header"
-         case h of
-           Offset n
-               -> do let Just l = IM.lookup n dt
-                     return (l, dt)
-           Length n
-               -> do offset <- liftM fromIntegral bytesRead
-                     label  <- getByteString n
-                     let dt' = IM.insert offset label dt
-                     return (label, dt')
+      worker :: DecompTable -> Get (DomainName, DecompTable)
+      worker dt
+          = do offset <- liftM fromIntegral bytesRead
+               hdr    <- getLabelHeader
+               case hdr of
+                 Offset n
+                     -> case IM.lookup n dt of
+                          Just name
+                              -> return (name, dt)
+                          Nothing
+                              -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
+                 Length 0
+                     -> return (rootName, dt)
+                 Length n
+                     -> do label       <- getByteString n
+                           (rest, dt') <- worker dt
+                           let name = consLabel label rest
+                               dt'' = IM.insert offset name dt'
+                           return (name, dt'')
+
+      getLabelHeader :: Get LabelHeader
+      getLabelHeader
+          = do header <- lookAhead $ getByteString 1
+               let Right h
+                       = runBitGet header $
+                         do a <- getBit
+                            b <- getBit
+                            n <- liftM fromIntegral (getAsWord8 6)
+                            case (a, b) of
+                              ( True,  True) -> return $ Offset n
+                              (False, False) -> return $ Length n
+                              _              -> fail "Illegal label header"
+               case h of
+                 Offset _
+                     -> do header' <- getByteString 2 -- Pointers have 2 octets.
+                           let Right h'
+                                   = runBitGet header' $
+                                     do BG.skip 2
+                                        n <- liftM fromIntegral (getAsWord16 14)
+                                        return $ Offset n
+                           return h'
+                 len@(Length _)
+                     -> do G.skip 1
+                           return len
+
 
 getCharString :: Get BS.ByteString
 getCharString = do len <- G.getWord8
@@ -257,13 +298,71 @@ data LabelHeader
     | Length !Int
 
 putDomainName :: DomainName -> Put
-putDomainName = mapM_ putDomainLabel
+putDomainName = mapM_ putDomainLabel . nameToLabels
 
 putDomainLabel :: DomainLabel -> Put
 putDomainLabel l
     = do putWord8 $ fromIntegral $ BS.length l
          P.putByteString l
 
+class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
+    rtToInt       :: rt -> Int
+    putRecordType :: rt -> Put
+    putRecordData :: rt -> dt -> Put
+    getRecordData :: rt -> DecompTable -> Get (dt, DecompTable)
+
+    putRecordType = putWord16be . fromIntegral . rtToInt
+
+data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
+
+instance Show SomeRT where
+    show (SomeRT rt) = show rt
+
+instance Eq SomeRT where
+    (SomeRT a) == (SomeRT b) = Just a == cast b
+
+putSomeRT :: SomeRT -> Put
+putSomeRT (SomeRT rt) = putRecordType rt
+
+getSomeRT :: Get SomeRT
+getSomeRT = do n <- liftM fromIntegral G.getWord16be
+               case IM.lookup n defaultRTTable of
+                 Nothing
+                     -> fail ("Unknown resource record type: " ++ show n)
+                 Just srt
+                     -> return srt
+
+data A = A deriving (Show, Eq, Typeable)
+instance RecordType A HostAddress where
+    rtToInt       _ = 1
+    putRecordData _ = putWord32be
+    getRecordData _ = \ dt ->
+                      do addr <- G.getWord32be
+                         return (addr, dt)
+
+data NS = NS deriving (Show, Eq, Typeable)
+instance RecordType NS DomainName where
+    rtToInt       _ = 2
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data CNAME = CNAME deriving (Show, Eq, Typeable)
+instance RecordType CNAME DomainName where
+    rtToInt       _ = 5
+    putRecordData _ = putDomainName
+    getRecordData _ = getDomainName
+
+data HINFO = HINFO deriving (Show, Eq, Typeable)
+instance RecordType HINFO (BS.ByteString, BS.ByteString) where
+    rtToInt       _           = 13
+    putRecordData _ (cpu, os) = do putCharString cpu
+                                   putCharString os
+    getRecordData _ dt        = do cpu <- getCharString
+                                   os  <- getCharString
+                                   return ((cpu, os), dt)
+
+
+{-
 data RecordType
     = A
     | NS
@@ -288,6 +387,7 @@ data RecordType
     | MAILA -- Obsolete
     | AnyType
     deriving (Show, Eq)
+-}
 
 instance Binary Message where
     put m = do put $ msgHeader m
@@ -296,19 +396,19 @@ instance Binary Message where
                putWord16be $ fromIntegral $ length $ msgAuthorities m
                putWord16be $ fromIntegral $ length $ msgAdditionals m
                mapM_ putQ  $ msgQuestions m
-               mapM_ putRR $ msgAnswers m
-               mapM_ putRR $ msgAuthorities m
-               mapM_ putRR $ msgAdditionals m
+               mapM_ putSomeRR $ msgAnswers m
+               mapM_ putSomeRR $ msgAuthorities m
+               mapM_ putSomeRR $ msgAdditionals m
 
     get = do hdr  <- get
              nQ   <- liftM fromIntegral G.getWord16be
              nAns <- liftM fromIntegral G.getWord16be
              nAth <- liftM fromIntegral G.getWord16be
              nAdd <- liftM fromIntegral G.getWord16be
-             (qs  , dt1) <- replicateM' nQ   getQ  IM.empty
-             (anss, dt2) <- replicateM' nAns getRR dt1
-             (aths, dt3) <- replicateM' nAth getRR dt2
-             (adds, _  ) <- replicateM' nAdd getRR dt3
+             (qs  , dt1) <- replicateM' nQ   getQ IM.empty
+             (anss, dt2) <- replicateM' nAns getSomeRR dt1
+             (aths, dt3) <- replicateM' nAth getSomeRR dt2
+             (adds, _  ) <- replicateM' nAdd getSomeRR dt3
              return Message {
                           msgHeader      = hdr
                         , msgQuestions   = qs
@@ -389,6 +489,7 @@ instance Enum ResponseCode where
     toEnum 5 = Refused
     toEnum _ = undefined
 
+{-
 instance Enum RecordType where
     fromEnum A       = 1
     fromEnum NS      = 2
@@ -432,6 +533,7 @@ instance Enum RecordType where
     toEnum 254 = MAILA
     toEnum 255 = AnyType
     toEnum _  = undefined
+-}
 
 instance Enum RecordClass where
     fromEnum IN       = 1
@@ -447,10 +549,28 @@ instance Enum RecordClass where
     toEnum 255 = AnyClass
     toEnum _   = undefined
 
-instance Binary RecordType where
+instance Binary RecordClass where
     get = liftM (toEnum . fromIntegral) G.getWord16be
     put = putWord16be . fromIntegral . fromEnum
 
-instance Binary RecordClass where
-    get = liftM (toEnum . fromIntegral) G.getWord16be
-    put = putWord16be . fromIntegral . fromEnum
\ No newline at end of file
+
+defaultRTTable :: IntMap SomeRT
+defaultRTTable = IM.fromList $ map toPair $
+                 [ wrapRecordType A
+                 , wrapRecordType NS
+                 , wrapRecordType CNAME
+                 , wrapRecordType HINFO
+                 ]
+    where
+      toPair :: SomeRT -> (Int, SomeRT)
+      toPair srt@(SomeRT rt) = (rtToInt rt, srt)
+
+
+wrapQueryType :: RecordType rt dt => rt -> SomeQT
+wrapQueryType = SomeRT
+
+wrapRecordType :: RecordType rt dt => rt -> SomeRT
+wrapRecordType = SomeRT
+
+wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
+wrapRecord = SomeRR
\ No newline at end of file