]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Named.hs
The server now accepts IPv6 connections.
[haskell-dns.git] / Network / DNS / Named.hs
1 module Network.DNS.Named
2     ( runNamed
3     )
4     where
5
6 import           Control.Concurrent
7 import           Control.Exception
8 import           Control.Monad
9 import           Data.Binary
10 import           Data.Binary.Get
11 import           Data.Binary.Put
12 import qualified Data.ByteString as BS
13 import qualified Data.ByteString.Lazy as LBS
14 import           Data.Dynamic
15 import           Data.Maybe
16 import           Network.Socket
17 import qualified Network.Socket.ByteString as NB
18 import           Network.DNS.Message
19 import           Network.DNS.Named.Config
20 import           Network.DNS.Named.ResponseBuilder
21 import           Network.DNS.Named.Zone
22 import           System.Posix.Signals
23 import           System.IO
24
25
26 runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
27 runNamed cnf findZone
28     = withSocketsDo $
29       do installHandler sigPIPE Ignore Nothing
30
31          let hint = defaultHints {
32                       addrFlags      = [AI_PASSIVE, AI_V4MAPPED]
33                     , addrFamily     = AF_INET6
34                     , addrSocketType = NoSocketType
35                     , addrProtocol   = defaultProtocol
36                     }
37          (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
38
39          _tcpListenerTID <- forkIO $ tcpListen ai
40          udpListen ai
41     where
42       udpListen :: AddrInfo -> IO ()
43       udpListen ai
44           = do so <- socket (addrFamily ai) Datagram defaultProtocol
45                bindSocket so (addrAddress ai)
46                udpLoop so
47
48       udpLoop :: Socket -> IO ()
49       udpLoop so
50           = do (packet, cameFrom) <- NB.recvFrom so 512
51                _handlerTID <- forkIO $ udpHandler so packet cameFrom
52                udpLoop so
53
54       tcpListen :: AddrInfo -> IO ()
55       tcpListen ai
56           = do so <- socket (addrFamily ai) Stream defaultProtocol
57                bindSocket so (addrAddress ai)
58                listen so 255
59                tcpLoop so
60
61       tcpLoop :: Socket -> IO ()
62       tcpLoop so
63           = do (so', _)    <- accept so
64                h           <- socketToHandle so' ReadWriteMode
65                hSetBuffering h $ BlockBuffering Nothing
66                _handlerTID <- forkIO $ tcpHandler h
67                tcpLoop so
68
69       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
70       udpHandler so packet cameFrom
71           = do msg   <- evaluate $ unpackMessage packet
72                msg'  <- handleMessage msg
73                         `onException`
74                         do let servfail = mkErrorReply ServerFailure msg
75                            NB.sendTo so (packMessage (Just 512) servfail) cameFrom
76                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
77                return ()
78
79       tcpHandler :: Handle -> IO ()
80       tcpHandler h
81           = do lenB   <- LBS.hGet h 2
82                if LBS.length lenB < 2 then
83                    -- Got EOF
84                    hClose h
85                  else
86                    do let len = runGet getWord16be lenB
87                       packet <- BS.hGet h $ fromIntegral len
88                       msg    <- evaluate $ unpackMessage packet
89                       msg'   <- handleMessage msg
90                                 `onException`
91                                 do let servfail = mkErrorReply ServerFailure msg
92                                        packet'  = packMessage Nothing servfail
93                                        len'     = fromIntegral $ BS.length packet'
94                                    LBS.hPut h $ runPut $ putWord16be len'
95                                    BS.hPut h packet'
96                                    hClose h
97                       let packet' = packMessage Nothing msg'
98                           len'    = fromIntegral $ BS.length packet'
99                       LBS.hPut h $ runPut $ putWord16be len'
100                       BS.hPut h packet'
101                       hFlush h
102                       tcpHandler h
103
104       handleMessage :: Message -> IO Message
105       handleMessage msg
106           = case validateQuery msg of
107               NoError
108                   -> do builders <- mapM handleQuestion $ msgQuestions msg
109
110                         let builder = foldl (>>) (return ()) builders
111                             msg'    = runBuilder msg builder
112
113                         return msg'
114
115               err -> return $ mkErrorReply err msg
116
117       handleQuestion :: SomeQ -> IO (Builder ())
118       handleQuestion (SomeQ q)
119           = do zoneM <- findZone (qName q)
120                case zoneM of
121                  Nothing
122                      -> return $ do unauthorise
123                                     setResponseCode Refused
124                  Just zone
125                      -> handleQuestionForZone (SomeQ q) zone
126
127       handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
128       handleQuestionForZone (SomeQ q) zone
129           | Just (qType q) == cast AXFR
130               = handleAXFR (SomeQ q) zone
131           | otherwise
132               = do allRecords <- zoneResponder zone (qName q)
133                    let filtered = filterRecords (SomeQ q) allRecords
134
135                    additionals <- do xss <- mapM (getAdditionals zone) filtered
136                                      ys  <- case zoneNSRecord zone of
137                                               Just rr -> getAdditionals zone rr
138                                               Nothing -> return []
139                                      return (concat xss ++ ys)
140
141                    return $ do mapM_ addAnswer filtered
142
143                                when (qName q == zoneName zone) $
144                                     do when (Just (qType q) == cast SOA ||
145                                              Just (qType q) == cast ANY   )
146                                                 $ case zoneSOARecord zone of
147                                                     Just rr -> addAnswer rr
148                                                     Nothing -> return ()
149
150                                        when (Just (qType q) == cast NS ||
151                                              Just (qType q) == cast ANY  )
152                                                 $ case zoneNSRecord zone of
153                                                     Just rr -> addAnswer rr
154                                                     Nothing -> return ()
155
156                                mapM_ addAdditional additionals
157
158                                case zoneNSRecord zone of
159                                  Just rr -> addAuthority rr
160                                  Nothing -> unauthorise
161
162       getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
163       getAdditionals zone (SomeRR rr)
164           = case cast (rrData rr) :: Maybe DomainName of
165               Nothing
166                   -> return []
167               Just name
168                   -> do allRecords <- zoneResponder zone name
169
170                         let filtered = filterRecords (SomeQ q') allRecords
171                             q'       = Question {
172                                          qName  = name
173                                        , qType  = A
174                                        , qClass = IN
175                                        }
176                         return filtered
177
178       filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
179       filterRecords (SomeQ q)
180           | Just (qType  q) == cast ANY &&
181             Just (qClass q) == cast ANY    = id
182           | Just (qType  q) == cast ANY    = filter matchClass
183           | Just (qClass q) == cast ANY    = filter matchType
184           | otherwise                      = filter matchBoth
185           where
186             matchClass (SomeRR rr)
187                 = Just (qClass q) == cast (rrClass rr)
188
189             matchType (SomeRR rr)
190                 = Just (qType  q) == cast (rrType  rr) ||
191                   Just CNAME      == cast (rrType  rr)
192
193             matchBoth rr
194                 = matchType rr && matchClass rr
195
196       handleAXFR :: SomeQ -> Zone -> IO (Builder ())
197       handleAXFR (SomeQ q) zone
198           | qName q == zoneName zone &&
199             isJust (zoneSOA zone)    &&
200             cnfAllowTransfer cnf
201               = do names      <- zoneRecordNames zone
202                    allRecords <- liftM concat $ mapM (zoneResponder zone) names
203                    return $ do addAnswer $ fromJust $ zoneSOARecord zone
204                                addAnswer $ fromJust $ zoneNSRecord  zone
205                                mapM_ addAnswer allRecords
206                                addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
207           | otherwise
208               = return $ return ()
209
210
211 validateQuery :: Message -> ResponseCode
212 validateQuery = validateHeader . msgHeader
213     where
214       validateHeader :: Header -> ResponseCode
215       validateHeader hdr
216           | hdMessageType hdr /= Query         = NotImplemented
217           | hdOpcode      hdr /= StandardQuery = NotImplemented
218           | otherwise                          = NoError
219
220
221 packMessage :: Maybe Int -> Message -> BS.ByteString
222 packMessage limM = BS.concat . LBS.toChunks . truncateMsg
223     where
224       truncateMsg :: Message -> LBS.ByteString
225       truncateMsg msg
226           = let packet    = encode msg
227                 needTrunc = fromMaybe False $
228                             do lim <- limM
229                                return $ fromIntegral (LBS.length packet) > lim
230             in
231               if needTrunc then
232                   truncateMsg $ trunc' msg
233               else
234                   packet
235
236       trunc' :: Message -> Message
237       trunc' msg
238           | notNull $ msgAdditionals msg
239               = msg {
240                   msgAdditionals = truncList $ msgAdditionals msg
241                 }
242           | notNull $ msgAuthorities msg
243               = msg {
244                   msgHeader      = setTruncFlag $ msgHeader msg
245                 , msgAuthorities = truncList $ msgAuthorities msg
246                 }
247           | notNull $ msgAnswers msg
248               = msg {
249                   msgHeader      = setTruncFlag $ msgHeader msg
250                 , msgAnswers     = truncList $ msgAnswers msg
251                 }
252           | notNull $ msgQuestions msg
253               = msg {
254                   msgHeader      = setTruncFlag $ msgHeader msg
255                 , msgQuestions   = truncList $ msgQuestions msg
256                 }
257           | otherwise
258               = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
259
260       setTruncFlag :: Header -> Header
261       setTruncFlag hdr = hdr { hdIsTruncated = True }
262
263       notNull :: [a] -> Bool
264       notNull = not . null
265
266       truncList :: [a] -> [a]
267       truncList xs = take (length xs - 1) xs
268
269 unpackMessage :: BS.ByteString -> Message
270 unpackMessage = decode . LBS.fromChunks . return
271
272 mkErrorReply :: ResponseCode -> Message -> Message
273 mkErrorReply err msg
274     = runBuilder msg $ do unauthorise
275                           setResponseCode err