module Network.DNS.Message ( Message(..) , MessageID , MessageType(..) , Header(..) , Opcode(..) , ResponseCode(..) , Question(..) , ResourceRecord(..) , DomainName , DomainLabel , TTL , RecordType , RecordClass(..) , SomeRR(..) , SomeRT(..) , CNAME(..) , HINFO(..) ) where import Control.Monad import Data.Binary import Data.Binary.BitPut as BP import Data.Binary.Get as G import Data.Binary.Put as P import Data.Binary.Strict.BitGet as BG import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Typeable import qualified Data.IntMap as IM import Data.IntMap (IntMap) import Data.Word replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a) replicateM' = worker [] where worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a) worker soFar 0 _ a = return (reverse soFar, a) worker soFar n f a = do (b, a') <- f a worker (b : soFar) (n - 1) f a' data Message = Message { msgHeader :: !Header , msgQuestions :: ![Question] , msgAnswers :: ![SomeRR] , msgAuthorities :: ![SomeRR] , msgAdditionals :: ![SomeRR] } data Header = Header { hdMessageID :: !MessageID , hdMessageType :: !MessageType , hdOpcode :: !Opcode , hdIsAuthoritativeAnswer :: !Bool , hdIsTruncated :: !Bool , hdIsRecursionDesired :: !Bool , hdIsRecursionAvailable :: !Bool , hdResponseCode :: !ResponseCode -- These fields are supressed in this data structure: -- + QDCOUNT -- + ANCOUNT -- + NSCOUNT -- + ARCOUNT } type MessageID = Word16 data MessageType = Query | Response deriving (Show, Eq) data Opcode = StandardQuery | InverseQuery | ServerStatusRequest deriving (Show, Eq) data ResponseCode = NoError | FormatError | ServerFailure | NameError | NotImplemented | Refused deriving (Show, Eq) data Question = Question { qName :: !DomainName , qType :: !SomeRT , qClass :: !RecordClass } deriving (Show, Eq) putQ :: Question -> Put putQ q = do putDomainName $ qName q putSomeRT $ qType q put $ qClass q getQ :: DecompTable -> Get (Question, DecompTable) getQ dt = do (nm, dt') <- getDomainName dt ty <- getSomeRT cl <- get let q = Question { qName = nm , qType = ty , qClass = cl } return (q, dt') newtype DomainName = DN [DomainLabel] deriving (Eq, Show, Typeable) type DomainLabel = BS.ByteString nameToLabels :: DomainName -> [DomainLabel] nameToLabels (DN ls) = ls labelsToName :: [DomainLabel] -> DomainName labelsToName = DN data RecordClass = IN | CS -- Obsolete | CH | HS | AnyClass -- Only for queries deriving (Show, Eq) data RecordType rt dt => ResourceRecord rt dt = ResourceRecord { rrName :: !DomainName , rrType :: !rt , rrClass :: !RecordClass , rrTTL :: !TTL , rrData :: !dt } deriving (Show, Eq, Typeable) putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put putRR rr = do putDomainName $ rrName rr putRecordType $ rrType rr put $ rrClass rr putWord32be $ rrTTL rr let dat = runPut $ putRecordData (undefined :: rt) (rrData rr) putWord16be $ fromIntegral $ LBS.length dat putLazyByteString dat getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable) getRR dt rt = do (nm, dt1) <- getDomainName dt G.skip 2 -- record type cl <- get ttl <- G.getWord32be G.skip 2 -- data length (dat, dt2) <- getRecordData (undefined :: rt) dt1 let rr = ResourceRecord { rrName = nm , rrType = rt , rrClass = cl , rrTTL = ttl , rrData = dat } return (rr, dt2) data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt) instance Show SomeRR where show (SomeRR rr) = show rr instance Eq SomeRR where (SomeRR a) == (SomeRR b) = Just a == cast b putSomeRR :: SomeRR -> Put putSomeRR (SomeRR rr) = putRR rr getSomeRR :: DecompTable -> Get (SomeRR, DecompTable) getSomeRR dt = do srt <- lookAhead $ do getDomainName dt -- skip getSomeRT case srt of SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt') type DecompTable = IntMap BS.ByteString type TTL = Word32 getDomainName :: DecompTable -> Get (DomainName, DecompTable) getDomainName = flip worker [] where worker :: DecompTable -> [DomainLabel] -> Get (DomainName, DecompTable) worker dt soFar = do (l, dt') <- getDomainLabel dt case BS.null l of True -> return (labelsToName (reverse (l : soFar)), dt') False -> worker dt' (l : soFar) getDomainLabel :: DecompTable -> Get (DomainLabel, DecompTable) getDomainLabel dt = do header <- getByteString 1 let Right h = runBitGet header $ do a <- getBit b <- getBit n <- liftM fromIntegral (getAsWord8 6) case (a, b) of ( True, True) -> return $ Offset n (False, False) -> return $ Length n _ -> fail "Illegal label header" case h of Offset n -> do let Just l = IM.lookup n dt return (l, dt) Length n -> do offset <- liftM fromIntegral bytesRead label <- getByteString n let dt' = IM.insert offset label dt return (label, dt') getCharString :: Get BS.ByteString getCharString = do len <- G.getWord8 getByteString (fromIntegral len) putCharString :: BS.ByteString -> Put putCharString = putDomainLabel data LabelHeader = Offset !Int | Length !Int putDomainName :: DomainName -> Put putDomainName = mapM_ putDomainLabel . nameToLabels putDomainLabel :: DomainLabel -> Put putDomainLabel l = do putWord8 $ fromIntegral $ BS.length l P.putByteString l class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where rtToInt :: rt -> Int putRecordType :: rt -> Put putRecordData :: rt -> dt -> Put getRecordData :: rt -> DecompTable -> Get (dt, DecompTable) putRecordType = putWord16be . fromIntegral . rtToInt data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt instance Show SomeRT where show (SomeRT rt) = show rt instance Eq SomeRT where (SomeRT a) == (SomeRT b) = Just a == cast b putSomeRT :: SomeRT -> Put putSomeRT (SomeRT rt) = putRecordType rt getSomeRT :: Get SomeRT getSomeRT = do n <- liftM fromIntegral G.getWord16be case IM.lookup n defaultRTTable of Nothing -> fail ("Unknown resource record type: " ++ show n) Just srt -> return srt data CNAME = CNAME deriving (Show, Eq, Typeable) instance RecordType CNAME DomainName where rtToInt _ = 5 putRecordData _ = putDomainName getRecordData _ = getDomainName data HINFO = HINFO deriving (Show, Eq, Typeable) instance RecordType HINFO (BS.ByteString, BS.ByteString) where rtToInt _ = 13 putRecordData _ (cpu, os) = do putCharString cpu putCharString os getRecordData _ dt = do cpu <- getCharString os <- getCharString return ((cpu, os), dt) {- data RecordType = A | NS | MD | MF | CNAME | SOA | MB | MG | MR | NULL | WKS | PTR | HINFO | MINFO | MX | TXT -- Only for queries: | AXFR | MAILB -- Obsolete | MAILA -- Obsolete | AnyType deriving (Show, Eq) -} instance Binary Message where put m = do put $ msgHeader m putWord16be $ fromIntegral $ length $ msgQuestions m putWord16be $ fromIntegral $ length $ msgAnswers m putWord16be $ fromIntegral $ length $ msgAuthorities m putWord16be $ fromIntegral $ length $ msgAdditionals m mapM_ putQ $ msgQuestions m mapM_ putSomeRR $ msgAnswers m mapM_ putSomeRR $ msgAuthorities m mapM_ putSomeRR $ msgAdditionals m get = do hdr <- get nQ <- liftM fromIntegral G.getWord16be nAns <- liftM fromIntegral G.getWord16be nAth <- liftM fromIntegral G.getWord16be nAdd <- liftM fromIntegral G.getWord16be (qs , dt1) <- replicateM' nQ getQ IM.empty (anss, dt2) <- replicateM' nAns getSomeRR dt1 (aths, dt3) <- replicateM' nAth getSomeRR dt2 (adds, _ ) <- replicateM' nAdd getSomeRR dt3 return Message { msgHeader = hdr , msgQuestions = qs , msgAnswers = anss , msgAuthorities = aths , msgAdditionals = adds } instance Binary Header where put h = do putWord16be $ hdMessageID h putLazyByteString flags where flags = runBitPut $ do putNBits 1 $ fromEnum $ hdMessageType h putNBits 4 $ fromEnum $ hdOpcode h putBit $ hdIsAuthoritativeAnswer h putBit $ hdIsTruncated h putBit $ hdIsRecursionDesired h putBit $ hdIsRecursionAvailable h putNBits 3 (0 :: Int) putNBits 4 $ fromEnum $ hdResponseCode h get = do mID <- G.getWord16be flags <- getByteString 2 let Right hd = runBitGet flags $ do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1 op <- liftM (toEnum . fromIntegral) $ getAsWord8 4 aa <- getBit tc <- getBit rd <- getBit ra <- getBit BG.skip 3 rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4 return Header { hdMessageID = mID , hdMessageType = qr , hdOpcode = op , hdIsAuthoritativeAnswer = aa , hdIsTruncated = tc , hdIsRecursionDesired = rd , hdIsRecursionAvailable = ra , hdResponseCode = rc } return hd instance Enum MessageType where fromEnum Query = 0 fromEnum Response = 1 toEnum 0 = Query toEnum 1 = Response toEnum _ = undefined instance Enum Opcode where fromEnum StandardQuery = 0 fromEnum InverseQuery = 1 fromEnum ServerStatusRequest = 2 toEnum 0 = StandardQuery toEnum 1 = InverseQuery toEnum 2 = ServerStatusRequest toEnum _ = undefined instance Enum ResponseCode where fromEnum NoError = 0 fromEnum FormatError = 1 fromEnum ServerFailure = 2 fromEnum NameError = 3 fromEnum NotImplemented = 4 fromEnum Refused = 5 toEnum 0 = NoError toEnum 1 = FormatError toEnum 2 = ServerFailure toEnum 3 = NameError toEnum 4 = NotImplemented toEnum 5 = Refused toEnum _ = undefined {- instance Enum RecordType where fromEnum A = 1 fromEnum NS = 2 fromEnum MD = 3 fromEnum MF = 4 fromEnum CNAME = 5 fromEnum SOA = 6 fromEnum MB = 7 fromEnum MG = 8 fromEnum MR = 9 fromEnum NULL = 10 fromEnum WKS = 11 fromEnum PTR = 12 fromEnum HINFO = 13 fromEnum MINFO = 14 fromEnum MX = 15 fromEnum TXT = 16 fromEnum AXFR = 252 fromEnum MAILB = 253 fromEnum MAILA = 254 fromEnum AnyType = 255 toEnum 1 = A toEnum 2 = NS toEnum 3 = MD toEnum 4 = MF toEnum 5 = CNAME toEnum 6 = SOA toEnum 7 = MB toEnum 8 = MG toEnum 9 = MR toEnum 10 = NULL toEnum 11 = WKS toEnum 12 = PTR toEnum 13 = HINFO toEnum 14 = MINFO toEnum 15 = MX toEnum 16 = TXT toEnum 252 = AXFR toEnum 253 = MAILB toEnum 254 = MAILA toEnum 255 = AnyType toEnum _ = undefined -} instance Enum RecordClass where fromEnum IN = 1 fromEnum CS = 2 fromEnum CH = 3 fromEnum HS = 4 fromEnum AnyClass = 255 toEnum 1 = IN toEnum 2 = CS toEnum 3 = CH toEnum 4 = HS toEnum 255 = AnyClass toEnum _ = undefined instance Binary RecordClass where get = liftM (toEnum . fromIntegral) G.getWord16be put = putWord16be . fromIntegral . fromEnum defaultRTTable :: IntMap SomeRT defaultRTTable = IM.fromList $ map toPair $ [ SomeRT CNAME ] where toPair :: SomeRT -> (Int, SomeRT) toPair srt@(SomeRT rt) = (rtToInt rt, srt)