]> gitweb @ CieloNegro.org - Lucu.git/blob - Data/Collections/Newtype/TH.hs
auto-derive Foldable
[Lucu.git] / Data / Collections / Newtype / TH.hs
1 {-# LANGUAGE
2     TemplateHaskell
3   , UnicodeSyntax
4   #-}
5 -- |FIXME: doc
6 module Data.Collections.Newtype.TH
7     ( derive
8     )
9     where
10 import Control.Applicative hiding (empty)
11 import Control.Monad.Unicode
12 import Data.Collections
13 import Data.Collections.BaseInstances ()
14 import Data.Data
15 import Data.Generics.Aliases
16 import Data.Generics.Schemes
17 import Data.Maybe
18 import Language.Haskell.TH.Lib
19 import Language.Haskell.TH.Ppr
20 import Language.Haskell.TH.Syntax
21 import Prelude hiding ( concat, concatMap, exp
22                       , foldl, foldr, foldl1, foldr1, null)
23 import Prelude.Unicode
24
25 type Deriver = Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
26
27 -- |FIXME: doc
28 derive ∷ Q [Dec] → Q [Dec]
29 derive = (concat <$>) ∘ (mapM go =≪)
30     where
31       go ∷ Dec → Q [Dec]
32       go (InstanceD c ty _) = deriveInstance c ty
33       go _ = fail "derive: usage: derive [d| instance A; instance B; ... |]"
34
35 deriveInstance ∷ Cxt → Type → Q [Dec]
36 deriveInstance c ty
37     = do (wrapperTy, deriver) ← inspectInstance ty
38          (wrap     , wrapD  ) ← genWrap   wrapperTy
39          (unwrap   , unwrapD) ← genUnwrap wrapperTy
40          instanceDecl         ← deriver (return c     )
41                                         (return ty    )
42                                         (return wrap  )
43                                         (return unwrap)
44          return $ [ d | d ← wrapD  , wrap   `isUsedIn` instanceDecl ]
45                 ⧺ [ d | d ← unwrapD, unwrap `isUsedIn` instanceDecl ]
46                 ⧺ [ instanceDecl ]
47
48 isUsedIn ∷ (Eq α, Typeable α, Data β) ⇒ α → β → Bool
49 isUsedIn α = (> 0) ∘ gcount (mkQ False (≡ α))
50
51 inspectInstance ∷ Type → Q (Type, Deriver)
52 inspectInstance (AppT (AppT (ConT classTy) wrapperTy) _)
53     | classTy ≡ ''Unfoldable
54         = return (wrapperTy, deriveUnfoldable)
55     | classTy ≡ ''Foldable
56         = return (wrapperTy, deriveFoldable)
57 inspectInstance ty
58     = fail $ "deriveInstance: unsupported type: " ⧺ pprint ty
59
60 genWrap ∷ Type → Q (Exp, [Dec])
61 genWrap wrapperTy
62     = do name      ← newName "wrap"
63          (con, ty) ← wrapperConTy wrapperTy
64          decls     ← sequence
65                      [ sigD name [t| $(return ty) → $(return wrapperTy) |]
66                      , pragInlD name (inlineSpecNoPhase True True)
67                      , funD name [clause [] (normalB (conE con)) []]
68                      ]
69          return (VarE name, decls)
70
71 genUnwrap ∷ Type → Q (Exp, [Dec])
72 genUnwrap wrapperTy
73     = do name      ← newName "unwrap"
74          i         ← newName "i"
75          (con, ty) ← wrapperConTy wrapperTy
76          decls     ← sequence
77                      [ sigD name [t| $(return wrapperTy) → $(return ty) |]
78                      , pragInlD name (inlineSpecNoPhase True True)
79                      , funD name [clause [conP con [varP i]] (normalB (varE i)) []]
80                      ]
81          return (VarE name, decls)
82
83 wrapperConTy ∷ Type → Q (Name, Type)
84 wrapperConTy = (conTy =≪) ∘ tyInfo
85     where
86       tyInfo ∷ Type → Q Info
87       tyInfo (ConT name) = reify name
88       tyInfo (AppT ty _) = tyInfo ty
89       tyInfo (SigT ty _) = tyInfo ty
90       tyInfo ty
91           = fail $ "wrapperConTy: unsupported type: " ⧺ pprint ty
92
93       conTy ∷ Info → Q (Name, Type)
94       conTy (TyConI (NewtypeD [] _ [] (NormalC con [(NotStrict, ty)]) []))
95           = return (con, ty)
96       conTy info
97           = fail $ "wrapperConTy: unsupported type: " ⧺ pprint info
98
99 methodNames ∷ Name → Q [Name]
100 methodNames = (names =≪) ∘ reify
101     where
102       names ∷ Info → Q [Name]
103       names (ClassI (ClassD _ _ _ _ decls) _)
104               = return ∘ catMaybes $ map name decls
105       names c = fail $ "methodNames: not a class: " ⧺ pprint c
106
107       name ∷ Dec → Maybe Name
108       name (SigD n _) = Just n
109       name _          = Nothing
110
111 pointfreeMethod ∷ (Name → Q Exp) → Name → [Q Dec]
112 pointfreeMethod f name
113     = [ funD name [clause [] (normalB (f name)) []]
114       -- THINKME: Inserting PragmaD in an InstanceD causes an error
115       -- least GHC 7.0.3. Why?
116       -- , pragInlD name (inlineSpecNoPhase True False)
117       ]
118
119 deriveUnfoldable ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
120 deriveUnfoldable c ty wrap unwrap
121     = do names ← methodNames ''Unfoldable
122          instanceD c ty $ concatMap (pointfreeMethod exp) names
123     where
124       exp ∷ Name → Q Exp
125       exp name
126           | name ≡ 'insert
127               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insert |]
128           | name ≡ 'empty
129               = [| $wrap empty |]
130           | name ≡ 'singleton
131               = [| $wrap ∘ singleton |]
132           | name ≡ 'insertMany
133               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insertMany |]
134           | name ≡ 'insertManySorted
135               = [| ($wrap ∘) ∘ (∘ $unwrap) ∘ insertManySorted |]
136           | otherwise
137               = fail $ "deriveUnfoldable: unknown method: " ⧺ pprint name
138
139 deriveFoldable ∷ Q Cxt → Q Type → Q Exp → Q Exp → Q Dec
140 deriveFoldable c ty _ unwrap
141     = do names ← methodNames ''Foldable
142          instanceD c ty $ concatMap (pointfreeMethod exp) names
143     where
144       exp ∷ Name → Q Exp
145       exp name
146           | name ≡ 'fold
147               = [| fold ∘ $unwrap |]
148           | name ≡ 'foldMap
149               = [| (∘ $unwrap) ∘ foldMap |]
150           | name ≡ 'foldr
151               = [| flip flip $unwrap ∘ ((∘) ∘) ∘ foldr |]
152           | name ≡ 'foldl
153               = [| flip flip $unwrap ∘ ((∘) ∘) ∘ foldl |]
154           | name ≡ 'foldr1
155               = [| (∘ $unwrap) ∘ foldr1 |]
156           | name ≡ 'foldl1
157               = [| (∘ $unwrap) ∘ foldl1 |]
158           | name ≡ 'null
159               = [| null ∘ $unwrap |]
160           | name ≡ 'size
161               = [| size ∘ $unwrap |]
162           | name ≡ 'isSingleton
163               = [| isSingleton ∘ $unwrap |]
164           | otherwise
165               = fail $ "deriveFoldable: unknown method: " ⧺ pprint name