]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Named.hs
DomainMap: totally untested yet
[haskell-dns.git] / Network / DNS / Named.hs
1 module Network.DNS.Named
2     ( runNamed
3     )
4     where
5
6 import           Control.Concurrent
7 import           Control.Exception
8 import           Control.Monad
9 import           Data.Binary
10 import           Data.Binary.Get
11 import           Data.Binary.Put
12 import qualified Data.ByteString as BS
13 import qualified Data.ByteString.Lazy as LBS
14 import           Data.Dynamic
15 import           Data.Maybe
16 import           Network.Socket
17 import qualified Network.Socket.ByteString as NB
18 import           Network.DNS.Message
19 import           Network.DNS.Named.Config
20 import           Network.DNS.Named.ResponseBuilder
21 import           Network.DNS.Named.Zone
22 import           System.Posix.Signals
23 import           System.IO
24
25
26 runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
27 runNamed cnf findZone
28     = withSocketsDo $
29       do installHandler sigPIPE Ignore Nothing
30          _tcpListenerTID <- forkIO $ tcpListen
31          udpListen
32     where
33       udpListen :: IO ()
34       udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
35                      so <- socket AF_INET Datagram defaultProtocol
36                      bindSocket so $ cnfServerAddress cnf
37                      udpLoop so
38
39       udpLoop :: Socket -> IO ()
40       udpLoop so
41           = do (packet, cameFrom) <- NB.recvFrom so 512
42                _handlerTID <- forkIO $ udpHandler so packet cameFrom
43                udpLoop so
44
45       tcpListen :: IO ()
46       tcpListen = do so <- socket AF_INET Stream defaultProtocol
47                      bindSocket so $ cnfServerAddress cnf
48                      listen so 255
49                      tcpLoop so
50
51       tcpLoop :: Socket -> IO ()
52       tcpLoop so
53           = do (so', _)    <- accept so
54                h           <- socketToHandle so' ReadWriteMode
55                hSetBuffering h $ BlockBuffering Nothing
56                _handlerTID <- forkIO $ tcpHandler h
57                tcpLoop so
58
59       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
60       udpHandler so packet cameFrom
61           = do msg   <- evaluate $ unpackMessage packet
62                msg'  <- handleMessage msg
63                         `onException`
64                         do let servfail = mkErrorReply ServerFailure msg
65                            NB.sendTo so (packMessage (Just 512) servfail) cameFrom
66                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
67                return ()
68
69       tcpHandler :: Handle -> IO ()
70       tcpHandler h
71           = do lenB   <- LBS.hGet h 2
72                if LBS.length lenB < 2 then
73                    -- Got EOF
74                    hClose h
75                  else
76                    do let len = runGet getWord16be lenB
77                       packet <- BS.hGet h $ fromIntegral len
78                       msg    <- evaluate $ unpackMessage packet
79                       msg'   <- handleMessage msg
80                                 `onException`
81                                 do let servfail = mkErrorReply ServerFailure msg
82                                        packet'  = packMessage Nothing servfail
83                                        len'     = fromIntegral $ BS.length packet'
84                                    LBS.hPut h $ runPut $ putWord16be len'
85                                    BS.hPut h packet'
86                                    hClose h
87                       let packet' = packMessage Nothing msg'
88                           len'    = fromIntegral $ BS.length packet'
89                       LBS.hPut h $ runPut $ putWord16be len'
90                       BS.hPut h packet'
91                       hFlush h
92                       tcpHandler h
93
94       handleMessage :: Message -> IO Message
95       handleMessage msg
96           = case validateQuery msg of
97               NoError
98                   -> do builders <- mapM handleQuestion $ msgQuestions msg
99
100                         let builder = foldl (>>) (return ()) builders
101                             msg'    = runBuilder msg builder
102
103                         return msg'
104
105               err -> return $ mkErrorReply err msg
106
107       handleQuestion :: SomeQ -> IO (Builder ())
108       handleQuestion (SomeQ q)
109           = do zoneM <- findZone (qName q)
110                case zoneM of
111                  Nothing
112                      -> return $ do unauthorise
113                                     setResponseCode Refused
114                  Just zone
115                      -> handleQuestionForZone (SomeQ q) zone
116
117       handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
118       handleQuestionForZone (SomeQ q) zone
119           | Just (qType q) == cast AXFR
120               = handleAXFR (SomeQ q) zone
121           | otherwise
122               = do allRecords <- zoneResponder zone (qName q)
123                    let filtered = filterRecords (SomeQ q) allRecords
124
125                    additionals <- do xss <- mapM (getAdditionals zone) filtered
126                                      ys  <- case zoneNSRecord zone of
127                                               Just rr -> getAdditionals zone rr
128                                               Nothing -> return []
129                                      return (concat xss ++ ys)
130
131                    return $ do mapM_ addAnswer filtered
132
133                                when (qName q == zoneName zone) $
134                                     do when (Just (qType q) == cast SOA ||
135                                              Just (qType q) == cast ANY   )
136                                                 $ case zoneSOARecord zone of
137                                                     Just rr -> addAnswer rr
138                                                     Nothing -> return ()
139
140                                        when (Just (qType q) == cast NS ||
141                                              Just (qType q) == cast ANY  )
142                                                 $ case zoneNSRecord zone of
143                                                     Just rr -> addAnswer rr
144                                                     Nothing -> return ()
145
146                                mapM_ addAdditional additionals
147
148                                case zoneNSRecord zone of
149                                  Just rr -> addAuthority rr
150                                  Nothing -> unauthorise
151
152       getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
153       getAdditionals zone (SomeRR rr)
154           = case cast (rrData rr) :: Maybe DomainName of
155               Nothing
156                   -> return []
157               Just name
158                   -> do allRecords <- zoneResponder zone name
159
160                         let filtered = filterRecords (SomeQ q') allRecords
161                             q'       = Question {
162                                          qName  = name
163                                        , qType  = A
164                                        , qClass = IN
165                                        }
166                         return filtered
167
168       filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
169       filterRecords (SomeQ q)
170           | Just (qType  q) == cast ANY &&
171             Just (qClass q) == cast ANY    = id
172           | Just (qType  q) == cast ANY    = filter matchClass
173           | Just (qClass q) == cast ANY    = filter matchType
174           | otherwise                      = filter matchBoth
175           where
176             matchClass (SomeRR rr)
177                 = Just (qClass q) == cast (rrClass rr)
178
179             matchType (SomeRR rr)
180                 = Just (qType  q) == cast (rrType  rr) ||
181                   Just CNAME      == cast (rrType  rr)
182
183             matchBoth rr
184                 = matchType rr && matchClass rr
185
186       handleAXFR :: SomeQ -> Zone -> IO (Builder ())
187       handleAXFR (SomeQ q) zone
188           | qName q == zoneName zone &&
189             isJust (zoneSOA zone)    &&
190             cnfAllowTransfer cnf
191               = do names      <- zoneRecordNames zone
192                    allRecords <- liftM concat $ mapM (zoneResponder zone) names
193                    return $ do addAnswer $ fromJust $ zoneSOARecord zone
194                                addAnswer $ fromJust $ zoneNSRecord  zone
195                                mapM_ addAnswer allRecords
196                                addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
197           | otherwise
198               = return $ return ()
199
200
201 validateQuery :: Message -> ResponseCode
202 validateQuery = validateHeader . msgHeader
203     where
204       validateHeader :: Header -> ResponseCode
205       validateHeader hdr
206           | hdMessageType hdr /= Query         = NotImplemented
207           | hdOpcode      hdr /= StandardQuery = NotImplemented
208           | otherwise                          = NoError
209
210
211 packMessage :: Maybe Int -> Message -> BS.ByteString
212 packMessage limM = BS.concat . LBS.toChunks . truncateMsg
213     where
214       truncateMsg :: Message -> LBS.ByteString
215       truncateMsg msg
216           = let packet    = encode msg
217                 needTrunc = fromMaybe False $
218                             do lim <- limM
219                                return $ fromIntegral (LBS.length packet) > lim
220             in
221               if needTrunc then
222                   truncateMsg $ trunc' msg
223               else
224                   packet
225
226       trunc' :: Message -> Message
227       trunc' msg
228           | notNull $ msgAdditionals msg
229               = msg {
230                   msgAdditionals = truncList $ msgAdditionals msg
231                 }
232           | notNull $ msgAuthorities msg
233               = msg {
234                   msgHeader      = setTruncFlag $ msgHeader msg
235                 , msgAuthorities = truncList $ msgAuthorities msg
236                 }
237           | notNull $ msgAnswers msg
238               = msg {
239                   msgHeader      = setTruncFlag $ msgHeader msg
240                 , msgAnswers     = truncList $ msgAnswers msg
241                 }
242           | notNull $ msgQuestions msg
243               = msg {
244                   msgHeader      = setTruncFlag $ msgHeader msg
245                 , msgQuestions   = truncList $ msgQuestions msg
246                 }
247           | otherwise
248               = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
249
250       setTruncFlag :: Header -> Header
251       setTruncFlag hdr = hdr { hdIsTruncated = True }
252
253       notNull :: [a] -> Bool
254       notNull = not . null
255
256       truncList :: [a] -> [a]
257       truncList xs = take (length xs - 1) xs
258
259 unpackMessage :: BS.ByteString -> Message
260 unpackMessage = decode . LBS.fromChunks . return
261
262 mkErrorReply :: ResponseCode -> Message -> Message
263 mkErrorReply err msg
264     = runBuilder msg $ do unauthorise
265                           setResponseCode err