module Data.Derive.Internal.Traversal(
TraveralType(..), defaultTraversalType,
traversalDerivation1,
traversalInstance, traversalInstance1,
deriveTraversal
) where
import Language.Haskell
import Data.Derive.Internal.Derivation
import Data.List
import qualified Data.Set as S
import Control.Monad.Trans.Writer
import Control.Applicative
import Data.Generics.Uniplate.DataOnly
import Data.Maybe
import Prelude
type Trav = Exp ()
data TraveralType = TraveralType
{ traversalArg :: Int
, traversalCo :: Bool
, traversalName :: QName ()
, traversalId :: Trav
, traversalDirect :: Trav
, traversalFunc :: QName () -> Trav -> Trav
, traversalPlus :: Trav -> Trav -> Trav
, traverseArrow :: Maybe (Trav -> Trav -> Trav)
, traverseTuple :: [Exp ()] -> Exp ()
, traverseCtor :: String -> [Exp ()] -> Exp ()
, traverseFunc :: Pat () -> Exp () -> Match ()
}
defaultTraversalType = TraveralType
{ traversalArg = 1
, traversalCo = False
, traversalName = undefined
, traversalId = var "id"
, traversalDirect = var "_f"
, traversalFunc = \x y -> appP (Var () x) y
, traversalPlus = \x y -> apps (Con () $ Special () (Cons ())) [paren x, paren y]
, traverseArrow = Nothing
, traverseTuple = Tuple () Boxed
, traverseCtor = \x y -> apps (con x) (map paren y)
, traverseFunc = undefined
}
data RequiredInstance = RequiredInstance
{ _requiredDataArg :: String
, _requiredPosition :: Int
}
deriving (Eq, Ord)
type WithInstances a = Writer (S.Set RequiredInstance) a
vars f c n = [f $ c : show i | i <- [1..n]]
traversalDerivation1 :: TraveralType -> String -> Derivation
traversalDerivation1 tt nm = derivationCustom (className $ traversalArg tt) (traversalInstance1 tt nm)
where className n = nm ++ (if n > 1 then show n else "")
traversalInstance1 :: TraveralType -> String -> FullDataDecl -> Either String [Decl ()]
traversalInstance1 tt nm (_,dat)
| isNothing (traverseArrow tt) && any isTyFun (universeBi dat) = Left $ "Can't derive " ++ prettyPrint (traversalName tt) ++ " for types with arrow"
| dataDeclArity dat == 0 = Left "Cannot derive class for data type arity == 0"
| otherwise = Right $ traversalInstance tt nm dat [deriveTraversal tt dat]
traversalInstance :: TraveralType -> String -> DataDecl -> [WithInstances (Decl ())] -> [Decl ()]
traversalInstance tt nameBase dat bodyM =
[ simplify $ InstDecl () Nothing instRule (Just $ map (InsDecl ()) body) ]
where
instRule = IRule () Nothing (Just ctx) instHead
instHead = foldr (flip (IHApp ())) (IHCon () nam) args
(body, required) = runWriter (sequence bodyM)
ctx = CxTuple ()
[ ClassA () (qname $ className p) (tyVar n : vars tyVar 's' (p 1))
| RequiredInstance n p <- S.toList required
]
vrs = vars tyVar 't' (dataDeclArity dat)
(vrsBefore,_:vrsAfter) = splitAt (length vrs traversalArg tt) vrs
className n = nameBase ++ (if n > 1 then show n else "")
nam = qname (className (traversalArg tt))
args = TyParen () (tyApps (tyCon $ dataDeclName dat) vrsBefore) : vrsAfter
deriveTraversal :: TraveralType -> DataDecl -> WithInstances (Decl ())
deriveTraversal tt dat = fun
where
fun = (\xs -> FunBind () [Match () nam a b c | Match () _ a b c <- xs]) <$> body
args = argPositions dat
nam = unqual $ traversalNameN tt $ traversalArg tt
body = mapM (deriveTraversalCtor tt args) (dataDeclCtors dat)
unqual (Qual () _ x) = x
unqual (UnQual () x) = x
deriveTraversalCtor :: TraveralType -> ArgPositions -> CtorDecl -> WithInstances (Match ())
deriveTraversalCtor tt ap ctor = do
let nam = ctorDeclName ctor
arity = ctorDeclArity ctor
tTypes <- mapM (deriveTraversalType tt ap) (map snd $ ctorDeclFields ctor)
return $ traverseFunc tt (PParen () $ PApp () (qname nam) (vars pVar 'a' arity))
$ traverseCtor tt nam (zipWith (App ()) tTypes (vars var 'a' arity))
deriveTraversalType :: TraveralType -> ArgPositions -> Type () -> WithInstances Trav
deriveTraversalType tt ap (TyParen () x) = deriveTraversalType tt ap x
deriveTraversalType tt ap TyForall{} = fail "forall not supported in traversal deriving"
deriveTraversalType tt ap (TyFun () a b)
= fromJust (traverseArrow tt)
<$> deriveTraversalType tt{traversalCo = not $ traversalCo tt} ap a
<*> deriveTraversalType tt ap b
deriveTraversalType tt ap (TyApp () a b) = deriveTraversalApp tt ap a [b]
deriveTraversalType tt ap (TyList () a) = deriveTraversalType tt ap $ TyApp () (TyCon () $ Special () $ ListCon ()) a
deriveTraversalType tt ap (TyTuple () b a) = deriveTraversalType tt ap $ tyApps (TyCon () $ Special () $ TupleCon () b $ length a) a
deriveTraversalType tt ap (TyCon () n) = return $ traversalId tt
deriveTraversalType tt ap (TyVar () (Ident () n))
| ap n /= traversalArg tt = return $ traversalId tt
| traversalCo tt = fail "tyvar used in covariant position"
| otherwise = return $ traversalDirect tt
deriveTraversalApp :: TraveralType -> ArgPositions -> Type () -> [Type ()] -> WithInstances Trav
deriveTraversalApp tt ap (TyApp () a b) args = deriveTraversalApp tt ap a (b : args)
deriveTraversalApp tt ap tycon@TyTuple{} args = do
tArgs <- mapM (deriveTraversalType tt ap) args
return $
if (all (== traversalId tt) tArgs) then
traversalId tt
else
Lambda () [PTuple () Boxed (vars pVar 't' (length args))]
(traverseTuple tt $ zipWith (App ()) tArgs (vars var 't' (length args)))
deriveTraversalApp tt ap tycon args = do
tCon <- deriveTraversalType tt ap tycon
tArgs <- mapM (deriveTraversalType tt ap) args
case tycon of
TyVar () (Ident () n) | ap n == traversalArg tt -> fail "kind error: type used type constructor"
| otherwise -> tell $ S.fromList
[ RequiredInstance n i
| (t,i) <- zip (reverse tArgs) [1..]
, t /= traversalId tt
]
_ -> return ()
let nonId = [ traverseArg tt i t
| (t,i) <- zip (reverse tArgs) [1..]
, t /= traversalId tt
]
return $ case nonId of
[] -> traversalId tt
_ -> foldl1 (traversalPlus tt) nonId
traverseArg :: TraveralType -> Int -> Trav -> Trav
traverseArg tt n e = traversalFunc tt (traversalNameN tt n) e
traversalNameN :: TraveralType -> Int -> QName ()
traversalNameN tt n | n <= 1 = nm
| otherwise = nm `f` (if n > 1 then show n else "")
where nm = traversalName tt
f (Qual () m x) y = Qual () m $ x `g` y
f (UnQual () x) y = UnQual () $ x `g` y
g (Ident () x) y = Ident () $ x ++ y
type ArgPositions = String -> Int
argPositions :: DataDecl -> String -> Int
argPositions dat = \nm -> case elemIndex nm args of
Nothing -> error "impossible: tyvar not in scope"
Just k -> length args k
where args = dataDeclVars dat