]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Message.hs
Introduce Packer monad so that we can compress binary packets.
[haskell-dns.git] / Network / DNS / Message.hs
1 module Network.DNS.Message
2     ( Message(..)
3     , MessageID
4     , MessageType(..)
5     , Header(..)
6     , Opcode(..)
7     , ResponseCode(..)
8     , Question(..)
9     , ResourceRecord(..)
10     , DomainName
11     , DomainLabel
12     , TTL
13     , RecordType
14     , RecordClass(..)
15
16     , SomeQT
17     , SomeRR
18     , SomeRT
19
20     , A(..)
21     , NS(..)
22     , CNAME(..)
23     , HINFO(..)
24
25     , mkDomainName
26     , wrapQueryType
27     , wrapRecordType
28     , wrapRecord
29     )
30     where
31
32 import           Control.Exception
33 import           Control.Monad
34 import           Data.Binary
35 import           Data.Binary.BitPut as BP
36 import           Data.Binary.Get as G
37 import           Data.Binary.Put as P'
38 import           Data.Binary.Strict.BitGet as BG
39 import qualified Data.ByteString as BS
40 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
41 import           Data.Typeable
42 import qualified Data.IntMap as IM
43 import           Data.IntMap (IntMap)
44 import qualified Data.Map as M
45 import           Data.Map (Map)
46 import           Data.Word
47 import           Network.DNS.Packer as P
48 import           Network.DNS.Unpacker as U
49 import           Network.Socket
50
51
52 data Message
53     = Message {
54         msgHeader      :: !Header
55       , msgQuestions   :: ![Question]
56       , msgAnswers     :: ![SomeRR]
57       , msgAuthorities :: ![SomeRR]
58       , msgAdditionals :: ![SomeRR]
59       }
60     deriving (Show, Eq)
61
62 data Header
63     = Header {
64         hdMessageID             :: !MessageID
65       , hdMessageType           :: !MessageType
66       , hdOpcode                :: !Opcode
67       , hdIsAuthoritativeAnswer :: !Bool
68       , hdIsTruncated           :: !Bool
69       , hdIsRecursionDesired    :: !Bool
70       , hdIsRecursionAvailable  :: !Bool
71       , hdResponseCode          :: !ResponseCode
72
73       -- These fields are supressed in this data structure:
74       -- + QDCOUNT
75       -- + ANCOUNT
76       -- + NSCOUNT
77       -- + ARCOUNT
78       }
79     deriving (Show, Eq)
80
81 type MessageID = Word16
82
83 data MessageType
84     = Query
85     | Response
86     deriving (Show, Eq)
87
88 data Opcode
89     = StandardQuery
90     | InverseQuery
91     | ServerStatusRequest
92     deriving (Show, Eq)
93
94 data ResponseCode
95     = NoError
96     | FormatError
97     | ServerFailure
98     | NameError
99     | NotImplemented
100     | Refused
101     deriving (Show, Eq)
102
103 data Question
104     = Question {
105         qName  :: !DomainName
106       , qType  :: !SomeQT
107       , qClass :: !RecordClass
108       }
109     deriving (Show, Eq)
110
111 type SomeQT = SomeRT
112
113 putQ :: Question -> Packer CompTable ()
114 putQ q
115     = do putDomainName $ qName q
116          putSomeRT $ qType q
117          putBinary $ qClass q
118
119 getQ :: Unpacker DecompTable Question
120 getQ = do nm <- getDomainName
121           ty <- getSomeRT
122           cl <- getBinary
123           return Question {
124                        qName  = nm
125                      , qType  = ty
126                      , qClass = cl
127                      }
128
129
130 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Ord, Typeable)
131 type DomainLabel    = BS.ByteString
132
133 rootName :: DomainName
134 rootName = DN [BS.empty]
135
136 isRootName :: DomainName -> Bool
137 isRootName (DN [_]) = True
138 isRootName _        = False
139
140 consLabel :: DomainLabel -> DomainName -> DomainName
141 consLabel x (DN ys) = DN (x:ys)
142
143 unconsLabel :: DomainName -> (DomainLabel, DomainName)
144 unconsLabel (DN (x:xs)) = (x, DN xs)
145 unconsLabel x           = error ("Illegal use of unconsLabel: " ++ show x)
146
147 mkDomainName :: String -> DomainName
148 mkDomainName = DN . mkLabels [] . notEmpty
149     where
150       notEmpty :: String -> String
151       notEmpty xs = assert (not $ null xs) xs
152
153       mkLabels :: [DomainLabel] -> String -> [DomainLabel]
154       mkLabels soFar [] = reverse (C8.empty : soFar)
155       mkLabels soFar xs = case break (== '.') xs of
156                             (l, ('.':rest))
157                                 -> mkLabels (C8.pack l : soFar) rest
158                             _   -> error ("Illegal domain name: " ++ xs)
159
160 data RecordClass
161     = IN
162     | CS -- Obsolete
163     | CH
164     | HS
165     | AnyClass -- Only for queries
166     deriving (Show, Eq)
167
168
169 data RecordType rt dt => ResourceRecord rt dt
170     = ResourceRecord {
171         rrName  :: !DomainName
172       , rrType  :: !rt
173       , rrClass :: !RecordClass
174       , rrTTL   :: !TTL
175       , rrData  :: !dt
176       }
177     deriving (Show, Eq, Typeable)
178
179
180 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
181
182 instance Show SomeRR where
183     show (SomeRR rr) = show rr
184
185 instance Eq SomeRR where
186     (SomeRR a) == (SomeRR b) = Just a == cast b
187
188
189 putSomeRR :: SomeRR -> Packer CompTable ()
190 putSomeRR (SomeRR rr) = putResourceRecord rr
191
192 getSomeRR :: Unpacker DecompTable SomeRR
193 getSomeRR = do srt <- U.lookAhead $
194                       do getDomainName -- skip
195                          getSomeRT
196                case srt of
197                  SomeRT rt
198                      -> getResourceRecord rt >>= return . SomeRR
199
200 type CompTable   = Map DomainName Int
201 type DecompTable = IntMap DomainName
202 type TTL         = Word32
203
204 getDomainName :: Unpacker DecompTable DomainName
205 getDomainName = worker
206     where
207       worker :: Unpacker DecompTable DomainName
208       worker
209           = do offset <- U.bytesRead
210                hdr    <- getLabelHeader
211                case hdr of
212                  Offset n
213                      -> do dt <- U.getState
214                            case IM.lookup n dt of
215                              Just name
216                                  -> return name
217                              Nothing
218                                  -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
219                  Length 0
220                      -> return rootName
221                  Length n
222                      -> do label <- U.getByteString n
223                            rest  <- worker
224                            let name = consLabel label rest
225                            U.modifyState $ IM.insert offset name
226                            return name
227
228       getLabelHeader :: Unpacker s LabelHeader
229       getLabelHeader
230           = do header <- U.lookAhead $ U.getByteString 1
231                let Right h
232                        = runBitGet header $
233                          do a <- getBit
234                             b <- getBit
235                             n <- liftM fromIntegral (getAsWord8 6)
236                             case (a, b) of
237                               ( True,  True) -> return $ Offset n
238                               (False, False) -> return $ Length n
239                               _              -> fail "Illegal label header"
240                case h of
241                  Offset _
242                      -> do header' <- U.getByteString 2 -- Pointers have 2 octets.
243                            let Right h'
244                                    = runBitGet header' $
245                                      do BG.skip 2
246                                         n <- liftM fromIntegral (getAsWord16 14)
247                                         return $ Offset n
248                            return h'
249                  len@(Length _)
250                      -> do U.skip 1
251                            return len
252
253
254 getCharString :: Unpacker s BS.ByteString
255 getCharString = do len <- U.getWord8
256                    U.getByteString (fromIntegral len)
257
258 putCharString :: BS.ByteString -> Packer s ()
259 putCharString xs = do P.putWord8 $ fromIntegral $ BS.length xs
260                       P.putByteString xs
261
262 data LabelHeader
263     = Offset !Int
264     | Length !Int
265
266 putDomainName :: DomainName -> Packer CompTable ()
267 putDomainName name
268     = do ct <- P.getState
269          case M.lookup name ct of
270            Just n
271                -> do let ptr = runBitPut $
272                                do putBit True
273                                   putBit True
274                                   putNBits 14 n
275                      P.putLazyByteString ptr
276            Nothing
277                -> do offset <- bytesWrote
278                      P.modifyState $ M.insert name offset
279
280                      let (label, rest) = unconsLabel name
281
282                      putCharString label
283
284                      if isRootName rest then
285                          P.putWord8 0
286                        else
287                          putDomainName rest
288
289
290 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
291     rtToInt       :: rt -> Int
292     putRecordData :: rt -> dt -> Packer CompTable ()
293     getRecordData :: rt -> Unpacker DecompTable dt
294
295     putRecordType :: rt -> Packer s ()
296     putRecordType = P.putWord16be . fromIntegral . rtToInt
297
298     putResourceRecord :: ResourceRecord rt dt -> Packer CompTable ()
299     putResourceRecord rr
300         = do putDomainName $ rrName  rr
301              putRecordType $ rrType  rr
302              putBinary     $ rrClass rr
303              P.putWord32be $ rrTTL   rr
304
305              -- First, write a dummy data length.
306              offset <- bytesWrote
307              P.putWord16be 0
308
309              -- Second, write data.
310              putRecordData (rrType rr) (rrData rr)
311
312              -- Third, rewrite the dummy length to an actual value.
313              offset' <- bytesWrote
314              withOffset offset
315                  $ P.putWord16be (fromIntegral (offset' - offset - 2))
316
317     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
318     getResourceRecord rt
319         = do name     <- getDomainName
320              U.skip 2 -- record type
321              cl       <- getBinary
322              ttl      <- U.getWord32be
323              U.skip 2 -- data length
324              dat      <- getRecordData rt
325              return $ ResourceRecord {
326                           rrName  = name
327                         , rrType  = rt
328                         , rrClass = cl
329                         , rrTTL   = ttl
330                         , rrData  = dat
331                         }
332
333 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
334
335 instance Show SomeRT where
336     show (SomeRT rt) = show rt
337
338 instance Eq SomeRT where
339     (SomeRT a) == (SomeRT b) = Just a == cast b
340
341 putSomeRT :: SomeRT -> Packer s ()
342 putSomeRT (SomeRT rt) = putRecordType rt
343
344 getSomeRT :: Unpacker s SomeRT
345 getSomeRT = do n <- liftM fromIntegral U.getWord16be
346                case IM.lookup n defaultRTTable of
347                  Nothing
348                      -> fail ("Unknown resource record type: " ++ show n)
349                  Just srt
350                      -> return srt
351
352 data A = A deriving (Show, Eq, Typeable)
353 instance RecordType A HostAddress where
354     rtToInt       _ = 1
355     putRecordData _ = P.putWord32be
356     getRecordData _ = U.getWord32be
357
358 data NS = NS deriving (Show, Eq, Typeable)
359 instance RecordType NS DomainName where
360     rtToInt       _ = 2
361     putRecordData _ = putDomainName
362     getRecordData _ = getDomainName
363
364 data CNAME = CNAME deriving (Show, Eq, Typeable)
365 instance RecordType CNAME DomainName where
366     rtToInt       _ = 5
367     putRecordData _ = putDomainName
368     getRecordData _ = getDomainName
369
370 data HINFO = HINFO deriving (Show, Eq, Typeable)
371 instance RecordType HINFO (BS.ByteString, BS.ByteString) where
372     rtToInt       _           = 13
373     putRecordData _ (cpu, os) = do putCharString cpu
374                                    putCharString os
375     getRecordData _           = do cpu <- getCharString
376                                    os  <- getCharString
377                                    return (cpu, os)
378
379
380 {-
381 data RecordType
382     = A
383     | NS
384     | MD
385     | MF
386     | CNAME
387     | SOA
388     | MB
389     | MG
390     | MR
391     | NULL
392     | WKS
393     | PTR
394     | HINFO
395     | MINFO
396     | MX
397     | TXT
398
399     -- Only for queries:
400     | AXFR
401     | MAILB -- Obsolete
402     | MAILA -- Obsolete
403     | AnyType
404     deriving (Show, Eq)
405 -}
406
407 instance Binary Message where
408     put m = P.liftToBinary M.empty $
409             do putBinary $ msgHeader m
410                P.putWord16be $ fromIntegral $ length $ msgQuestions m
411                P.putWord16be $ fromIntegral $ length $ msgAnswers m
412                P.putWord16be $ fromIntegral $ length $ msgAuthorities m
413                P.putWord16be $ fromIntegral $ length $ msgAdditionals m
414                mapM_ putQ      $ msgQuestions m
415                mapM_ putSomeRR $ msgAnswers m
416                mapM_ putSomeRR $ msgAuthorities m
417                mapM_ putSomeRR $ msgAdditionals m
418
419     get = U.liftToBinary IM.empty $
420           do hdr  <- getBinary
421              nQ   <- liftM fromIntegral U.getWord16be
422              nAns <- liftM fromIntegral U.getWord16be
423              nAth <- liftM fromIntegral U.getWord16be
424              nAdd <- liftM fromIntegral U.getWord16be
425              qs   <- replicateM nQ   getQ
426              anss <- replicateM nAns getSomeRR
427              aths <- replicateM nAth getSomeRR
428              adds <- replicateM nAdd getSomeRR
429              return Message {
430                           msgHeader      = hdr
431                         , msgQuestions   = qs
432                         , msgAnswers     = anss
433                         , msgAuthorities = aths
434                         , msgAdditionals = adds
435                         }
436
437 instance Binary Header where
438     put h = do P'.putWord16be $ hdMessageID h
439                P'.putLazyByteString flags
440         where
441           flags = runBitPut $
442                   do putNBits 1 $ fromEnum $ hdMessageType h
443                      putNBits 4 $ fromEnum $ hdOpcode h
444                      putBit $ hdIsAuthoritativeAnswer h
445                      putBit $ hdIsTruncated h
446                      putBit $ hdIsRecursionDesired h
447                      putBit $ hdIsRecursionAvailable h
448                      putNBits 3 (0 :: Int)
449                      putNBits 4 $ fromEnum $ hdResponseCode h
450
451     get = do mID   <- G.getWord16be
452              flags <- G.getByteString 2
453              let Right hd
454                      = runBitGet flags $
455                        do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
456                           op <- liftM (toEnum . fromIntegral) $ getAsWord8 4
457                           aa <- getBit
458                           tc <- getBit
459                           rd <- getBit
460                           ra <- getBit
461                           BG.skip 3
462                           rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4
463                           return Header {
464                                        hdMessageID             = mID
465                                      , hdMessageType           = qr
466                                      , hdOpcode                = op
467                                      , hdIsAuthoritativeAnswer = aa
468                                      , hdIsTruncated           = tc
469                                      , hdIsRecursionDesired    = rd
470                                      , hdIsRecursionAvailable  = ra
471                                      , hdResponseCode          = rc
472                                      }
473              return hd
474
475 instance Enum MessageType where
476     fromEnum Query    = 0
477     fromEnum Response = 1
478
479     toEnum 0 = Query
480     toEnum 1 = Response
481     toEnum _ = undefined
482
483 instance Enum Opcode where
484     fromEnum StandardQuery       = 0
485     fromEnum InverseQuery        = 1
486     fromEnum ServerStatusRequest = 2
487
488     toEnum 0 = StandardQuery
489     toEnum 1 = InverseQuery
490     toEnum 2 = ServerStatusRequest
491     toEnum _ = undefined
492
493 instance Enum ResponseCode where
494     fromEnum NoError        = 0
495     fromEnum FormatError    = 1
496     fromEnum ServerFailure  = 2
497     fromEnum NameError      = 3
498     fromEnum NotImplemented = 4
499     fromEnum Refused        = 5
500
501     toEnum 0 = NoError
502     toEnum 1 = FormatError
503     toEnum 2 = ServerFailure
504     toEnum 3 = NameError
505     toEnum 4 = NotImplemented
506     toEnum 5 = Refused
507     toEnum _ = undefined
508
509 {-
510 instance Enum RecordType where
511     fromEnum A       = 1
512     fromEnum NS      = 2
513     fromEnum MD      = 3
514     fromEnum MF      = 4
515     fromEnum CNAME   = 5
516     fromEnum SOA     = 6
517     fromEnum MB      = 7
518     fromEnum MG      = 8
519     fromEnum MR      = 9
520     fromEnum NULL    = 10
521     fromEnum WKS     = 11
522     fromEnum PTR     = 12
523     fromEnum HINFO   = 13
524     fromEnum MINFO   = 14
525     fromEnum MX      = 15
526     fromEnum TXT     = 16
527     fromEnum AXFR    = 252
528     fromEnum MAILB   = 253
529     fromEnum MAILA   = 254
530     fromEnum AnyType = 255
531
532     toEnum 1  = A
533     toEnum 2  = NS
534     toEnum 3  = MD
535     toEnum 4  = MF
536     toEnum 5  = CNAME
537     toEnum 6  = SOA
538     toEnum 7  = MB
539     toEnum 8  = MG
540     toEnum 9  = MR
541     toEnum 10 = NULL
542     toEnum 11 = WKS
543     toEnum 12 = PTR
544     toEnum 13 = HINFO
545     toEnum 14 = MINFO
546     toEnum 15 = MX
547     toEnum 16 = TXT
548     toEnum 252 = AXFR
549     toEnum 253 = MAILB
550     toEnum 254 = MAILA
551     toEnum 255 = AnyType
552     toEnum _  = undefined
553 -}
554
555 instance Enum RecordClass where
556     fromEnum IN       = 1
557     fromEnum CS       = 2
558     fromEnum CH       = 3
559     fromEnum HS       = 4
560     fromEnum AnyClass = 255
561
562     toEnum 1   = IN
563     toEnum 2   = CS
564     toEnum 3   = CH
565     toEnum 4   = HS
566     toEnum 255 = AnyClass
567     toEnum _   = undefined
568
569 instance Binary RecordClass where
570     get = liftM (toEnum . fromIntegral) G.getWord16be
571     put = P'.putWord16be . fromIntegral . fromEnum
572
573
574 defaultRTTable :: IntMap SomeRT
575 defaultRTTable = IM.fromList $ map toPair $
576                  [ wrapRecordType A
577                  , wrapRecordType NS
578                  , wrapRecordType CNAME
579                  , wrapRecordType HINFO
580                  ]
581     where
582       toPair :: SomeRT -> (Int, SomeRT)
583       toPair srt@(SomeRT rt) = (rtToInt rt, srt)
584
585
586 wrapQueryType :: RecordType rt dt => rt -> SomeQT
587 wrapQueryType = SomeRT
588
589 wrapRecordType :: RecordType rt dt => rt -> SomeRT
590 wrapRecordType = SomeRT
591
592 wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
593 wrapRecord = SomeRR