]> gitweb @ CieloNegro.org - haskell-dns.git/blob - Network/DNS/Named.hs
Farewell to the Sanity.hs
[haskell-dns.git] / Network / DNS / Named.hs
1 module Network.DNS.Named
2     ( runNamed
3     )
4     where
5
6 import           Control.Concurrent
7 import           Control.Exception
8 import           Control.Monad
9 import           Data.Binary
10 import qualified Data.ByteString as BS
11 import qualified Data.ByteString.Lazy as LBS
12 import           Data.Dynamic
13 import           Data.Maybe
14 import           Network.Socket
15 import qualified Network.Socket.ByteString as NB
16 import           Network.DNS.Message
17 import           Network.DNS.Named.Config
18 import           Network.DNS.Named.ResponseBuilder
19 import           Network.DNS.Named.Zone
20 import           System.Posix.Signals
21
22
23 runNamed :: Config -> (DomainName -> IO (Maybe Zone)) -> IO ()
24 runNamed cnf findZone
25     = withSocketsDo $
26       do installHandler sigPIPE Ignore Nothing
27          _tcpListenerTID <- forkIO $ tcpListen
28          udpListen
29     where
30       udpListen :: IO ()
31       udpListen = do -- FIXME: we should support IPv6 when the network package supports it.
32                      so <- socket AF_INET Datagram defaultProtocol
33                      print cnf
34                      bindSocket so $ cnfServerAddress cnf
35                      udpLoop so
36
37       udpLoop :: Socket -> IO ()
38       udpLoop so
39           = do (packet, cameFrom) <- NB.recvFrom so 512
40                _handlerTID <- forkIO $ udpHandler so packet cameFrom
41                udpLoop so
42
43       tcpListen :: IO ()
44       tcpListen = putStrLn "FIXME: tcpListen is not implemented yet."
45
46       udpHandler :: Socket -> BS.ByteString -> SockAddr -> IO ()
47       udpHandler so packet cameFrom
48           = do msg   <- evaluate $ unpackMessage packet
49                msg'  <- handleMessage msg
50                         `onException`
51                         do let servfail = mkErrorReply ServerFailure msg
52                            NB.sendTo so (packMessage (Just 512) servfail) cameFrom
53                _sent <- NB.sendTo so (packMessage (Just 512) msg') cameFrom
54                return ()
55
56       handleMessage :: Message -> IO Message
57       handleMessage msg
58           = case validateQuery msg of
59               NoError
60                   -> do builders <- mapM handleQuestion $ msgQuestions msg
61
62                         let builder = foldl (>>) (return ()) builders
63                             msg'    = runBuilder msg builder
64
65                         return msg'
66
67               err -> return $ mkErrorReply err msg
68
69       handleQuestion :: SomeQ -> IO (Builder ())
70       handleQuestion (SomeQ q)
71           = do zoneM <- findZone (qName q)
72                case zoneM of
73                  Nothing
74                      -> return $ do unauthorise
75                                     setResponseCode Refused
76                  Just zone
77                      -> handleQuestionForZone (SomeQ q) zone
78
79       handleQuestionForZone :: SomeQ -> Zone -> IO (Builder ())
80       handleQuestionForZone (SomeQ q) zone
81           | Just (qType q) == cast AXFR
82               = handleAXFR (SomeQ q) zone
83           | otherwise
84               = do allRecords <- zoneResponder zone (qName q)
85                    let filtered = filterRecords (SomeQ q) allRecords
86
87                    additionals <- do xss <- mapM (getAdditionals zone) filtered
88                                      ys  <- case zoneNSRecord zone of
89                                               Just rr -> getAdditionals zone rr
90                                               Nothing -> return []
91                                      return (concat xss ++ ys)
92
93                    return $ do mapM_ addAnswer filtered
94
95                                when (qName q == zoneName zone) $
96                                     do when (Just (qType q) == cast SOA ||
97                                              Just (qType q) == cast ANY   )
98                                                 $ case zoneSOARecord zone of
99                                                     Just rr -> addAnswer rr
100                                                     Nothing -> return ()
101
102                                        when (Just (qType q) == cast NS ||
103                                              Just (qType q) == cast ANY  )
104                                                 $ case zoneNSRecord zone of
105                                                     Just rr -> addAnswer rr
106                                                     Nothing -> return ()
107
108                                mapM_ addAdditional additionals
109
110                                case zoneNSRecord zone of
111                                  Just rr -> addAuthority rr
112                                  Nothing -> unauthorise
113
114       getAdditionals :: Zone -> SomeRR -> IO [SomeRR]
115       getAdditionals zone (SomeRR rr)
116           = case cast (rrData rr) :: Maybe DomainName of
117               Nothing
118                   -> return []
119               Just name
120                   -> do allRecords <- zoneResponder zone name
121
122                         let filtered = filterRecords (SomeQ q') allRecords
123                             q'       = Question {
124                                          qName  = name
125                                        , qType  = A
126                                        , qClass = IN
127                                        }
128                         return filtered
129
130       filterRecords :: SomeQ -> [SomeRR] -> [SomeRR]
131       filterRecords (SomeQ q)
132           | Just (qType  q) == cast ANY &&
133             Just (qClass q) == cast ANY    = id
134           | Just (qType  q) == cast ANY    = filter matchClass
135           | Just (qClass q) == cast ANY    = filter matchType
136           | otherwise                      = filter matchBoth
137           where
138             matchClass (SomeRR rr)
139                 = Just (qClass q) == cast (rrClass rr)
140
141             matchType (SomeRR rr)
142                 = Just (qType  q) == cast (rrType  rr) ||
143                   Just CNAME      == cast (rrType  rr)
144
145             matchBoth rr
146                 = matchType rr && matchClass rr
147
148       handleAXFR :: SomeQ -> Zone -> IO (Builder ())
149       handleAXFR (SomeQ _q) _zone
150           = fail "FIXME: not implemented yet"
151
152
153 validateQuery :: Message -> ResponseCode
154 validateQuery = validateHeader . msgHeader
155     where
156       validateHeader :: Header -> ResponseCode
157       validateHeader hdr
158           | hdMessageType hdr /= Query         = NotImplemented
159           | hdOpcode      hdr /= StandardQuery = NotImplemented
160           | otherwise                          = NoError
161
162
163 packMessage :: Maybe Int -> Message -> BS.ByteString
164 packMessage limM = BS.concat . LBS.toChunks . truncateMsg
165     where
166       truncateMsg :: Message -> LBS.ByteString
167       truncateMsg msg
168           = let packet    = encode msg
169                 needTrunc = fromMaybe False $
170                             do lim <- limM
171                                return $ fromIntegral (LBS.length packet) > lim
172             in
173               if needTrunc then
174                   truncateMsg $ trunc' msg
175               else
176                   packet
177
178       trunc' :: Message -> Message
179       trunc' msg
180           | notNull $ msgAdditionals msg
181               = msg {
182                   msgAdditionals = truncList $ msgAdditionals msg
183                 }
184           | notNull $ msgAuthorities msg
185               = msg {
186                   msgHeader      = setTruncFlag $ msgHeader msg
187                 , msgAuthorities = truncList $ msgAuthorities msg
188                 }
189           | notNull $ msgAnswers msg
190               = msg {
191                   msgHeader      = setTruncFlag $ msgHeader msg
192                 , msgAnswers     = truncList $ msgAnswers msg
193                 }
194           | notNull $ msgQuestions msg
195               = msg {
196                   msgHeader      = setTruncFlag $ msgHeader msg
197                 , msgQuestions   = truncList $ msgQuestions msg
198                 }
199           | otherwise
200               = error ("packMessage: You are already skinny and need no diet: " ++ show msg)
201
202       setTruncFlag :: Header -> Header
203       setTruncFlag hdr = hdr { hdIsTruncated = True }
204
205       notNull :: [a] -> Bool
206       notNull = not . null
207
208       truncList :: [a] -> [a]
209       truncList xs = take (length xs - 1) xs
210
211 unpackMessage :: BS.ByteString -> Message
212 unpackMessage = decode . LBS.fromChunks . return
213
214 mkErrorReply :: ResponseCode -> Message -> Message
215 mkErrorReply err msg
216     = runBuilder msg $ do unauthorise
217                           setResponseCode err