]> gitweb @ CieloNegro.org - Lucu.git/blob - Data/Collections/Newtype/TH.hs
auto-derive Map
[Lucu.git] / Data / Collections / Newtype / TH.hs
1 {-# LANGUAGE
2     TemplateHaskell
3   , UnicodeSyntax
4   #-}
5 -- |FIXME: doc
6 module Data.Collections.Newtype.TH
7     ( derive
8     )
9     where
10 import Control.Applicative hiding (empty)
11 import Control.Arrow
12 import Control.Monad.Unicode
13 import Data.Collections
14 import Data.Collections.BaseInstances ()
15 import Data.Data
16 import Data.Generics.Aliases
17 import Data.Generics.Schemes
18 import Data.Maybe
19 import Language.Haskell.TH.Lib
20 import Language.Haskell.TH.Ppr
21 import Language.Haskell.TH.Syntax
22 import Prelude hiding ( concat, concatMap, exp, filter
23                       , foldl, foldr, foldl1, foldr1
24                       , lookup, null
25                       )
26 import Prelude.Unicode
27
28 type Deriver = Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
29
30 -- |FIXME: doc
31 derive ∷ Q [Dec] → Q [Dec]
32 derive = (concat <$>) ∘ (mapM go =≪)
33     where
34       go ∷ Dec → Q [Dec]
35       go (InstanceD c ty _) = deriveInstance c ty
36       go _ = fail "derive: usage: derive [d| instance A; instance B; ... |]"
37
38 deriveInstance ∷ Cxt → Type → Q [Dec]
39 deriveInstance c ty
40     = do (wrapperTy, deriver) ← inspectInstance ty
41          (wrap     , wrapD  ) ← genWrap   wrapperTy
42          (unwrap   , unwrapD) ← genUnwrap wrapperTy
43          instanceDecl         ← deriver (return c     )
44                                         (return ty    )
45                                         (return wrap  )
46                                         (return unwrap)
47          return $ [ d | d ← wrapD  , wrap   `isUsedIn` instanceDecl ]
48                 ⧺ [ d | d ← unwrapD, unwrap `isUsedIn` instanceDecl ]
49                 ⧺ [ instanceDecl ]
50
51 isUsedIn ∷ (Eq α, Typeable α, Data β) ⇒ α → β → Bool
52 isUsedIn α = (> 0) ∘ gcount (mkQ False (≡ α))
53
54 inspectInstance ∷ Type → Q (Type, Deriver)
55 inspectInstance (AppT (AppT (ConT classTy) wrapperTy) _)
56     | classTy ≡ ''Unfoldable
57         = return (wrapperTy, deriveUnfoldable)
58     | classTy ≡ ''Foldable
59         = return (wrapperTy, deriveFoldable)
60     | classTy ≡ ''Collection
61         = return (wrapperTy, deriveCollection)
62     | classTy ≡ ''SortingCollection
63         = return (wrapperTy, deriveSortingCollection)
64 inspectInstance (AppT (AppT (AppT (ConT classTy) wrapperTy) _) _)
65     | classTy ≡ ''Indexed
66         = return (wrapperTy, deriveIndexed)
67     | classTy ≡ ''Map
68         = return (wrapperTy, deriveMap)
69 inspectInstance ty
70     = fail $ "deriveInstance: unsupported type: " ⧺ pprint ty
71
72 genWrap ∷ Type → Q (Exp, [Dec])
73 genWrap wrapperTy
74     = do name      ← newName "wrap"
75          (con, ty) ← wrapperConTy wrapperTy
76          decls     ← sequence
77                      [ sigD name [t| $(return ty) → $(return wrapperTy) |]
78                      , pragInlD name (inlineSpecNoPhase True True)
79                      , funD name [clause [] (normalB (conE con)) []]
80                      ]
81          return (VarE name, decls)
82
83 genUnwrap ∷ Type → Q (Exp, [Dec])
84 genUnwrap wrapperTy
85     = do name      ← newName "unwrap"
86          i         ← newName "i"
87          (con, ty) ← wrapperConTy wrapperTy
88          decls     ← sequence
89                      [ sigD name [t| $(return wrapperTy) → $(return ty) |]
90                      , pragInlD name (inlineSpecNoPhase True True)
91                      , funD name [clause [conP con [varP i]] (normalB (varE i)) []]
92                      ]
93          return (VarE name, decls)
94
95 wrapperConTy ∷ Type → Q (Name, Type)
96 wrapperConTy = (conTy =≪) ∘ tyInfo
97     where
98       tyInfo ∷ Type → Q Info
99       tyInfo (ConT name) = reify name
100       tyInfo (AppT ty _) = tyInfo ty
101       tyInfo (SigT ty _) = tyInfo ty
102       tyInfo ty
103           = fail $ "wrapperConTy: unsupported type: " ⧺ pprint ty
104
105       conTy ∷ Info → Q (Name, Type)
106       conTy (TyConI (NewtypeD [] _ [] (NormalC con [(NotStrict, ty)]) []))
107           = return (con, ty)
108       conTy info
109           = fail $ "wrapperConTy: unsupported type: " ⧺ pprint info
110
111 methodNames ∷ Name → Q [Name]
112 methodNames = (names =≪) ∘ reify
113     where
114       names ∷ Info → Q [Name]
115       names (ClassI (ClassD _ _ _ _ decls) _)
116               = return ∘ catMaybes $ map name decls
117       names c = fail $ "methodNames: not a class: " ⧺ pprint c
118
119       name ∷ Dec → Maybe Name
120       name (SigD n _) = Just n
121       name _          = Nothing
122
123 pointfreeMethod ∷ (Name → Q Exp) → Name → [Q Dec]
124 pointfreeMethod f name
125     = [ funD name [clause [] (normalB (f name)) []]
126       -- THINKME: Inserting PragmaD in an InstanceD causes an error
127       -- least GHC 7.0.3. Why?
128       -- , pragInlD name (inlineSpecNoPhase True False)
129       ]
130
131 deriveUnfoldable ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
132 deriveUnfoldable c ty wrap unwrap
133     = do names ← methodNames ''Unfoldable
134          instanceD c ty $ concatMap (pointfreeMethod exp) names
135     where
136       exp ∷ Name → Q Exp
137       exp name
138           | name ≡ 'insert
139               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insert |]
140           | name ≡ 'empty
141               = [| $wrap empty |]
142           | name ≡ 'singleton
143               = [| $wrap ∘ singleton |]
144           | name ≡ 'insertMany
145               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insertMany |]
146           | name ≡ 'insertManySorted
147               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insertManySorted |]
148           | otherwise
149               = fail $ "deriveUnfoldable: unknown method: " ⧺ pprint name
150
151 deriveFoldable ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
152 deriveFoldable c ty _ unwrap
153     = do names ← methodNames ''Foldable
154          instanceD c ty $ concatMap (pointfreeMethod exp) names
155     where
156       exp ∷ Name → Q Exp
157       exp name
158           | name ≡ 'fold
159               = [| fold ∘ $unwrap |]
160           | name ≡ 'foldMap
161               = [| (∘ $unwrap) ∘ foldMap |]
162           | name ≡ 'foldr
163               = [| flip flip $unwrap ∘ ((∘) ∘) ∘ foldr |]
164           | name ≡ 'foldl
165               = [| flip flip $unwrap ∘ ((∘) ∘) ∘ foldl |]
166           | name ≡ 'foldr1
167               = [| (∘ $unwrap) ∘ foldr1 |]
168           | name ≡ 'foldl1
169               = [| (∘ $unwrap) ∘ foldl1 |]
170           | name ≡ 'null
171               = [| null ∘ $unwrap |]
172           | name ≡ 'size
173               = [| size ∘ $unwrap |]
174           | name ≡ 'isSingleton
175               = [| isSingleton ∘ $unwrap |]
176           | otherwise
177               = fail $ "deriveFoldable: unknown method: " ⧺ pprint name
178
179 deriveCollection ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
180 deriveCollection c ty wrap unwrap
181     = do names ← methodNames ''Collection
182          instanceD c ty $ concatMap (pointfreeMethod exp) names
183     where
184       exp ∷ Name → Q Exp
185       exp name
186           | name ≡ 'filter
187               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ filter |]
188           | otherwise
189               = fail $ "deriveCollection: unknown method: " ⧺ pprint name
190
191 deriveIndexed ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
192 deriveIndexed c ty wrap unwrap
193     = do names ← methodNames ''Indexed
194          instanceD c ty $ concatMap (pointfreeMethod exp) names
195     where
196       exp ∷ Name → Q Exp
197       exp name
198           | name ≡ 'index
199               = [| (∘ $unwrap) ∘ index |]
200           | name ≡ 'adjust
201               = [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ adjust |]
202           | name ≡ 'inDomain
203               = [| (∘ $unwrap) ∘ inDomain |]
204           | name ≡ '(//)
205               = [| ($wrap ∘) ∘ (//) ∘ $unwrap |]
206           | name ≡ 'accum
207               = [| (($wrap ∘) ∘) ∘ (∘ $unwrap) ∘ accum |]
208           | otherwise
209               = fail $ "deriveIndexed: unknown method: " ⧺ pprint name
210
211 deriveMap ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
212 deriveMap c ty wrap unwrap
213     = do names ← methodNames ''Map
214          instanceD c ty $ concatMap (pointfreeMethod exp) names
215     where
216       exp ∷ Name → Q Exp
217       exp name
218           | name ≡ 'delete
219               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ delete |]
220           | name ≡ 'member
221               = [| (∘ $unwrap) ∘ member |]
222           | name ≡ 'union
223               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ union ∘ $unwrap |]
224           | name ≡ 'intersection
225               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ intersection ∘ $unwrap |]
226           | name ≡ 'difference
227               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ difference ∘ $unwrap |]
228           | name ≡ 'isSubset
229               = [| (∘ $unwrap) ∘ isSubset ∘ $unwrap |]
230           | name ≡ 'isProperSubset
231               = [| (∘ $unwrap) ∘ isProperSubset ∘ $unwrap |]
232           | name ≡ 'lookup
233               = [| (∘ $unwrap) ∘ lookup |]
234           | name ≡ 'alter
235               = [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ alter |]
236           | name ≡ 'insertWith
237               = [| ((($wrap ∘) ∘) ∘) ∘ flip flip $unwrap ∘ ((flip ∘ ((∘) ∘)) ∘) ∘ insertWith |]
238           | name ≡ 'fromFoldableWith
239               = [| ($wrap ∘) ∘ fromFoldableWith |]
240           | name ≡ 'foldGroups
241               = [| (($wrap ∘) ∘) ∘ foldGroups |]
242           | name ≡ 'mapWithKey
243               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ mapWithKey |]
244           | name ≡ 'unionWith
245               = [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ unionWith |]
246           | name ≡ 'intersectionWith
247               = [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ intersectionWith |]
248           | name ≡ 'differenceWith
249               = [| (($wrap ∘) ∘) ∘ flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ differenceWith |]
250           | name ≡ 'isSubmapBy
251               = [| flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ isSubmapBy |]
252           | name ≡ 'isProperSubmapBy
253               = [| flip flip $unwrap ∘ ((∘) ∘) ∘ (∘ $unwrap) ∘ isProperSubmapBy |]
254           | otherwise
255               = fail $ "deriveMap: unknown method: " ⧺ pprint name
256
257 deriveSortingCollection ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
258 deriveSortingCollection c ty wrap unwrap
259     = do names ← methodNames ''SortingCollection
260          instanceD c ty $ concatMap (pointfreeMethod exp) names
261     where
262       exp ∷ Name → Q Exp
263       exp name
264           | name ≡ 'minView
265               = [| (second $wrap <$>) ∘ minView ∘ $unwrap |]
266           | otherwise
267               = fail $ "deriveSortingCollection: unknown method: " ⧺ pprint name