module Network.DNS.Message ( Message(..) , MessageID , MessageType(..) , Header(..) , Opcode(..) , ResponseCode(..) , Question(..) , ResourceRecord(..) , DomainName , DomainLabel , TTL , RecordType , RecordClass(..) , SOAFields(..) , SomeQT , SomeRR , SomeRT , A(..) , NS(..) , MD(..) , MF(..) , CNAME(..) , SOA(..) , MB(..) , MG(..) , MR(..) , NULL(..) , PTR(..) , HINFO(..) , MINFO(..) , MX(..) , TXT(..) , mkDomainName , wrapQueryType , wrapRecordType , wrapRecord ) where import Control.Exception import Control.Monad import Data.Binary import Data.Binary.BitPut as BP import Data.Binary.Get as G import Data.Binary.Put as P' import Data.Binary.Strict.BitGet as BG import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as C8 hiding (ByteString) import Data.Typeable import qualified Data.IntMap as IM import Data.IntMap (IntMap) import qualified Data.Map as M import Data.Map (Map) import Data.Word import Network.DNS.Packer as P import Network.DNS.Unpacker as U import Network.Socket data Message = Message { msgHeader :: !Header , msgQuestions :: ![Question] , msgAnswers :: ![SomeRR] , msgAuthorities :: ![SomeRR] , msgAdditionals :: ![SomeRR] } deriving (Show, Eq) data Header = Header { hdMessageID :: !MessageID , hdMessageType :: !MessageType , hdOpcode :: !Opcode , hdIsAuthoritativeAnswer :: !Bool , hdIsTruncated :: !Bool , hdIsRecursionDesired :: !Bool , hdIsRecursionAvailable :: !Bool , hdResponseCode :: !ResponseCode -- These fields are supressed in this data structure: -- + QDCOUNT -- + ANCOUNT -- + NSCOUNT -- + ARCOUNT } deriving (Show, Eq) type MessageID = Word16 data MessageType = Query | Response deriving (Show, Eq) data Opcode = StandardQuery | InverseQuery | ServerStatusRequest deriving (Show, Eq) data ResponseCode = NoError | FormatError | ServerFailure | NameError | NotImplemented | Refused deriving (Show, Eq) data Question = Question { qName :: !DomainName , qType :: !SomeQT , qClass :: !RecordClass } deriving (Show, Eq) type SomeQT = SomeRT putQ :: Question -> Packer CompTable () putQ q = do putDomainName $ qName q putSomeRT $ qType q putBinary $ qClass q getQ :: Unpacker DecompTable Question getQ = do nm <- getDomainName ty <- getSomeRT cl <- getBinary return Question { qName = nm , qType = ty , qClass = cl } newtype DomainName = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable) type DomainLabel = BS.ByteString rootName :: DomainName rootName = DN [BS.empty] isRootName :: DomainName -> Bool isRootName (DN [_]) = True isRootName _ = False consLabel :: DomainLabel -> DomainName -> DomainName consLabel x (DN ys) = DN (x:ys) unconsLabel :: DomainName -> (DomainLabel, DomainName) unconsLabel (DN (x:xs)) = (x, DN xs) unconsLabel x = error ("Illegal use of unconsLabel: " ++ show x) mkDomainName :: String -> DomainName mkDomainName = DN . mkLabels [] . notEmpty where notEmpty :: String -> String notEmpty xs = assert (not $ null xs) xs mkLabels :: [DomainLabel] -> String -> [DomainLabel] mkLabels soFar [] = reverse (C8.empty : soFar) mkLabels soFar xs = case break (== '.') xs of (l, ('.':rest)) -> mkLabels (C8.pack l : soFar) rest _ -> error ("Illegal domain name: " ++ xs) data RecordClass = IN | CS -- Obsolete | CH | HS | AnyClass -- Only for queries deriving (Show, Eq) data RecordType rt dt => ResourceRecord rt dt = ResourceRecord { rrName :: !DomainName , rrType :: !rt , rrClass :: !RecordClass , rrTTL :: !TTL , rrData :: !dt } deriving (Show, Eq, Typeable) data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt) instance Show SomeRR where show (SomeRR rr) = show rr instance Eq SomeRR where (SomeRR a) == (SomeRR b) = Just a == cast b putSomeRR :: SomeRR -> Packer CompTable () putSomeRR (SomeRR rr) = putResourceRecord rr getSomeRR :: Unpacker DecompTable SomeRR getSomeRR = do srt <- U.lookAhead $ do getDomainName -- skip getSomeRT case srt of SomeRT rt -> getResourceRecord rt >>= return . SomeRR type CompTable = Map DomainName Int type DecompTable = IntMap DomainName type TTL = Word32 getDomainName :: Unpacker DecompTable DomainName getDomainName = worker where worker :: Unpacker DecompTable DomainName worker = do offset <- U.bytesRead hdr <- getLabelHeader case hdr of Offset n -> do dt <- U.getState case IM.lookup n dt of Just name -> return name Nothing -> fail ("Illegal offset of label pointer: " ++ show (n, dt)) Length 0 -> return rootName Length n -> do label <- U.getByteString n rest <- worker let name = consLabel label rest U.modifyState $ IM.insert offset name return name getLabelHeader :: Unpacker s LabelHeader getLabelHeader = do header <- U.lookAhead $ U.getByteString 1 let Right h = runBitGet header $ do a <- getBit b <- getBit n <- liftM fromIntegral (getAsWord8 6) case (a, b) of ( True, True) -> return $ Offset n (False, False) -> return $ Length n _ -> fail "Illegal label header" case h of Offset _ -> do header' <- U.getByteString 2 -- Pointers have 2 octets. let Right h' = runBitGet header' $ do BG.skip 2 n <- liftM fromIntegral (getAsWord16 14) return $ Offset n return h' len@(Length _) -> do U.skip 1 return len getCharString :: Unpacker s BS.ByteString getCharString = do len <- U.getWord8 U.getByteString (fromIntegral len) putCharString :: BS.ByteString -> Packer s () putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs P.putByteString xs data LabelHeader = Offset !Int | Length !Int putDomainName :: DomainName -> Packer CompTable () putDomainName name = do ct <- P.getState case M.lookup name ct of Just n -> do let ptr = runBitPut $ do putBit True putBit True putNBits 14 n P.putLazyByteString ptr Nothing -> do offset <- bytesWrote P.modifyState $ M.insert name offset let (label, rest) = unconsLabel name putCharString label if isRootName rest then P.putWord8 0 else putDomainName rest class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where rtToInt :: rt -> Int putRecordData :: rt -> dt -> Packer CompTable () getRecordData :: rt -> Unpacker DecompTable dt putRecordType :: rt -> Packer s () putRecordType = P.putWord16be . fromIntegral . rtToInt putRecordDataWithLength :: rt -> dt -> Packer CompTable () putRecordDataWithLength rt dt = do -- First, write a dummy data length. offset <- bytesWrote P.putWord16be 0 -- Second, write data. putRecordData rt dt -- Third, rewrite the dummy length to an actual value. offset' <- bytesWrote withOffset offset $ P.putWord16be (fromIntegral (offset' - offset - 2)) putResourceRecord :: ResourceRecord rt dt -> Packer CompTable () putResourceRecord rr = do putDomainName $ rrName rr putRecordType $ rrType rr putBinary $ rrClass rr P.putWord32be $ rrTTL rr putRecordDataWithLength (rrType rr) (rrData rr) getRecordDataWithLength :: rt -> Unpacker DecompTable dt getRecordDataWithLength rt = do len <- U.getWord16be offset <- U.bytesRead dat <- getRecordData rt offset' <- U.bytesRead let consumed = offset' - offset when (consumed /= len) $ fail ("getRecordData " ++ show rt ++ " consumed " ++ show consumed ++ " bytes but it had to consume " ++ show len ++ " bytes") return dat getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt) getResourceRecord rt = do name <- getDomainName U.skip 2 -- record type cl <- getBinary ttl <- U.getWord32be dat <- getRecordDataWithLength rt return $ ResourceRecord { rrName = name , rrType = rt , rrClass = cl , rrTTL = ttl , rrData = dat } data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt instance Show SomeRT where show (SomeRT rt) = show rt instance Eq SomeRT where (SomeRT a) == (SomeRT b) = Just a == cast b putSomeRT :: SomeRT -> Packer s () putSomeRT (SomeRT rt) = putRecordType rt getSomeRT :: Unpacker s SomeRT getSomeRT = do n <- liftM fromIntegral U.getWord16be case IM.lookup n defaultRTTable of Nothing -> fail ("Unknown resource record type: " ++ show n) Just srt -> return srt data SOAFields = SOAFields { soaMasterNameServer :: !DomainName , soaResponsibleMailbox :: !DomainName , soaSerialNumber :: !Word32 , soaRefreshInterval :: !Word32 , soaRetryInterval :: !Word32 , soaExpirationLimit :: !Word32 , soaMinimumTTL :: !Word32 } deriving (Show, Eq, Typeable) data A = A deriving (Show, Eq, Typeable) instance RecordType A HostAddress where rtToInt _ = 1 putRecordData _ = P.putWord32be getRecordData _ = U.getWord32be data NS = NS deriving (Show, Eq, Typeable) instance RecordType NS DomainName where rtToInt _ = 2 putRecordData _ = putDomainName getRecordData _ = getDomainName data MD = MD deriving (Show, Eq, Typeable) instance RecordType MD DomainName where rtToInt _ = 3 putRecordData _ = putDomainName getRecordData _ = getDomainName data MF = MF deriving (Show, Eq, Typeable) instance RecordType MF DomainName where rtToInt _ = 4 putRecordData _ = putDomainName getRecordData _ = getDomainName data CNAME = CNAME deriving (Show, Eq, Typeable) instance RecordType CNAME DomainName where rtToInt _ = 5 putRecordData _ = putDomainName getRecordData _ = getDomainName data SOA = SOA deriving (Show, Eq, Typeable) instance RecordType SOA SOAFields where rtToInt _ = 6 putRecordData _ = \ soa -> do putDomainName $ soaMasterNameServer soa putDomainName $ soaResponsibleMailbox soa P.putWord32be $ soaSerialNumber soa P.putWord32be $ soaRefreshInterval soa P.putWord32be $ soaRetryInterval soa P.putWord32be $ soaExpirationLimit soa P.putWord32be $ soaMinimumTTL soa getRecordData _ = do master <- getDomainName mail <- getDomainName serial <- U.getWord32be refresh <- U.getWord32be retry <- U.getWord32be expire <- U.getWord32be ttl <- U.getWord32be return SOAFields { soaMasterNameServer = master , soaResponsibleMailbox = mail , soaSerialNumber = serial , soaRefreshInterval = refresh , soaRetryInterval = retry , soaExpirationLimit = expire , soaMinimumTTL = ttl } data MB = MB deriving (Show, Eq, Typeable) instance RecordType MB DomainName where rtToInt _ = 7 putRecordData _ = putDomainName getRecordData _ = getDomainName data MG = MG deriving (Show, Eq, Typeable) instance RecordType MG DomainName where rtToInt _ = 8 putRecordData _ = putDomainName getRecordData _ = getDomainName data MR = MR deriving (Show, Eq, Typeable) instance RecordType MR DomainName where rtToInt _ = 9 putRecordData _ = putDomainName getRecordData _ = getDomainName data NULL = NULL deriving (Show, Eq, Typeable) instance RecordType NULL BS.ByteString where rtToInt _ = 10 putRecordData _ _ = fail "putRecordData NULL can't be defined" getRecordData _ = fail "getRecordData NULL can't be defined" putRecordDataWithLength _ = \ dat -> do P.putWord16be $ fromIntegral $ BS.length dat P.putByteString dat getRecordDataWithLength _ = do len <- U.getWord16be U.getByteString $ fromIntegral len data PTR = PTR deriving (Show, Eq, Typeable) instance RecordType PTR DomainName where rtToInt _ = 12 putRecordData _ = putDomainName getRecordData _ = getDomainName data HINFO = HINFO deriving (Show, Eq, Typeable) instance RecordType HINFO (BS.ByteString, BS.ByteString) where rtToInt _ = 13 putRecordData _ = \ (cpu, os) -> do putCharString cpu putCharString os getRecordData _ = do cpu <- getCharString os <- getCharString return (cpu, os) data MINFO = MINFO deriving (Show, Eq, Typeable) instance RecordType MINFO (DomainName, DomainName) where rtToInt _ = 14 putRecordData _ = \ (r, e) -> do putDomainName r putDomainName e getRecordData _ = do r <- getDomainName e <- getDomainName return (r, e) data MX = MX deriving (Show, Eq, Typeable) instance RecordType MX (Word16, DomainName) where rtToInt _ = 15 putRecordData _ = \ (pref, exch) -> do P.putWord16be pref putDomainName exch getRecordData _ = do pref <- U.getWord16be exch <- getDomainName return (pref, exch) data TXT = TXT deriving (Show, Eq, Typeable) instance RecordType TXT [BS.ByteString] where rtToInt _ = 16 putRecordData _ = mapM_ putCharString getRecordData _ = fail "getRecordData TXT can't be defined" getRecordDataWithLength _ = U.getWord16be >>= worker [] . fromIntegral where worker :: [BS.ByteString] -> Int -> Unpacker s [BS.ByteString] worker soFar 0 = return (reverse soFar) worker soFar n = do str <- getCharString worker (str : soFar) (0 `max` n - 1 - BS.length str) {- data RecordType = A | NS | MD | MF | CNAME | SOA | MB | MG | MR | NULL | WKS | PTR | HINFO | MINFO | MX | TXT -- Only for queries: | AXFR | MAILB -- Obsolete | MAILA -- Obsolete | AnyType deriving (Show, Eq) -} instance Binary Message where put m = P.liftToBinary M.empty $ do putBinary $ msgHeader m P.putWord16be $ fromIntegral $ length $ msgQuestions m P.putWord16be $ fromIntegral $ length $ msgAnswers m P.putWord16be $ fromIntegral $ length $ msgAuthorities m P.putWord16be $ fromIntegral $ length $ msgAdditionals m mapM_ putQ $ msgQuestions m mapM_ putSomeRR $ msgAnswers m mapM_ putSomeRR $ msgAuthorities m mapM_ putSomeRR $ msgAdditionals m get = U.liftToBinary IM.empty $ do hdr <- getBinary nQ <- liftM fromIntegral U.getWord16be nAns <- liftM fromIntegral U.getWord16be nAth <- liftM fromIntegral U.getWord16be nAdd <- liftM fromIntegral U.getWord16be qs <- replicateM nQ getQ anss <- replicateM nAns getSomeRR aths <- replicateM nAth getSomeRR adds <- replicateM nAdd getSomeRR return Message { msgHeader = hdr , msgQuestions = qs , msgAnswers = anss , msgAuthorities = aths , msgAdditionals = adds } instance Binary Header where put h = do P'.putWord16be $ hdMessageID h P'.putLazyByteString flags where flags = runBitPut $ do putNBits 1 $ fromEnum $ hdMessageType h putNBits 4 $ fromEnum $ hdOpcode h putBit $ hdIsAuthoritativeAnswer h putBit $ hdIsTruncated h putBit $ hdIsRecursionDesired h putBit $ hdIsRecursionAvailable h putNBits 3 (0 :: Int) putNBits 4 $ fromEnum $ hdResponseCode h get = do mID <- G.getWord16be flags <- G.getByteString 2 let Right hd = runBitGet flags $ do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1 op <- liftM (toEnum . fromIntegral) $ getAsWord8 4 aa <- getBit tc <- getBit rd <- getBit ra <- getBit BG.skip 3 rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4 return Header { hdMessageID = mID , hdMessageType = qr , hdOpcode = op , hdIsAuthoritativeAnswer = aa , hdIsTruncated = tc , hdIsRecursionDesired = rd , hdIsRecursionAvailable = ra , hdResponseCode = rc } return hd instance Enum MessageType where fromEnum Query = 0 fromEnum Response = 1 toEnum 0 = Query toEnum 1 = Response toEnum _ = undefined instance Enum Opcode where fromEnum StandardQuery = 0 fromEnum InverseQuery = 1 fromEnum ServerStatusRequest = 2 toEnum 0 = StandardQuery toEnum 1 = InverseQuery toEnum 2 = ServerStatusRequest toEnum _ = undefined instance Enum ResponseCode where fromEnum NoError = 0 fromEnum FormatError = 1 fromEnum ServerFailure = 2 fromEnum NameError = 3 fromEnum NotImplemented = 4 fromEnum Refused = 5 toEnum 0 = NoError toEnum 1 = FormatError toEnum 2 = ServerFailure toEnum 3 = NameError toEnum 4 = NotImplemented toEnum 5 = Refused toEnum _ = undefined {- instance Enum RecordType where fromEnum A = 1 / fromEnum NS = 2 / fromEnum MD = 3 / fromEnum MF = 4 / fromEnum CNAME = 5 / fromEnum SOA = 6 / fromEnum MB = 7 / fromEnum MG = 8 / fromEnum MR = 9 / fromEnum NULL = 10 / fromEnum WKS = 11 fromEnum PTR = 12 / fromEnum HINFO = 13 / fromEnum MINFO = 14 / fromEnum MX = 15 / fromEnum TXT = 16 / fromEnum AXFR = 252 fromEnum MAILB = 253 fromEnum MAILA = 254 fromEnum AnyType = 255 -} instance Enum RecordClass where fromEnum IN = 1 fromEnum CS = 2 fromEnum CH = 3 fromEnum HS = 4 fromEnum AnyClass = 255 toEnum 1 = IN toEnum 2 = CS toEnum 3 = CH toEnum 4 = HS toEnum 255 = AnyClass toEnum _ = undefined instance Binary RecordClass where get = liftM (toEnum . fromIntegral) G.getWord16be put = P'.putWord16be . fromIntegral . fromEnum defaultRTTable :: IntMap SomeRT defaultRTTable = IM.fromList $ map toPair $ [ wrapRecordType A , wrapRecordType NS , wrapRecordType CNAME , wrapRecordType HINFO ] where toPair :: SomeRT -> (Int, SomeRT) toPair srt@(SomeRT rt) = (rtToInt rt, srt) wrapQueryType :: RecordType rt dt => rt -> SomeQT wrapQueryType = SomeRT wrapRecordType :: RecordType rt dt => rt -> SomeRT wrapRecordType = SomeRT wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR wrapRecord = SomeRR