]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Message.hs
7bedacf5a0922b1816280a8ae4162f9aaf3ff698
[haskell-dns.git] / Network / DNS / Message.hs
1 module Network.DNS.Message
2     ( Message(..)
3     , MessageID
4     , MessageType(..)
5     , Header(..)
6     , Opcode(..)
7     , ResponseCode(..)
8     , Question(..)
9     , ResourceRecord(..)
10     , DomainName
11     , DomainLabel
12     , TTL
13     , RecordType
14     , RecordClass(..)
15
16     , SomeQT
17     , SomeRR
18     , SomeRT
19
20     , A(..)
21     , NS(..)
22     , CNAME(..)
23     , HINFO(..)
24
25     , mkDomainName
26     , wrapQueryType
27     , wrapRecordType
28     , wrapRecord
29     )
30     where
31
32 import           Control.Exception
33 import           Control.Monad
34 import           Data.Binary
35 import           Data.Binary.BitPut as BP
36 import           Data.Binary.Get as G
37 import           Data.Binary.Put as P
38 import           Data.Binary.Strict.BitGet as BG
39 import qualified Data.ByteString as BS
40 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
41 import qualified Data.ByteString.Lazy as LBS
42 import           Data.Typeable
43 import qualified Data.IntMap as IM
44 import           Data.IntMap (IntMap)
45 import           Data.Word
46 import           Network.Socket
47
48
49 replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a)
50 replicateM' = worker []
51     where
52       worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a)
53       worker soFar 0 _ a = return (reverse soFar, a)
54       worker soFar n f a = do (b, a') <- f a
55                               worker (b : soFar) (n - 1) f a'
56
57
58 data Message
59     = Message {
60         msgHeader      :: !Header
61       , msgQuestions   :: ![Question]
62       , msgAnswers     :: ![SomeRR]
63       , msgAuthorities :: ![SomeRR]
64       , msgAdditionals :: ![SomeRR]
65       }
66     deriving (Show, Eq)
67
68 data Header
69     = Header {
70         hdMessageID             :: !MessageID
71       , hdMessageType           :: !MessageType
72       , hdOpcode                :: !Opcode
73       , hdIsAuthoritativeAnswer :: !Bool
74       , hdIsTruncated           :: !Bool
75       , hdIsRecursionDesired    :: !Bool
76       , hdIsRecursionAvailable  :: !Bool
77       , hdResponseCode          :: !ResponseCode
78
79       -- These fields are supressed in this data structure:
80       -- + QDCOUNT
81       -- + ANCOUNT
82       -- + NSCOUNT
83       -- + ARCOUNT
84       }
85     deriving (Show, Eq)
86
87 type MessageID = Word16
88
89 data MessageType
90     = Query
91     | Response
92     deriving (Show, Eq)
93
94 data Opcode
95     = StandardQuery
96     | InverseQuery
97     | ServerStatusRequest
98     deriving (Show, Eq)
99
100 data ResponseCode
101     = NoError
102     | FormatError
103     | ServerFailure
104     | NameError
105     | NotImplemented
106     | Refused
107     deriving (Show, Eq)
108
109 data Question
110     = Question {
111         qName  :: !DomainName
112       , qType  :: !SomeQT
113       , qClass :: !RecordClass
114       }
115     deriving (Show, Eq)
116
117 type SomeQT = SomeRT
118
119 putQ :: Question -> Put
120 putQ q
121     = do putDomainName $ qName q
122          putSomeRT $ qType q
123          put $ qClass q
124
125 getQ :: DecompTable -> Get (Question, DecompTable)
126 getQ dt
127     = do (nm, dt') <- getDomainName dt
128          ty        <- getSomeRT
129          cl        <- get
130          let q = Question {
131                    qName  = nm
132                  , qType  = ty
133                  , qClass = cl
134                  }
135          return (q, dt')
136
137 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
138 type DomainLabel    = BS.ByteString
139
140 nameToLabels :: DomainName -> [DomainLabel]
141 nameToLabels (DN ls) = ls
142
143 labelsToName :: [DomainLabel] -> DomainName
144 labelsToName = DN
145
146 rootName :: DomainName
147 rootName = DN [BS.empty]
148
149 consLabel :: DomainLabel -> DomainName -> DomainName
150 consLabel x (DN ys) = DN (x:ys)
151
152 mkDomainName :: String -> DomainName
153 mkDomainName = labelsToName . mkLabels [] . notEmpty
154     where
155       notEmpty :: String -> String
156       notEmpty xs = assert (not $ null xs) xs
157
158       mkLabels :: [DomainLabel] -> String -> [DomainLabel]
159       mkLabels soFar [] = reverse (C8.empty : soFar)
160       mkLabels soFar xs = case break (== '.') xs of
161                             (l, ('.':rest))
162                                 -> mkLabels (C8.pack l : soFar) rest
163                             _   -> error ("Illegal domain name: " ++ xs)
164
165 data RecordClass
166     = IN
167     | CS -- Obsolete
168     | CH
169     | HS
170     | AnyClass -- Only for queries
171     deriving (Show, Eq)
172
173
174 data RecordType rt dt => ResourceRecord rt dt
175     = ResourceRecord {
176         rrName  :: !DomainName
177       , rrType  :: !rt
178       , rrClass :: !RecordClass
179       , rrTTL   :: !TTL
180       , rrData  :: !dt
181       }
182     deriving (Show, Eq, Typeable)
183
184
185 putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
186 putRR rr = do putDomainName $ rrName rr
187               putRecordType $ rrType  rr
188               put $ rrClass rr
189               putWord32be $ rrTTL rr
190
191               let dat = runPut $
192                         putRecordData (undefined :: rt) (rrData rr)
193               putWord16be $ fromIntegral $ LBS.length dat
194               putLazyByteString dat
195
196
197 getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable)
198 getRR dt rt
199     = do (nm, dt1)  <- getDomainName dt
200          G.skip 2   -- record type
201          cl         <- get
202          ttl        <- G.getWord32be
203          G.skip 2   -- data length
204          (dat, dt2) <- getRecordData (undefined :: rt) dt1
205
206          let rr = ResourceRecord {
207                     rrName  = nm
208                   , rrType  = rt
209                   , rrClass = cl
210                   , rrTTL   = ttl
211                   , rrData  = dat
212                   }
213          return (rr, dt2)
214
215
216 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
217
218 instance Show SomeRR where
219     show (SomeRR rr) = show rr
220
221 instance Eq SomeRR where
222     (SomeRR a) == (SomeRR b) = Just a == cast b
223
224
225 putSomeRR :: SomeRR -> Put
226 putSomeRR (SomeRR rr) = putRR rr
227
228 getSomeRR :: DecompTable -> Get (SomeRR, DecompTable)
229 getSomeRR dt
230     = do srt <- lookAhead $
231                 do getDomainName dt -- skip
232                    getSomeRT
233          case srt of
234            SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt')
235
236
237 type DecompTable = IntMap DomainName
238 type TTL = Word32
239
240 getDomainName :: DecompTable -> Get (DomainName, DecompTable)
241 getDomainName = worker
242     where
243       worker :: DecompTable -> Get (DomainName, DecompTable)
244       worker dt
245           = do offset <- liftM fromIntegral bytesRead
246                hdr    <- getLabelHeader
247                case hdr of
248                  Offset n
249                      -> case IM.lookup n dt of
250                           Just name
251                               -> return (name, dt)
252                           Nothing
253                               -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
254                  Length 0
255                      -> return (rootName, dt)
256                  Length n
257                      -> do label       <- getByteString n
258                            (rest, dt') <- worker dt
259                            let name = consLabel label rest
260                                dt'' = IM.insert offset name dt'
261                            return (name, dt'')
262
263       getLabelHeader :: Get LabelHeader
264       getLabelHeader
265           = do header <- lookAhead $ getByteString 1
266                let Right h
267                        = runBitGet header $
268                          do a <- getBit
269                             b <- getBit
270                             n <- liftM fromIntegral (getAsWord8 6)
271                             case (a, b) of
272                               ( True,  True) -> return $ Offset n
273                               (False, False) -> return $ Length n
274                               _              -> fail "Illegal label header"
275                case h of
276                  Offset _
277                      -> do header' <- getByteString 2 -- Pointers have 2 octets.
278                            let Right h'
279                                    = runBitGet header' $
280                                      do BG.skip 2
281                                         n <- liftM fromIntegral (getAsWord16 14)
282                                         return $ Offset n
283                            return h'
284                  len@(Length _)
285                      -> do G.skip 1
286                            return len
287
288
289 getCharString :: Get BS.ByteString
290 getCharString = do len <- G.getWord8
291                    getByteString (fromIntegral len)
292
293 putCharString :: BS.ByteString -> Put
294 putCharString = putDomainLabel
295
296 data LabelHeader
297     = Offset !Int
298     | Length !Int
299
300 putDomainName :: DomainName -> Put
301 putDomainName = mapM_ putDomainLabel . nameToLabels
302
303 putDomainLabel :: DomainLabel -> Put
304 putDomainLabel l
305     = do putWord8 $ fromIntegral $ BS.length l
306          P.putByteString l
307
308 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
309     rtToInt       :: rt -> Int
310     putRecordType :: rt -> Put
311     putRecordData :: rt -> dt -> Put
312     getRecordData :: rt -> DecompTable -> Get (dt, DecompTable)
313
314     putRecordType = putWord16be . fromIntegral . rtToInt
315
316 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
317
318 instance Show SomeRT where
319     show (SomeRT rt) = show rt
320
321 instance Eq SomeRT where
322     (SomeRT a) == (SomeRT b) = Just a == cast b
323
324 putSomeRT :: SomeRT -> Put
325 putSomeRT (SomeRT rt) = putRecordType rt
326
327 getSomeRT :: Get SomeRT
328 getSomeRT = do n <- liftM fromIntegral G.getWord16be
329                case IM.lookup n defaultRTTable of
330                  Nothing
331                      -> fail ("Unknown resource record type: " ++ show n)
332                  Just srt
333                      -> return srt
334
335 data A = A deriving (Show, Eq, Typeable)
336 instance RecordType A HostAddress where
337     rtToInt       _ = 1
338     putRecordData _ = putWord32be
339     getRecordData _ = \ dt ->
340                       do addr <- G.getWord32be
341                          return (addr, dt)
342
343 data NS = NS deriving (Show, Eq, Typeable)
344 instance RecordType NS DomainName where
345     rtToInt       _ = 2
346     putRecordData _ = putDomainName
347     getRecordData _ = getDomainName
348
349 data CNAME = CNAME deriving (Show, Eq, Typeable)
350 instance RecordType CNAME DomainName where
351     rtToInt       _ = 5
352     putRecordData _ = putDomainName
353     getRecordData _ = getDomainName
354
355 data HINFO = HINFO deriving (Show, Eq, Typeable)
356 instance RecordType HINFO (BS.ByteString, BS.ByteString) where
357     rtToInt       _           = 13
358     putRecordData _ (cpu, os) = do putCharString cpu
359                                    putCharString os
360     getRecordData _ dt        = do cpu <- getCharString
361                                    os  <- getCharString
362                                    return ((cpu, os), dt)
363
364
365 {-
366 data RecordType
367     = A
368     | NS
369     | MD
370     | MF
371     | CNAME
372     | SOA
373     | MB
374     | MG
375     | MR
376     | NULL
377     | WKS
378     | PTR
379     | HINFO
380     | MINFO
381     | MX
382     | TXT
383
384     -- Only for queries:
385     | AXFR
386     | MAILB -- Obsolete
387     | MAILA -- Obsolete
388     | AnyType
389     deriving (Show, Eq)
390 -}
391
392 instance Binary Message where
393     put m = do put $ msgHeader m
394                putWord16be $ fromIntegral $ length $ msgQuestions m
395                putWord16be $ fromIntegral $ length $ msgAnswers m
396                putWord16be $ fromIntegral $ length $ msgAuthorities m
397                putWord16be $ fromIntegral $ length $ msgAdditionals m
398                mapM_ putQ  $ msgQuestions m
399                mapM_ putSomeRR $ msgAnswers m
400                mapM_ putSomeRR $ msgAuthorities m
401                mapM_ putSomeRR $ msgAdditionals m
402
403     get = do hdr  <- get
404              nQ   <- liftM fromIntegral G.getWord16be
405              nAns <- liftM fromIntegral G.getWord16be
406              nAth <- liftM fromIntegral G.getWord16be
407              nAdd <- liftM fromIntegral G.getWord16be
408              (qs  , dt1) <- replicateM' nQ   getQ IM.empty
409              (anss, dt2) <- replicateM' nAns getSomeRR dt1
410              (aths, dt3) <- replicateM' nAth getSomeRR dt2
411              (adds, _  ) <- replicateM' nAdd getSomeRR dt3
412              return Message {
413                           msgHeader      = hdr
414                         , msgQuestions   = qs
415                         , msgAnswers     = anss
416                         , msgAuthorities = aths
417                         , msgAdditionals = adds
418                         }
419
420 instance Binary Header where
421     put h = do putWord16be $ hdMessageID h
422                putLazyByteString flags
423         where
424           flags = runBitPut $
425                   do putNBits 1 $ fromEnum $ hdMessageType h
426                      putNBits 4 $ fromEnum $ hdOpcode h
427                      putBit $ hdIsAuthoritativeAnswer h
428                      putBit $ hdIsTruncated h
429                      putBit $ hdIsRecursionDesired h
430                      putBit $ hdIsRecursionAvailable h
431                      putNBits 3 (0 :: Int)
432                      putNBits 4 $ fromEnum $ hdResponseCode h
433
434     get = do mID   <- G.getWord16be
435              flags <- getByteString 2
436              let Right hd
437                      = runBitGet flags $
438                        do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
439                           op <- liftM (toEnum . fromIntegral) $ getAsWord8 4
440                           aa <- getBit
441                           tc <- getBit
442                           rd <- getBit
443                           ra <- getBit
444                           BG.skip 3
445                           rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4
446                           return Header {
447                                        hdMessageID             = mID
448                                      , hdMessageType           = qr
449                                      , hdOpcode                = op
450                                      , hdIsAuthoritativeAnswer = aa
451                                      , hdIsTruncated           = tc
452                                      , hdIsRecursionDesired    = rd
453                                      , hdIsRecursionAvailable  = ra
454                                      , hdResponseCode          = rc
455                                      }
456              return hd
457
458 instance Enum MessageType where
459     fromEnum Query    = 0
460     fromEnum Response = 1
461
462     toEnum 0 = Query
463     toEnum 1 = Response
464     toEnum _ = undefined
465
466 instance Enum Opcode where
467     fromEnum StandardQuery       = 0
468     fromEnum InverseQuery        = 1
469     fromEnum ServerStatusRequest = 2
470
471     toEnum 0 = StandardQuery
472     toEnum 1 = InverseQuery
473     toEnum 2 = ServerStatusRequest
474     toEnum _ = undefined
475
476 instance Enum ResponseCode where
477     fromEnum NoError        = 0
478     fromEnum FormatError    = 1
479     fromEnum ServerFailure  = 2
480     fromEnum NameError      = 3
481     fromEnum NotImplemented = 4
482     fromEnum Refused        = 5
483
484     toEnum 0 = NoError
485     toEnum 1 = FormatError
486     toEnum 2 = ServerFailure
487     toEnum 3 = NameError
488     toEnum 4 = NotImplemented
489     toEnum 5 = Refused
490     toEnum _ = undefined
491
492 {-
493 instance Enum RecordType where
494     fromEnum A       = 1
495     fromEnum NS      = 2
496     fromEnum MD      = 3
497     fromEnum MF      = 4
498     fromEnum CNAME   = 5
499     fromEnum SOA     = 6
500     fromEnum MB      = 7
501     fromEnum MG      = 8
502     fromEnum MR      = 9
503     fromEnum NULL    = 10
504     fromEnum WKS     = 11
505     fromEnum PTR     = 12
506     fromEnum HINFO   = 13
507     fromEnum MINFO   = 14
508     fromEnum MX      = 15
509     fromEnum TXT     = 16
510     fromEnum AXFR    = 252
511     fromEnum MAILB   = 253
512     fromEnum MAILA   = 254
513     fromEnum AnyType = 255
514
515     toEnum 1  = A
516     toEnum 2  = NS
517     toEnum 3  = MD
518     toEnum 4  = MF
519     toEnum 5  = CNAME
520     toEnum 6  = SOA
521     toEnum 7  = MB
522     toEnum 8  = MG
523     toEnum 9  = MR
524     toEnum 10 = NULL
525     toEnum 11 = WKS
526     toEnum 12 = PTR
527     toEnum 13 = HINFO
528     toEnum 14 = MINFO
529     toEnum 15 = MX
530     toEnum 16 = TXT
531     toEnum 252 = AXFR
532     toEnum 253 = MAILB
533     toEnum 254 = MAILA
534     toEnum 255 = AnyType
535     toEnum _  = undefined
536 -}
537
538 instance Enum RecordClass where
539     fromEnum IN       = 1
540     fromEnum CS       = 2
541     fromEnum CH       = 3
542     fromEnum HS       = 4
543     fromEnum AnyClass = 255
544
545     toEnum 1   = IN
546     toEnum 2   = CS
547     toEnum 3   = CH
548     toEnum 4   = HS
549     toEnum 255 = AnyClass
550     toEnum _   = undefined
551
552 instance Binary RecordClass where
553     get = liftM (toEnum . fromIntegral) G.getWord16be
554     put = putWord16be . fromIntegral . fromEnum
555
556
557 defaultRTTable :: IntMap SomeRT
558 defaultRTTable = IM.fromList $ map toPair $
559                  [ wrapRecordType A
560                  , wrapRecordType NS
561                  , wrapRecordType CNAME
562                  , wrapRecordType HINFO
563                  ]
564     where
565       toPair :: SomeRT -> (Int, SomeRT)
566       toPair srt@(SomeRT rt) = (rtToInt rt, srt)
567
568
569 wrapQueryType :: RecordType rt dt => rt -> SomeQT
570 wrapQueryType = SomeRT
571
572 wrapRecordType :: RecordType rt dt => rt -> SomeRT
573 wrapRecordType = SomeRT
574
575 wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
576 wrapRecord = SomeRR