]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Message.hs
Introduce Unpacker monad to clean up things.
[haskell-dns.git] / Network / DNS / Message.hs
1 module Network.DNS.Message
2     ( Message(..)
3     , MessageID
4     , MessageType(..)
5     , Header(..)
6     , Opcode(..)
7     , ResponseCode(..)
8     , Question(..)
9     , ResourceRecord(..)
10     , DomainName
11     , DomainLabel
12     , TTL
13     , RecordType
14     , RecordClass(..)
15
16     , SomeQT
17     , SomeRR
18     , SomeRT
19
20     , A(..)
21     , NS(..)
22     , CNAME(..)
23     , HINFO(..)
24
25     , mkDomainName
26     , wrapQueryType
27     , wrapRecordType
28     , wrapRecord
29     )
30     where
31
32 import           Control.Exception
33 import           Control.Monad
34 import           Data.Binary
35 import           Data.Binary.BitPut as BP
36 import           Data.Binary.Get as G
37 import           Data.Binary.Put as P
38 import           Data.Binary.Strict.BitGet as BG
39 import qualified Data.ByteString as BS
40 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
41 import qualified Data.ByteString.Lazy as LBS
42 import           Data.Typeable
43 import qualified Data.IntMap as IM
44 import           Data.IntMap (IntMap)
45 import           Data.Word
46 import           Network.DNS.Unpacker as U
47 import           Network.Socket
48
49
50 data Message
51     = Message {
52         msgHeader      :: !Header
53       , msgQuestions   :: ![Question]
54       , msgAnswers     :: ![SomeRR]
55       , msgAuthorities :: ![SomeRR]
56       , msgAdditionals :: ![SomeRR]
57       }
58     deriving (Show, Eq)
59
60 data Header
61     = Header {
62         hdMessageID             :: !MessageID
63       , hdMessageType           :: !MessageType
64       , hdOpcode                :: !Opcode
65       , hdIsAuthoritativeAnswer :: !Bool
66       , hdIsTruncated           :: !Bool
67       , hdIsRecursionDesired    :: !Bool
68       , hdIsRecursionAvailable  :: !Bool
69       , hdResponseCode          :: !ResponseCode
70
71       -- These fields are supressed in this data structure:
72       -- + QDCOUNT
73       -- + ANCOUNT
74       -- + NSCOUNT
75       -- + ARCOUNT
76       }
77     deriving (Show, Eq)
78
79 type MessageID = Word16
80
81 data MessageType
82     = Query
83     | Response
84     deriving (Show, Eq)
85
86 data Opcode
87     = StandardQuery
88     | InverseQuery
89     | ServerStatusRequest
90     deriving (Show, Eq)
91
92 data ResponseCode
93     = NoError
94     | FormatError
95     | ServerFailure
96     | NameError
97     | NotImplemented
98     | Refused
99     deriving (Show, Eq)
100
101 data Question
102     = Question {
103         qName  :: !DomainName
104       , qType  :: !SomeQT
105       , qClass :: !RecordClass
106       }
107     deriving (Show, Eq)
108
109 type SomeQT = SomeRT
110
111 putQ :: Question -> Put
112 putQ q
113     = do putDomainName $ qName q
114          putSomeRT $ qType q
115          put $ qClass q
116
117 getQ :: Unpacker DecompTable Question
118 getQ = do nm <- getDomainName
119           ty <- getSomeRT
120           cl <- getBinary
121           return Question {
122                        qName  = nm
123                      , qType  = ty
124                      , qClass = cl
125                      }
126
127
128 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
129 type DomainLabel    = BS.ByteString
130
131 nameToLabels :: DomainName -> [DomainLabel]
132 nameToLabels (DN ls) = ls
133
134 labelsToName :: [DomainLabel] -> DomainName
135 labelsToName = DN
136
137 rootName :: DomainName
138 rootName = DN [BS.empty]
139
140 consLabel :: DomainLabel -> DomainName -> DomainName
141 consLabel x (DN ys) = DN (x:ys)
142
143 mkDomainName :: String -> DomainName
144 mkDomainName = labelsToName . mkLabels [] . notEmpty
145     where
146       notEmpty :: String -> String
147       notEmpty xs = assert (not $ null xs) xs
148
149       mkLabels :: [DomainLabel] -> String -> [DomainLabel]
150       mkLabels soFar [] = reverse (C8.empty : soFar)
151       mkLabels soFar xs = case break (== '.') xs of
152                             (l, ('.':rest))
153                                 -> mkLabels (C8.pack l : soFar) rest
154                             _   -> error ("Illegal domain name: " ++ xs)
155
156 data RecordClass
157     = IN
158     | CS -- Obsolete
159     | CH
160     | HS
161     | AnyClass -- Only for queries
162     deriving (Show, Eq)
163
164
165 data RecordType rt dt => ResourceRecord rt dt
166     = ResourceRecord {
167         rrName  :: !DomainName
168       , rrType  :: !rt
169       , rrClass :: !RecordClass
170       , rrTTL   :: !TTL
171       , rrData  :: !dt
172       }
173     deriving (Show, Eq, Typeable)
174
175
176 putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
177 putRR rr = do putDomainName $ rrName rr
178               putRecordType $ rrType  rr
179               put $ rrClass rr
180               putWord32be $ rrTTL rr
181
182               let dat = runPut $
183                         putRecordData (undefined :: rt) (rrData rr)
184               putWord16be $ fromIntegral $ LBS.length dat
185               putLazyByteString dat
186
187
188 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
189
190 instance Show SomeRR where
191     show (SomeRR rr) = show rr
192
193 instance Eq SomeRR where
194     (SomeRR a) == (SomeRR b) = Just a == cast b
195
196
197 putSomeRR :: SomeRR -> Put
198 putSomeRR (SomeRR rr) = putRR rr
199
200 getSomeRR :: Unpacker DecompTable SomeRR
201 getSomeRR = do srt <- U.lookAhead $
202                       do getDomainName -- skip
203                          getSomeRT
204                case srt of
205                  SomeRT rt
206                      -> getResourceRecord rt >>= return . SomeRR
207
208 type DecompTable = IntMap DomainName
209 type TTL = Word32
210
211 getDomainName :: Unpacker DecompTable DomainName
212 getDomainName = worker
213     where
214       worker :: Unpacker DecompTable DomainName
215       worker
216           = do offset <- U.bytesRead
217                hdr    <- getLabelHeader
218                case hdr of
219                  Offset n
220                      -> do dt <- getState
221                            case IM.lookup n dt of
222                              Just name
223                                  -> return name
224                              Nothing
225                                  -> fail ("Illegal offset of label pointer: " ++ show (n, dt))
226                  Length 0
227                      -> return rootName
228                  Length n
229                      -> do label <- U.getByteString n
230                            rest  <- worker
231                            let name = consLabel label rest
232                            modifyState $ IM.insert offset name
233                            return name
234
235       getLabelHeader :: Unpacker s LabelHeader
236       getLabelHeader
237           = do header <- U.lookAhead $ U.getByteString 1
238                let Right h
239                        = runBitGet header $
240                          do a <- getBit
241                             b <- getBit
242                             n <- liftM fromIntegral (getAsWord8 6)
243                             case (a, b) of
244                               ( True,  True) -> return $ Offset n
245                               (False, False) -> return $ Length n
246                               _              -> fail "Illegal label header"
247                case h of
248                  Offset _
249                      -> do header' <- U.getByteString 2 -- Pointers have 2 octets.
250                            let Right h'
251                                    = runBitGet header' $
252                                      do BG.skip 2
253                                         n <- liftM fromIntegral (getAsWord16 14)
254                                         return $ Offset n
255                            return h'
256                  len@(Length _)
257                      -> do U.skip 1
258                            return len
259
260
261 getCharString :: Unpacker s BS.ByteString
262 getCharString = do len <- U.getWord8
263                    U.getByteString (fromIntegral len)
264
265 putCharString :: BS.ByteString -> Put
266 putCharString = putDomainLabel
267
268 data LabelHeader
269     = Offset !Int
270     | Length !Int
271
272 putDomainName :: DomainName -> Put
273 putDomainName = mapM_ putDomainLabel . nameToLabels
274
275 putDomainLabel :: DomainLabel -> Put
276 putDomainLabel l
277     = do putWord8 $ fromIntegral $ BS.length l
278          P.putByteString l
279
280 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
281     rtToInt       :: rt -> Int
282     putRecordData :: rt -> dt -> Put
283     getRecordData :: rt -> Unpacker DecompTable dt
284
285     putRecordType :: rt -> Put
286     putRecordType = putWord16be . fromIntegral . rtToInt
287
288     getResourceRecord :: rt -> Unpacker DecompTable (ResourceRecord rt dt)
289     getResourceRecord rt
290         = do name     <- getDomainName
291              U.skip 2 -- record type
292              cl       <- getBinary
293              ttl      <- U.getWord32be
294              U.skip 2 -- data length
295              dat      <- getRecordData rt
296              return $ ResourceRecord {
297                           rrName  = name
298                         , rrType  = rt
299                         , rrClass = cl
300                         , rrTTL   = ttl
301                         , rrData  = dat
302                         }
303
304 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
305
306 instance Show SomeRT where
307     show (SomeRT rt) = show rt
308
309 instance Eq SomeRT where
310     (SomeRT a) == (SomeRT b) = Just a == cast b
311
312 putSomeRT :: SomeRT -> Put
313 putSomeRT (SomeRT rt) = putRecordType rt
314
315 getSomeRT :: Unpacker s SomeRT
316 getSomeRT = do n <- liftM fromIntegral U.getWord16be
317                case IM.lookup n defaultRTTable of
318                  Nothing
319                      -> fail ("Unknown resource record type: " ++ show n)
320                  Just srt
321                      -> return srt
322
323 data A = A deriving (Show, Eq, Typeable)
324 instance RecordType A HostAddress where
325     rtToInt       _ = 1
326     putRecordData _ = putWord32be
327     getRecordData _ = U.getWord32be
328
329 data NS = NS deriving (Show, Eq, Typeable)
330 instance RecordType NS DomainName where
331     rtToInt       _ = 2
332     putRecordData _ = putDomainName
333     getRecordData _ = getDomainName
334
335 data CNAME = CNAME deriving (Show, Eq, Typeable)
336 instance RecordType CNAME DomainName where
337     rtToInt       _ = 5
338     putRecordData _ = putDomainName
339     getRecordData _ = getDomainName
340
341 data HINFO = HINFO deriving (Show, Eq, Typeable)
342 instance RecordType HINFO (BS.ByteString, BS.ByteString) where
343     rtToInt       _           = 13
344     putRecordData _ (cpu, os) = do putCharString cpu
345                                    putCharString os
346     getRecordData _           = do cpu <- getCharString
347                                    os  <- getCharString
348                                    return (cpu, os)
349
350
351 {-
352 data RecordType
353     = A
354     | NS
355     | MD
356     | MF
357     | CNAME
358     | SOA
359     | MB
360     | MG
361     | MR
362     | NULL
363     | WKS
364     | PTR
365     | HINFO
366     | MINFO
367     | MX
368     | TXT
369
370     -- Only for queries:
371     | AXFR
372     | MAILB -- Obsolete
373     | MAILA -- Obsolete
374     | AnyType
375     deriving (Show, Eq)
376 -}
377
378 instance Binary Message where
379     put m = do put $ msgHeader m
380                putWord16be $ fromIntegral $ length $ msgQuestions m
381                putWord16be $ fromIntegral $ length $ msgAnswers m
382                putWord16be $ fromIntegral $ length $ msgAuthorities m
383                putWord16be $ fromIntegral $ length $ msgAdditionals m
384                mapM_ putQ  $ msgQuestions m
385                mapM_ putSomeRR $ msgAnswers m
386                mapM_ putSomeRR $ msgAuthorities m
387                mapM_ putSomeRR $ msgAdditionals m
388
389     get = liftToBinary IM.empty $
390           do hdr  <- getBinary
391              nQ   <- liftM fromIntegral U.getWord16be
392              nAns <- liftM fromIntegral U.getWord16be
393              nAth <- liftM fromIntegral U.getWord16be
394              nAdd <- liftM fromIntegral U.getWord16be
395              qs   <- replicateM nQ   getQ
396              anss <- replicateM nAns getSomeRR
397              aths <- replicateM nAth getSomeRR
398              adds <- replicateM nAdd getSomeRR
399              return Message {
400                           msgHeader      = hdr
401                         , msgQuestions   = qs
402                         , msgAnswers     = anss
403                         , msgAuthorities = aths
404                         , msgAdditionals = adds
405                         }
406
407 instance Binary Header where
408     put h = do putWord16be $ hdMessageID h
409                putLazyByteString flags
410         where
411           flags = runBitPut $
412                   do putNBits 1 $ fromEnum $ hdMessageType h
413                      putNBits 4 $ fromEnum $ hdOpcode h
414                      putBit $ hdIsAuthoritativeAnswer h
415                      putBit $ hdIsTruncated h
416                      putBit $ hdIsRecursionDesired h
417                      putBit $ hdIsRecursionAvailable h
418                      putNBits 3 (0 :: Int)
419                      putNBits 4 $ fromEnum $ hdResponseCode h
420
421     get = do mID   <- G.getWord16be
422              flags <- G.getByteString 2
423              let Right hd
424                      = runBitGet flags $
425                        do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
426                           op <- liftM (toEnum . fromIntegral) $ getAsWord8 4
427                           aa <- getBit
428                           tc <- getBit
429                           rd <- getBit
430                           ra <- getBit
431                           BG.skip 3
432                           rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4
433                           return Header {
434                                        hdMessageID             = mID
435                                      , hdMessageType           = qr
436                                      , hdOpcode                = op
437                                      , hdIsAuthoritativeAnswer = aa
438                                      , hdIsTruncated           = tc
439                                      , hdIsRecursionDesired    = rd
440                                      , hdIsRecursionAvailable  = ra
441                                      , hdResponseCode          = rc
442                                      }
443              return hd
444
445 instance Enum MessageType where
446     fromEnum Query    = 0
447     fromEnum Response = 1
448
449     toEnum 0 = Query
450     toEnum 1 = Response
451     toEnum _ = undefined
452
453 instance Enum Opcode where
454     fromEnum StandardQuery       = 0
455     fromEnum InverseQuery        = 1
456     fromEnum ServerStatusRequest = 2
457
458     toEnum 0 = StandardQuery
459     toEnum 1 = InverseQuery
460     toEnum 2 = ServerStatusRequest
461     toEnum _ = undefined
462
463 instance Enum ResponseCode where
464     fromEnum NoError        = 0
465     fromEnum FormatError    = 1
466     fromEnum ServerFailure  = 2
467     fromEnum NameError      = 3
468     fromEnum NotImplemented = 4
469     fromEnum Refused        = 5
470
471     toEnum 0 = NoError
472     toEnum 1 = FormatError
473     toEnum 2 = ServerFailure
474     toEnum 3 = NameError
475     toEnum 4 = NotImplemented
476     toEnum 5 = Refused
477     toEnum _ = undefined
478
479 {-
480 instance Enum RecordType where
481     fromEnum A       = 1
482     fromEnum NS      = 2
483     fromEnum MD      = 3
484     fromEnum MF      = 4
485     fromEnum CNAME   = 5
486     fromEnum SOA     = 6
487     fromEnum MB      = 7
488     fromEnum MG      = 8
489     fromEnum MR      = 9
490     fromEnum NULL    = 10
491     fromEnum WKS     = 11
492     fromEnum PTR     = 12
493     fromEnum HINFO   = 13
494     fromEnum MINFO   = 14
495     fromEnum MX      = 15
496     fromEnum TXT     = 16
497     fromEnum AXFR    = 252
498     fromEnum MAILB   = 253
499     fromEnum MAILA   = 254
500     fromEnum AnyType = 255
501
502     toEnum 1  = A
503     toEnum 2  = NS
504     toEnum 3  = MD
505     toEnum 4  = MF
506     toEnum 5  = CNAME
507     toEnum 6  = SOA
508     toEnum 7  = MB
509     toEnum 8  = MG
510     toEnum 9  = MR
511     toEnum 10 = NULL
512     toEnum 11 = WKS
513     toEnum 12 = PTR
514     toEnum 13 = HINFO
515     toEnum 14 = MINFO
516     toEnum 15 = MX
517     toEnum 16 = TXT
518     toEnum 252 = AXFR
519     toEnum 253 = MAILB
520     toEnum 254 = MAILA
521     toEnum 255 = AnyType
522     toEnum _  = undefined
523 -}
524
525 instance Enum RecordClass where
526     fromEnum IN       = 1
527     fromEnum CS       = 2
528     fromEnum CH       = 3
529     fromEnum HS       = 4
530     fromEnum AnyClass = 255
531
532     toEnum 1   = IN
533     toEnum 2   = CS
534     toEnum 3   = CH
535     toEnum 4   = HS
536     toEnum 255 = AnyClass
537     toEnum _   = undefined
538
539 instance Binary RecordClass where
540     get = liftM (toEnum . fromIntegral) G.getWord16be
541     put = putWord16be . fromIntegral . fromEnum
542
543
544 defaultRTTable :: IntMap SomeRT
545 defaultRTTable = IM.fromList $ map toPair $
546                  [ wrapRecordType A
547                  , wrapRecordType NS
548                  , wrapRecordType CNAME
549                  , wrapRecordType HINFO
550                  ]
551     where
552       toPair :: SomeRT -> (Int, SomeRT)
553       toPair srt@(SomeRT rt) = (rtToInt rt, srt)
554
555
556 wrapQueryType :: RecordType rt dt => rt -> SomeQT
557 wrapQueryType = SomeRT
558
559 wrapRecordType :: RecordType rt dt => rt -> SomeRT
560 wrapRecordType = SomeRT
561
562 wrapRecord :: RecordType rt dt => ResourceRecord rt dt -> SomeRR
563 wrapRecord = SomeRR