]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Named.hs
AAAA support
[haskell-dns.git] / Network / DNS / Named.hs
1 module Network.DNS.Named
2     ( runNamed
3     )
4     where
5
6 import           Control.Concurrent
7 import           Control.Exception
8 import           Control.Monad
9 import           Data.Binary
10 import           Data.Binary.Get
11 import           Data.Binary.Put
12 import qualified Data.ByteString as BS
13 import qualified Data.ByteString.Lazy as LBS
14 import           Data.Dynamic
15 import           Data.Maybe
16 import           Network.Socket
17 import qualified Network.Socket.ByteString as NB
18 import           Network.DNS.Message
19 import           Network.DNS.Named.Config
20 import           Network.DNS.Named.ResponseBuilder
21 import           Network.DNS.Named.Zone
22 import           System.Posix.Signals
23 import           System.IO
24
25
26 runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
27 runNamed cnf findZone
28     = withSocketsDo $
29       do installHandler sigPIPE Ignore Nothing
30
31          let hint = defaultHints {
32                       addrFlags      = [AI_PASSIVE, AI_V4MAPPED]
33                     , addrFamily     = AF_INET6
34                     , addrSocketType = NoSocketType
35                     , addrProtocol   = defaultProtocol
36                     }
37          (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
38
39          _tcpListenerTID <- forkIO $ tcpListen ai
40          udpListen ai
41     where
42       udpListen :: AddrInfo -> IO ()
43       udpListen ai
44           = do so <- socket (addrFamily ai) Datagram defaultProtocol
45                setSocketOption so ReuseAddr 1
46                bindSocket so (addrAddress ai)
47                udpLoop so
48
49       udpLoop :: Socket -> IO ()
50       udpLoop so
51           = do (packet, cameFrom) <- NB.recvFrom so 512
52                _handlerTID <- forkIO $ udpHandler so packet cameFrom
53                udpLoop so
54
55       tcpListen :: AddrInfo -> IO ()
56       tcpListen ai
57           = do so <- socket (addrFamily ai) Stream defaultProtocol
58                setSocketOption so ReuseAddr 1
59                bindSocket so (addrAddress ai)
60                listen so 255
61                tcpLoop so
62
63       tcpLoop :: Socket -> IO ()
64       tcpLoop so
65           = do (so', _)    <- accept so
66                h           <- socketToHandle so' ReadWriteMode
67                hSetBuffering h $ BlockBuffering Nothing
68                _handlerTID <- forkIO $ tcpHandler h
69                tcpLoop so
70
71       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
72       udpHandler so packet cameFrom
73           = do msg   <- evaluate $ unpackMessage packet
74                msg'  <- handleMessage msg
75                         `onException`
76                         do let servfail = mkErrorReply ServerFailure msg
77                            NB.sendTo so (packMessage (Just 512) servfail) cameFrom
78                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
79                return ()
80
81       tcpHandler :: Handle -> IO ()
82       tcpHandler h
83           = do lenB   <- LBS.hGet h 2
84                if LBS.length lenB < 2 then
85                    -- Got EOF
86                    hClose h
87                  else
88                    do let len = runGet getWord16be lenB
89                       packet <- BS.hGet h $ fromIntegral len
90                       msg    <- evaluate $ unpackMessage packet
91                       msg'   <- handleMessage msg
92                                 `onException`
93                                 do let servfail = mkErrorReply ServerFailure msg
94                                        packet'  = packMessage Nothing servfail
95                                        len'     = fromIntegral $ BS.length packet'
96                                    LBS.hPut h $ runPut $ putWord16be len'
97                                    BS.hPut h packet'
98                                    hClose h
99                       let packet' = packMessage Nothing msg'
100                           len'    = fromIntegral $ BS.length packet'
101                       LBS.hPut h $ runPut $ putWord16be len'
102                       BS.hPut h packet'
103                       hFlush h
104                       tcpHandler h
105
106       handleMessage :: Message -> IO Message
107       handleMessage msg
108           = case validateQuery msg of
109               NoError
110                   -> do builders <- mapM handleQuestion $ msgQuestions msg
111
112                         let builder = foldl (>>) (return ()) builders
113                             msg'    = runBuilder msg builder
114
115                         return msg'
116
117               err -> return $ mkErrorReply err msg
118
119       handleQuestion :: SomeQ -> IO (Builder ())
120       handleQuestion (SomeQ q)
121           = do zoneM <- findZone (qName q)
122                case zoneM of
123                  Nothing
124                      -> return $ do unauthorise
125                                     setResponseCode Refused
126                  Just zone
127                      -> handleQuestionForZone (SomeQ q) zone
128
129       handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
130       handleQuestionForZone (SomeQ q) zone
131           | Just (qType q) == cast AXFR
132               = handleAXFR (SomeQ q) zone
133           | otherwise
134               = do allRecords <- zoneResponder zone (qName q)
135                    let filtered = filterRecords (SomeQ q) allRecords
136
137                    additionals <- do xss <- mapM (getAdditionals zone) filtered
138                                      ys  <- case zoneNSRecord zone of
139                                               Just rr -> getAdditionals zone rr
140                                               Nothing -> return []
141                                      return (concat xss ++ ys)
142
143                    return $ do mapM_ addAnswer filtered
144
145                                when (qName q == zoneName zone) $
146                                     do when (Just (qType q) == cast SOA ||
147                                              Just (qType q) == cast ANY   )
148                                                 $ case zoneSOARecord zone of
149                                                     Just rr -> addAnswer rr
150                                                     Nothing -> return ()
151
152                                        when (Just (qType q) == cast NS ||
153                                              Just (qType q) == cast ANY  )
154                                                 $ case zoneNSRecord zone of
155                                                     Just rr -> addAnswer rr
156                                                     Nothing -> return ()
157
158                                mapM_ addAdditional additionals
159
160                                case zoneNSRecord zone of
161                                  Just rr -> addAuthority rr
162                                  Nothing -> unauthorise
163
164       getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
165       getAdditionals zone (SomeRR rr)
166           = case cast (rrData rr) :: Maybe DomainName of
167               Nothing
168                   -> return []
169               Just name
170                   -> do allRecords <- zoneResponder zone name
171
172                         let rA = filterRecords (SomeQ qA) allRecords
173                             rB = filterRecords (SomeQ qB) allRecords
174                             qA = Question {
175                                    qName  = name
176                                  , qType  = A
177                                  , qClass = IN
178                                  }
179                             qB = Question {
180                                    qName  = name
181                                  , qType  = AAAA
182                                  , qClass = IN
183                                  }
184                         return (rA ++ rB)
185
186       filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
187       filterRecords (SomeQ q) = filter predicate
188           where
189             predicate rr
190                 = predForType rr && predForClass rr
191
192             predForType (SomeRR rr)
193                 | typeOf (qType q) == typeOf ANY
194                     = True
195
196                 | typeOf (qType q) == typeOf MAILB
197                     = typeOf (rrType rr) == typeOf MR ||
198                       typeOf (rrType rr) == typeOf MB ||
199                       typeOf (rrType rr) == typeOf MG ||
200                       typeOf (rrType rr) == typeOf MINFO
201
202                 | otherwise
203                     = typeOf (rrType rr) == typeOf (qType q) ||
204                       typeOf (rrType rr) == typeOf CNAME
205
206             predForClass (SomeRR rr)
207                 | typeOf (qClass q) == typeOf ANY
208                     = True
209
210                 | otherwise
211                     = typeOf (rrClass rr) == typeOf (qClass q)
212
213       handleAXFR :: SomeQ -> Zone -> IO (Builder ())
214       handleAXFR (SomeQ q) zone
215           | qName q == zoneName zone &&
216             isJust (zoneSOA zone)    &&
217             cnfAllowTransfer cnf
218               = do names      <- zoneRecordNames zone
219                    allRecords <- liftM concat $ mapM (zoneResponder zone) names
220                    return $ do addAnswer $ fromJust $ zoneSOARecord zone
221                                addAnswer $ fromJust $ zoneNSRecord  zone
222                                mapM_ addAnswer allRecords
223                                addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
224           | otherwise
225               = return $ return ()
226
227
228 validateQuery :: Message -> ResponseCode
229 validateQuery = validateHeader . msgHeader
230     where
231       validateHeader :: Header -> ResponseCode
232       validateHeader hdr
233           | hdMessageType hdr /= Query         = NotImplemented
234           | hdOpcode      hdr /= StandardQuery = NotImplemented
235           | otherwise                          = NoError
236
237
238 packMessage :: Maybe Int -> Message -> BS.ByteString
239 packMessage limM = BS.concat . LBS.toChunks . truncateMsg
240     where
241       truncateMsg :: Message -> LBS.ByteString
242       truncateMsg msg
243           = let packet    = encode msg
244                 needTrunc = fromMaybe False $
245                             do lim <- limM
246                                return $ fromIntegral (LBS.length packet) > lim
247             in
248               if needTrunc then
249                   truncateMsg $ trunc' msg
250               else
251                   packet
252
253       trunc' :: Message -> Message
254       trunc' msg
255           | notNull $ msgAdditionals msg
256               = msg {
257                   msgAdditionals = truncList $ msgAdditionals msg
258                 }
259           | notNull $ msgAuthorities msg
260               = msg {
261                   msgHeader      = setTruncFlag $ msgHeader msg
262                 , msgAuthorities = truncList $ msgAuthorities msg
263                 }
264           | notNull $ msgAnswers msg
265               = msg {
266                   msgHeader      = setTruncFlag $ msgHeader msg
267                 , msgAnswers     = truncList $ msgAnswers msg
268                 }
269           | notNull $ msgQuestions msg
270               = msg {
271                   msgHeader      = setTruncFlag $ msgHeader msg
272                 , msgQuestions   = truncList $ msgQuestions msg
273                 }
274           | otherwise
275               = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
276
277       setTruncFlag :: Header -> Header
278       setTruncFlag hdr = hdr { hdIsTruncated = True }
279
280       notNull :: [a] -> Bool
281       notNull = not . null
282
283       truncList :: [a] -> [a]
284       truncList xs = take (length xs - 1) xs
285
286 unpackMessage :: BS.ByteString -> Message
287 unpackMessage = decode . LBS.fromChunks . return
288
289 mkErrorReply :: ResponseCode -> Message -> Message
290 mkErrorReply err msg
291     = runBuilder msg $ do unauthorise
292                           setResponseCode err