]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Named.hs
Code clean up
[haskell-dns.git] / Network / DNS / Named.hs
1 module Network.DNS.Named
2     ( runNamed
3     )
4     where
5
6 import           Control.Concurrent
7 import           Control.Exception
8 import           Control.Monad
9 import           Data.Binary
10 import           Data.Binary.Get
11 import           Data.Binary.Put
12 import qualified Data.ByteString as BS
13 import qualified Data.ByteString.Lazy as LBS
14 import           Data.Dynamic
15 import           Data.Maybe
16 import           Network.Socket
17 import qualified Network.Socket.ByteString as NB
18 import           Network.DNS.Message
19 import           Network.DNS.Named.Config
20 import           Network.DNS.Named.ResponseBuilder
21 import           Network.DNS.Named.Zone
22 import           System.Posix.Signals
23 import           System.IO
24
25
26 runNamed :: ZoneFinder zf => Config -> zf -> IO ()
27 runNamed cnf zf
28     = withSocketsDo $
29       do installHandler sigPIPE Ignore Nothing
30
31          let hint = defaultHints {
32                       addrFlags      = [AI_PASSIVE, AI_V4MAPPED]
33                     , addrFamily     = AF_INET6
34                     , addrSocketType = NoSocketType
35                     , addrProtocol   = defaultProtocol
36                     }
37          (ai:_) <- getAddrInfo (Just hint) Nothing (Just $ cnfServerPort cnf)
38
39          _tcpListenerTID <- forkIO $ tcpListen ai
40          udpListen ai
41     where
42       udpListen :: AddrInfo -> IO ()
43       udpListen ai
44           = do so <- socket (addrFamily ai) Datagram defaultProtocol
45                setSocketOption so ReuseAddr 1
46                bindSocket so (addrAddress ai)
47                udpLoop so
48
49       udpLoop :: Socket -> IO ()
50       udpLoop so
51           = do (packet, cameFrom) <- NB.recvFrom so 512
52                _handlerTID <- forkIO $ udpHandler so packet cameFrom
53                udpLoop so
54
55       tcpListen :: AddrInfo -> IO ()
56       tcpListen ai
57           = do so <- socket (addrFamily ai) Stream defaultProtocol
58                setSocketOption so ReuseAddr 1
59                bindSocket so (addrAddress ai)
60                listen so 255
61                tcpLoop so
62
63       tcpLoop :: Socket -> IO ()
64       tcpLoop so
65           = do (so', _)    <- accept so
66                h           <- socketToHandle so' ReadWriteMode
67                hSetBuffering h $ BlockBuffering Nothing
68                _handlerTID <- forkIO $ tcpHandler h
69                tcpLoop so
70
71       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
72       udpHandler so packet cameFrom
73           = do msg   <- evaluate $ unpackMessage packet
74                msg'  <- handleMessage msg
75                         `onException`
76                         do let servfail = mkErrorReply ServerFailure msg
77                            NB.sendTo so (packMessage (Just 512) servfail) cameFrom
78                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
79                return ()
80
81       tcpHandler :: Handle -> IO ()
82       tcpHandler h
83           = do lenB   <- LBS.hGet h 2
84                if LBS.length lenB < 2 then
85                    -- Got EOF
86                    hClose h
87                  else
88                    do let len = runGet getWord16be lenB
89                       packet <- BS.hGet h $ fromIntegral len
90                       msg    <- evaluate $ unpackMessage packet
91                       msg'   <- handleMessage msg
92                                 `onException`
93                                 do let servfail = mkErrorReply ServerFailure msg
94                                        packet'  = packMessage Nothing servfail
95                                        len'     = fromIntegral $ BS.length packet'
96                                    LBS.hPut h $ runPut $ putWord16be len'
97                                    BS.hPut h packet'
98                                    hClose h
99                       let packet' = packMessage Nothing msg'
100                           len'    = fromIntegral $ BS.length packet'
101                       LBS.hPut h $ runPut $ putWord16be len'
102                       BS.hPut h packet'
103                       hFlush h
104                       tcpHandler h
105
106       handleMessage :: Message -> IO Message
107       handleMessage msg
108           = case validateQuery msg of
109               NoError
110                   -> do builders <- mapM handleQuestion $ msgQuestions msg
111
112                         let builder = foldl (>>) (return ()) builders
113                             msg'    = runBuilder msg builder
114
115                         return msg'
116
117               err -> return $ mkErrorReply err msg
118
119       handleQuestion :: SomeQ -> IO (Builder ())
120       handleQuestion (SomeQ q)
121           = do zoneM <- findZone zf (qName q)
122                case zoneM of
123                  Nothing
124                      -> return $ do unauthorise
125                                     setResponseCode Refused
126                  Just zone
127                      -> handleQuestionForZone q zone
128
129       handleQuestionForZone :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ())
130       handleQuestionForZone q zone
131           | Just (qType q) == cast AXFR
132               = handleAXFR q zone
133           | otherwise
134               = do answers     <- getRecords zone q
135                    authority   <- getRecords zone (Question (zoneName zone) NS IN)
136                    additionals <- liftM concat $ mapM (getAdditionals zone) (answers ++ authority)
137                    isAuth      <- isAuthoritativeZone zone
138                    return $ do mapM_ addAnswer     answers
139                                mapM_ addAuthority  authority
140                                mapM_ addAdditional additionals
141                                unless isAuth unauthorise
142
143       getAdditionals :: Zone z => z -> SomeRR -> IO [SomeRR]
144       getAdditionals zone (SomeRR rr)
145           = case cast (rrData rr) :: Maybe DomainName of
146               Nothing
147                   -> return []
148               Just name
149                   -> do rrA    <- getRecords zone (Question name A    IN)
150                         rrAAAA <- getRecords zone (Question name AAAA IN)
151                         return (rrA ++ rrAAAA)
152
153       handleAXFR :: (Zone z, QueryType qt, QueryClass qc) => Question qt qc -> z -> IO (Builder ())
154       handleAXFR q zone
155           | cnfAllowTransfer cnf
156               = do rs <- getRecords zone q
157                    return $ mapM_ addAnswerNonuniquely rs
158           | otherwise
159               = return $ return ()
160
161
162 validateQuery :: Message -> ResponseCode
163 validateQuery = validateHeader . msgHeader
164     where
165       validateHeader :: Header -> ResponseCode
166       validateHeader hdr
167           | hdMessageType hdr /= Query         = NotImplemented
168           | hdOpcode      hdr /= StandardQuery = NotImplemented
169           | otherwise                          = NoError
170
171
172 packMessage :: Maybe Int -> Message -> BS.ByteString
173 packMessage limM = BS.concat . LBS.toChunks . truncateMsg
174     where
175       truncateMsg :: Message -> LBS.ByteString
176       truncateMsg msg
177           = let packet    = encode msg
178                 needTrunc = fromMaybe False $
179                             do lim <- limM
180                                return $ fromIntegral (LBS.length packet) > lim
181             in
182               if needTrunc then
183                   truncateMsg $ trunc' msg
184               else
185                   packet
186
187       trunc' :: Message -> Message
188       trunc' msg
189           | notNull $ msgAdditionals msg
190               = msg {
191                   msgAdditionals = truncList $ msgAdditionals msg
192                 }
193           | notNull $ msgAuthorities msg
194               = msg {
195                   msgHeader      = setTruncFlag $ msgHeader msg
196                 , msgAuthorities = truncList $ msgAuthorities msg
197                 }
198           | notNull $ msgAnswers msg
199               = msg {
200                   msgHeader      = setTruncFlag $ msgHeader msg
201                 , msgAnswers     = truncList $ msgAnswers msg
202                 }
203           | notNull $ msgQuestions msg
204               = msg {
205                   msgHeader      = setTruncFlag $ msgHeader msg
206                 , msgQuestions   = truncList $ msgQuestions msg
207                 }
208           | otherwise
209               = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
210
211       setTruncFlag :: Header -> Header
212       setTruncFlag hdr = hdr { hdIsTruncated = True }
213
214       notNull :: [a] -> Bool
215       notNull = not . null
216
217       truncList :: [a] -> [a]
218       truncList xs = take (length xs - 1) xs
219
220 unpackMessage :: BS.ByteString -> Message
221 unpackMessage = decode . LBS.fromChunks . return
222
223 mkErrorReply :: ResponseCode -> Message -> Message
224 mkErrorReply err msg
225     = runBuilder msg $ do unauthorise
226                           setResponseCode err