1 module Network.DNS.Named
6 import Control.Concurrent
7 import Control.Exception
10 import Data.Binary.Get
11 import Data.Binary.Put
12 import qualified Data.ByteString as BS
13 import qualified Data.ByteString.Lazy as LBS
17 import qualified Network.Socket.ByteString as NB
18 import Network.DNS.Message
19 import Network.DNS.Named.Config
20 import Network.DNS.Named.ResponseBuilder
21 import Network.DNS.Named.Zone
22 import System.Posix.Signals
26 runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
29 do installHandler sigPIPE Ignore Nothing
31 let hint = defaultHints {
32 addrFlags = [AI_PASSIVE, AI_V4MAPPED]
33 , addrFamily = AF_INET6
34 , addrSocketType = NoSocketType
35 , addrProtocol = defaultProtocol
37 (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
39 _tcpListenerTID <- forkIO $ tcpListen ai
42 udpListen :: AddrInfo -> IO ()
44 = do so <- socket (addrFamily ai) Datagram defaultProtocol
45 bindSocket so (addrAddress ai)
48 udpLoop :: Socket -> IO ()
50 = do (packet, cameFrom) <- NB.recvFrom so 512
51 _handlerTID <- forkIO $ udpHandler so packet cameFrom
54 tcpListen :: AddrInfo -> IO ()
56 = do so <- socket (addrFamily ai) Stream defaultProtocol
57 bindSocket so (addrAddress ai)
61 tcpLoop :: Socket -> IO ()
63 = do (so', _) <- accept so
64 h <- socketToHandle so' ReadWriteMode
65 hSetBuffering h $ BlockBuffering Nothing
66 _handlerTID <- forkIO $ tcpHandler h
69 udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
70 udpHandler so packet cameFrom
71 = do msg <- evaluate $ unpackMessage packet
72 msg' <- handleMessage msg
74 do let servfail = mkErrorReply ServerFailure msg
75 NB.sendTo so (packMessage (Just 512) servfail) cameFrom
76 _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
79 tcpHandler :: Handle -> IO ()
81 = do lenB <- LBS.hGet h 2
82 if LBS.length lenB < 2 then
86 do let len = runGet getWord16be lenB
87 packet <- BS.hGet h $ fromIntegral len
88 msg <- evaluate $ unpackMessage packet
89 msg' <- handleMessage msg
91 do let servfail = mkErrorReply ServerFailure msg
92 packet' = packMessage Nothing servfail
93 len' = fromIntegral $ BS.length packet'
94 LBS.hPut h $ runPut $ putWord16be len'
97 let packet' = packMessage Nothing msg'
98 len' = fromIntegral $ BS.length packet'
99 LBS.hPut h $ runPut $ putWord16be len'
104 handleMessage :: Message -> IO Message
106 = case validateQuery msg of
108 -> do builders <- mapM handleQuestion $ msgQuestions msg
110 let builder = foldl (>>) (return ()) builders
111 msg' = runBuilder msg builder
115 err -> return $ mkErrorReply err msg
117 handleQuestion :: SomeQ -> IO (Builder ())
118 handleQuestion (SomeQ q)
119 = do zoneM <- findZone (qName q)
122 -> return $ do unauthorise
123 setResponseCode Refused
125 -> handleQuestionForZone (SomeQ q) zone
127 handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
128 handleQuestionForZone (SomeQ q) zone
129 | Just (qType q) == cast AXFR
130 = handleAXFR (SomeQ q) zone
132 = do allRecords <- zoneResponder zone (qName q)
133 let filtered = filterRecords (SomeQ q) allRecords
135 additionals <- do xss <- mapM (getAdditionals zone) filtered
136 ys <- case zoneNSRecord zone of
137 Just rr -> getAdditionals zone rr
139 return (concat xss ++ ys)
141 return $ do mapM_ addAnswer filtered
143 when (qName q == zoneName zone) $
144 do when (Just (qType q) == cast SOA ||
145 Just (qType q) == cast ANY )
146 $ case zoneSOARecord zone of
147 Just rr -> addAnswer rr
150 when (Just (qType q) == cast NS ||
151 Just (qType q) == cast ANY )
152 $ case zoneNSRecord zone of
153 Just rr -> addAnswer rr
156 mapM_ addAdditional additionals
158 case zoneNSRecord zone of
159 Just rr -> addAuthority rr
160 Nothing -> unauthorise
162 getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
163 getAdditionals zone (SomeRR rr)
164 = case cast (rrData rr) :: Maybe DomainName of
168 -> do allRecords <- zoneResponder zone name
170 let filtered = filterRecords (SomeQ q') allRecords
178 filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
179 filterRecords (SomeQ q)
180 | Just (qType q) == cast ANY &&
181 Just (qClass q) == cast ANY = id
182 | Just (qType q) == cast ANY = filter matchClass
183 | Just (qClass q) == cast ANY = filter matchType
184 | otherwise = filter matchBoth
186 matchClass (SomeRR rr)
187 = Just (qClass q) == cast (rrClass rr)
189 matchType (SomeRR rr)
190 = Just (qType q) == cast (rrType rr) ||
191 Just CNAME == cast (rrType rr)
194 = matchType rr && matchClass rr
196 handleAXFR :: SomeQ -> Zone -> IO (Builder ())
197 handleAXFR (SomeQ q) zone
198 | qName q == zoneName zone &&
199 isJust (zoneSOA zone) &&
201 = do names <- zoneRecordNames zone
202 allRecords <- liftM concat $ mapM (zoneResponder zone) names
203 return $ do addAnswer $ fromJust $ zoneSOARecord zone
204 addAnswer $ fromJust $ zoneNSRecord zone
205 mapM_ addAnswer allRecords
206 addAnswerNonuniquely $ fromJust $ zoneSOARecord zone
211 validateQuery :: Message -> ResponseCode
212 validateQuery = validateHeader . msgHeader
214 validateHeader :: Header -> ResponseCode
216 | hdMessageType hdr /= Query = NotImplemented
217 | hdOpcode hdr /= StandardQuery = NotImplemented
218 | otherwise = NoError
221 packMessage :: Maybe Int -> Message -> BS.ByteString
222 packMessage limM = BS.concat . LBS.toChunks . truncateMsg
224 truncateMsg :: Message -> LBS.ByteString
226 = let packet = encode msg
227 needTrunc = fromMaybe False $
229 return $ fromIntegral (LBS.length packet) > lim
232 truncateMsg $ trunc' msg
236 trunc' :: Message -> Message
238 | notNull $ msgAdditionals msg
240 msgAdditionals = truncList $ msgAdditionals msg
242 | notNull $ msgAuthorities msg
244 msgHeader = setTruncFlag $ msgHeader msg
245 , msgAuthorities = truncList $ msgAuthorities msg
247 | notNull $ msgAnswers msg
249 msgHeader = setTruncFlag $ msgHeader msg
250 , msgAnswers = truncList $ msgAnswers msg
252 | notNull $ msgQuestions msg
254 msgHeader = setTruncFlag $ msgHeader msg
255 , msgQuestions = truncList $ msgQuestions msg
258 = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
260 setTruncFlag :: Header -> Header
261 setTruncFlag hdr = hdr { hdIsTruncated = True }
263 notNull :: [a] -> Bool
266 truncList :: [a] -> [a]
267 truncList xs = take (length xs - 1) xs
269 unpackMessage :: BS.ByteString -> Message
270 unpackMessage = decode . LBS.fromChunks . return
272 mkErrorReply :: ResponseCode -> Message -> Message
274 = runBuilder msg $ do unauthorise