]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Message.hs
Add DNSUnitTest.hs
[haskell-dns.git] / Network / DNS / Message.hs
1 module Network.DNS.Message
2     ( Message(..)
3     , MessageID
4     , MessageType(..)
5     , Header(..)
6     , Opcode(..)
7     , ResponseCode(..)
8     , Question(..)
9     , ResourceRecord(..)
10     , DomainName
11     , DomainLabel
12     , TTL
13     , RecordType
14     , RecordClass(..)
15
16     , SomeRR(..)
17     , SomeRT(..)
18
19     , CNAME(..)
20     , HINFO(..)
21
22     , mkQueryType
23     , mkDomainName
24     )
25     where
26
27 import           Control.Exception
28 import           Control.Monad
29 import           Data.Binary
30 import           Data.Binary.BitPut as BP
31 import           Data.Binary.Get as G
32 import           Data.Binary.Put as P
33 import           Data.Binary.Strict.BitGet as BG
34 import qualified Data.ByteString as BS
35 import qualified Data.ByteString.Char8 as C8 hiding (ByteString)
36 import qualified Data.ByteString.Lazy as LBS
37 import           Data.Typeable
38 import qualified Data.IntMap as IM
39 import           Data.IntMap (IntMap)
40 import           Data.Word
41
42
43 replicateM' :: Monad m => Int -> (a -> m (b, a)) -> a -> m ([b], a)
44 replicateM' = worker []
45     where
46       worker :: Monad m => [b] -> Int -> (a -> m (b, a)) -> a -> m ([b], a)
47       worker soFar 0 _ a = return (reverse soFar, a)
48       worker soFar n f a = do (b, a') <- f a
49                               worker (b : soFar) (n - 1) f a'
50
51
52 data Message
53     = Message {
54         msgHeader      :: !Header
55       , msgQuestions   :: ![Question]
56       , msgAnswers     :: ![SomeRR]
57       , msgAuthorities :: ![SomeRR]
58       , msgAdditionals :: ![SomeRR]
59       }
60     deriving (Show, Eq)
61
62 data Header
63     = Header {
64         hdMessageID             :: !MessageID
65       , hdMessageType           :: !MessageType
66       , hdOpcode                :: !Opcode
67       , hdIsAuthoritativeAnswer :: !Bool
68       , hdIsTruncated           :: !Bool
69       , hdIsRecursionDesired    :: !Bool
70       , hdIsRecursionAvailable  :: !Bool
71       , hdResponseCode          :: !ResponseCode
72
73       -- These fields are supressed in this data structure:
74       -- + QDCOUNT
75       -- + ANCOUNT
76       -- + NSCOUNT
77       -- + ARCOUNT
78       }
79     deriving (Show, Eq)
80
81 type MessageID = Word16
82
83 data MessageType
84     = Query
85     | Response
86     deriving (Show, Eq)
87
88 data Opcode
89     = StandardQuery
90     | InverseQuery
91     | ServerStatusRequest
92     deriving (Show, Eq)
93
94 data ResponseCode
95     = NoError
96     | FormatError
97     | ServerFailure
98     | NameError
99     | NotImplemented
100     | Refused
101     deriving (Show, Eq)
102
103 data Question
104     = Question {
105         qName  :: !DomainName
106       , qType  :: !SomeQT
107       , qClass :: !RecordClass
108       }
109     deriving (Show, Eq)
110
111 type SomeQT = SomeRT
112
113 mkQueryType :: RecordType rt dt => rt -> SomeQT
114 mkQueryType = SomeRT
115
116 putQ :: Question -> Put
117 putQ q
118     = do putDomainName $ qName q
119          putSomeRT $ qType q
120          put $ qClass q
121
122 getQ :: DecompTable -> Get (Question, DecompTable)
123 getQ dt
124     = do (nm, dt') <- getDomainName dt
125          ty        <- getSomeRT
126          cl        <- get
127          let q = Question {
128                    qName  = nm
129                  , qType  = ty
130                  , qClass = cl
131                  }
132          return (q, dt')
133
134 newtype DomainName  = DN [DomainLabel] deriving (Eq, Show, Typeable)
135 type DomainLabel    = BS.ByteString
136
137 nameToLabels :: DomainName -> [DomainLabel]
138 nameToLabels (DN ls) = ls
139
140 labelsToName :: [DomainLabel] -> DomainName
141 labelsToName = DN
142
143 mkDomainName :: String -> DomainName
144 mkDomainName = labelsToName . mkLabels [] . notEmpty
145     where
146       notEmpty :: String -> String
147       notEmpty xs = assert (not $ null xs) xs
148
149       mkLabels :: [DomainLabel] -> String -> [DomainLabel]
150       mkLabels soFar [] = reverse (C8.empty : soFar)
151       mkLabels soFar xs = case break (== '.') xs of
152                             (l, ('.':rest))
153                                 -> mkLabels (C8.pack l : soFar) rest
154                             _   -> error ("Illegal domain name: " ++ xs)
155
156 data RecordClass
157     = IN
158     | CS -- Obsolete
159     | CH
160     | HS
161     | AnyClass -- Only for queries
162     deriving (Show, Eq)
163
164
165 data RecordType rt dt => ResourceRecord rt dt
166     = ResourceRecord {
167         rrName  :: !DomainName
168       , rrType  :: !rt
169       , rrClass :: !RecordClass
170       , rrTTL   :: !TTL
171       , rrData  :: !dt
172       }
173     deriving (Show, Eq, Typeable)
174
175
176 putRR :: forall rt dt. RecordType rt dt => ResourceRecord rt dt -> Put
177 putRR rr = do putDomainName $ rrName rr
178               putRecordType $ rrType  rr
179               put $ rrClass rr
180               putWord32be $ rrTTL rr
181
182               let dat = runPut $
183                         putRecordData (undefined :: rt) (rrData rr)
184               putWord16be $ fromIntegral $ LBS.length dat
185               putLazyByteString dat
186
187
188 getRR :: forall rt dt. RecordType rt dt => DecompTable -> rt -> Get (ResourceRecord rt dt, DecompTable)
189 getRR dt rt
190     = do (nm, dt1)  <- getDomainName dt
191          G.skip 2   -- record type
192          cl         <- get
193          ttl        <- G.getWord32be
194          G.skip 2   -- data length
195          (dat, dt2) <- getRecordData (undefined :: rt) dt1
196
197          let rr = ResourceRecord {
198                     rrName  = nm
199                   , rrType  = rt
200                   , rrClass = cl
201                   , rrTTL   = ttl
202                   , rrData  = dat
203                   }
204          return (rr, dt2)
205
206
207 data SomeRR = forall rt dt. RecordType rt dt => SomeRR (ResourceRecord rt dt)
208
209 instance Show SomeRR where
210     show (SomeRR rr) = show rr
211
212 instance Eq SomeRR where
213     (SomeRR a) == (SomeRR b) = Just a == cast b
214
215
216 putSomeRR :: SomeRR -> Put
217 putSomeRR (SomeRR rr) = putRR rr
218
219 getSomeRR :: DecompTable -> Get (SomeRR, DecompTable)
220 getSomeRR dt
221     = do srt <- lookAhead $
222                 do getDomainName dt -- skip
223                    getSomeRT
224          case srt of
225            SomeRT rt -> getRR dt rt >>= \ (rr, dt') -> return (SomeRR rr, dt')
226
227
228 type DecompTable = IntMap BS.ByteString
229 type TTL = Word32
230
231 getDomainName :: DecompTable -> Get (DomainName, DecompTable)
232 getDomainName = flip worker []
233     where
234       worker :: DecompTable -> [DomainLabel] -> Get (DomainName, DecompTable)
235       worker dt soFar
236           = do (l, dt') <- getDomainLabel dt
237                case BS.null l of
238                  True  -> return (labelsToName (reverse (l : soFar)), dt')
239                  False -> worker dt' (l : soFar)
240
241 getDomainLabel :: DecompTable -> Get (DomainLabel, DecompTable)
242 getDomainLabel dt
243     = do header <- getByteString 1
244          let Right h
245                  = runBitGet header $
246                    do a <- getBit
247                       b <- getBit
248                       n <- liftM fromIntegral (getAsWord8 6)
249                       case (a, b) of
250                         ( True,  True) -> return $ Offset n
251                         (False, False) -> return $ Length n
252                         _              -> fail "Illegal label header"
253          case h of
254            Offset n
255                -> do let Just l = IM.lookup n dt
256                      return (l, dt)
257            Length n
258                -> do offset <- liftM fromIntegral bytesRead
259                      label  <- getByteString n
260                      let dt' = IM.insert offset label dt
261                      return (label, dt')
262
263 getCharString :: Get BS.ByteString
264 getCharString = do len <- G.getWord8
265                    getByteString (fromIntegral len)
266
267 putCharString :: BS.ByteString -> Put
268 putCharString = putDomainLabel
269
270 data LabelHeader
271     = Offset !Int
272     | Length !Int
273
274 putDomainName :: DomainName -> Put
275 putDomainName = mapM_ putDomainLabel . nameToLabels
276
277 putDomainLabel :: DomainLabel -> Put
278 putDomainLabel l
279     = do putWord8 $ fromIntegral $ BS.length l
280          P.putByteString l
281
282 class (Show rt, Show dt, Eq rt, Eq dt, Typeable rt, Typeable dt) => RecordType rt dt | rt -> dt where
283     rtToInt       :: rt -> Int
284     putRecordType :: rt -> Put
285     putRecordData :: rt -> dt -> Put
286     getRecordData :: rt -> DecompTable -> Get (dt, DecompTable)
287
288     putRecordType = putWord16be . fromIntegral . rtToInt
289
290 data SomeRT = forall rt dt. RecordType rt dt => SomeRT rt
291
292 instance Show SomeRT where
293     show (SomeRT rt) = show rt
294
295 instance Eq SomeRT where
296     (SomeRT a) == (SomeRT b) = Just a == cast b
297
298 putSomeRT :: SomeRT -> Put
299 putSomeRT (SomeRT rt) = putRecordType rt
300
301 getSomeRT :: Get SomeRT
302 getSomeRT = do n <- liftM fromIntegral G.getWord16be
303                case IM.lookup n defaultRTTable of
304                  Nothing
305                      -> fail ("Unknown resource record type: " ++ show n)
306                  Just srt
307                      -> return srt
308
309 data CNAME = CNAME deriving (Show, Eq, Typeable)
310 instance RecordType CNAME DomainName where
311     rtToInt       _ = 5
312     putRecordData _ = putDomainName
313     getRecordData _ = getDomainName
314
315 data HINFO = HINFO deriving (Show, Eq, Typeable)
316 instance RecordType HINFO (BS.ByteString, BS.ByteString) where
317     rtToInt       _           = 13
318     putRecordData _ (cpu, os) = do putCharString cpu
319                                    putCharString os
320     getRecordData _ dt        = do cpu <- getCharString
321                                    os  <- getCharString
322                                    return ((cpu, os), dt)
323
324 {-
325 data RecordType
326     = A
327     | NS
328     | MD
329     | MF
330     | CNAME
331     | SOA
332     | MB
333     | MG
334     | MR
335     | NULL
336     | WKS
337     | PTR
338     | HINFO
339     | MINFO
340     | MX
341     | TXT
342
343     -- Only for queries:
344     | AXFR
345     | MAILB -- Obsolete
346     | MAILA -- Obsolete
347     | AnyType
348     deriving (Show, Eq)
349 -}
350
351 instance Binary Message where
352     put m = do put $ msgHeader m
353                putWord16be $ fromIntegral $ length $ msgQuestions m
354                putWord16be $ fromIntegral $ length $ msgAnswers m
355                putWord16be $ fromIntegral $ length $ msgAuthorities m
356                putWord16be $ fromIntegral $ length $ msgAdditionals m
357                mapM_ putQ  $ msgQuestions m
358                mapM_ putSomeRR $ msgAnswers m
359                mapM_ putSomeRR $ msgAuthorities m
360                mapM_ putSomeRR $ msgAdditionals m
361
362     get = do hdr  <- get
363              nQ   <- liftM fromIntegral G.getWord16be
364              nAns <- liftM fromIntegral G.getWord16be
365              nAth <- liftM fromIntegral G.getWord16be
366              nAdd <- liftM fromIntegral G.getWord16be
367              (qs  , dt1) <- replicateM' nQ   getQ IM.empty
368              (anss, dt2) <- replicateM' nAns getSomeRR dt1
369              (aths, dt3) <- replicateM' nAth getSomeRR dt2
370              (adds, _  ) <- replicateM' nAdd getSomeRR dt3
371              return Message {
372                           msgHeader      = hdr
373                         , msgQuestions   = qs
374                         , msgAnswers     = anss
375                         , msgAuthorities = aths
376                         , msgAdditionals = adds
377                         }
378
379 instance Binary Header where
380     put h = do putWord16be $ hdMessageID h
381                putLazyByteString flags
382         where
383           flags = runBitPut $
384                   do putNBits 1 $ fromEnum $ hdMessageType h
385                      putNBits 4 $ fromEnum $ hdOpcode h
386                      putBit $ hdIsAuthoritativeAnswer h
387                      putBit $ hdIsTruncated h
388                      putBit $ hdIsRecursionDesired h
389                      putBit $ hdIsRecursionAvailable h
390                      putNBits 3 (0 :: Int)
391                      putNBits 4 $ fromEnum $ hdResponseCode h
392
393     get = do mID   <- G.getWord16be
394              flags <- getByteString 2
395              let Right hd
396                      = runBitGet flags $
397                        do qr <- liftM (toEnum . fromIntegral) $ getAsWord8 1
398                           op <- liftM (toEnum . fromIntegral) $ getAsWord8 4
399                           aa <- getBit
400                           tc <- getBit
401                           rd <- getBit
402                           ra <- getBit
403                           BG.skip 3
404                           rc <- liftM (toEnum . fromIntegral) $ getAsWord8 4
405                           return Header {
406                                        hdMessageID             = mID
407                                      , hdMessageType           = qr
408                                      , hdOpcode                = op
409                                      , hdIsAuthoritativeAnswer = aa
410                                      , hdIsTruncated           = tc
411                                      , hdIsRecursionDesired    = rd
412                                      , hdIsRecursionAvailable  = ra
413                                      , hdResponseCode          = rc
414                                      }
415              return hd
416
417 instance Enum MessageType where
418     fromEnum Query    = 0
419     fromEnum Response = 1
420
421     toEnum 0 = Query
422     toEnum 1 = Response
423     toEnum _ = undefined
424
425 instance Enum Opcode where
426     fromEnum StandardQuery       = 0
427     fromEnum InverseQuery        = 1
428     fromEnum ServerStatusRequest = 2
429
430     toEnum 0 = StandardQuery
431     toEnum 1 = InverseQuery
432     toEnum 2 = ServerStatusRequest
433     toEnum _ = undefined
434
435 instance Enum ResponseCode where
436     fromEnum NoError        = 0
437     fromEnum FormatError    = 1
438     fromEnum ServerFailure  = 2
439     fromEnum NameError      = 3
440     fromEnum NotImplemented = 4
441     fromEnum Refused        = 5
442
443     toEnum 0 = NoError
444     toEnum 1 = FormatError
445     toEnum 2 = ServerFailure
446     toEnum 3 = NameError
447     toEnum 4 = NotImplemented
448     toEnum 5 = Refused
449     toEnum _ = undefined
450
451 {-
452 instance Enum RecordType where
453     fromEnum A       = 1
454     fromEnum NS      = 2
455     fromEnum MD      = 3
456     fromEnum MF      = 4
457     fromEnum CNAME   = 5
458     fromEnum SOA     = 6
459     fromEnum MB      = 7
460     fromEnum MG      = 8
461     fromEnum MR      = 9
462     fromEnum NULL    = 10
463     fromEnum WKS     = 11
464     fromEnum PTR     = 12
465     fromEnum HINFO   = 13
466     fromEnum MINFO   = 14
467     fromEnum MX      = 15
468     fromEnum TXT     = 16
469     fromEnum AXFR    = 252
470     fromEnum MAILB   = 253
471     fromEnum MAILA   = 254
472     fromEnum AnyType = 255
473
474     toEnum 1  = A
475     toEnum 2  = NS
476     toEnum 3  = MD
477     toEnum 4  = MF
478     toEnum 5  = CNAME
479     toEnum 6  = SOA
480     toEnum 7  = MB
481     toEnum 8  = MG
482     toEnum 9  = MR
483     toEnum 10 = NULL
484     toEnum 11 = WKS
485     toEnum 12 = PTR
486     toEnum 13 = HINFO
487     toEnum 14 = MINFO
488     toEnum 15 = MX
489     toEnum 16 = TXT
490     toEnum 252 = AXFR
491     toEnum 253 = MAILB
492     toEnum 254 = MAILA
493     toEnum 255 = AnyType
494     toEnum _  = undefined
495 -}
496
497 instance Enum RecordClass where
498     fromEnum IN       = 1
499     fromEnum CS       = 2
500     fromEnum CH       = 3
501     fromEnum HS       = 4
502     fromEnum AnyClass = 255
503
504     toEnum 1   = IN
505     toEnum 2   = CS
506     toEnum 3   = CH
507     toEnum 4   = HS
508     toEnum 255 = AnyClass
509     toEnum _   = undefined
510
511 instance Binary RecordClass where
512     get = liftM (toEnum . fromIntegral) G.getWord16be
513     put = putWord16be . fromIntegral . fromEnum
514
515
516 defaultRTTable :: IntMap SomeRT
517 defaultRTTable = IM.fromList $ map toPair $
518                  [ SomeRT CNAME
519                  ]
520     where
521       toPair :: SomeRT -> (Int, SomeRT)
522       toPair srt@(SomeRT rt) = (rtToInt rt, srt)