# -*- coding: utf-8 -*-
"""
This module contains shared code generation tools for the ufc::*_integral classes.
"""

# Copyright (C) 2008-2009 Martin Sandve Alnes and Simula Resarch Laboratory
#
# This file is part of SyFi.
#
# SyFi is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# SyFi is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SyFi. If not, see <http://www.gnu.org/licenses/>.
#
# First added:  2008-08-13
# Last changed: 2009-03-10

from itertools import tee
from sfc.common.output import sfc_debug, sfc_error, sfc_assert, sfc_warning
from sfc.common.utilities import indices_subset, unique
from sfc.quadrature import gen_quadrature_rule_definition, gen_quadrature_rule_definitions
from sfc.codegeneration.codeformatting import indent, CodeFormatter, \
    gen_const_token_definitions, gen_token_prints, \
    gen_token_declarations, gen_token_assignments, gen_token_additions

class IntegralCG(object):
    def __init__(self, itgrep):
        self.itgrep = itgrep
        self.options = self.itgrep.formrep.options.code.integral
    
    def hincludes(self):
        l = []
        return l
    
    def cincludes(self):
        l = []
        return l
    
    def generate_code_dict(self):
        vars = {
             "classname"                  : self.itgrep.classname,
             "constructor"                : indent(self.gen_constructor()),
             "constructor_arguments"      : indent(self.gen_constructor_arguments()),
             "initializer_list"           : indent(self.gen_initializer_list()),
             "destructor"                 : indent(self.gen_destructor()),
             "members"                    : indent(self.gen_members()),
             "tabulate_tensor"            : indent(self.gen_tabulate_tensor()),
             "tabulate_tensor_quadrature" : indent(self.gen_tabulate_tensor_quadrature()),
             }
        return vars
    
    def generate_support_code(self):
        ccode = ""
        
        # Write quadrature rule if we have one: # TODO: Place quadrature rules in a separate shared file?
        fr = self.itgrep.formrep
        if fr.quad_rule is not None:
            ccode += gen_quadrature_rule_definition(fr.quad_rule)
        
        if fr.facet_quad_rules is not None:
            ccode += gen_quadrature_rule_definitions(fr.facet_quad_rules)
        
        # Write sign correction code if needed: TODO: Update sign code
        #nsd = fr.cell.nsd
        #handled_fe_names = set()
        #fe_list = fr.fe_list
        #for fe in fe_list:
        #    if fe_is_signed(fe):
        #        fe_name = strings.finite_element_classname(fe)
        #        if not fe_name in handled_fe_names:
        #            handled_fe_names.add(fe_name)
        #            ccode += tabulate_sign_code(fe, nsd)
        
        return ccode
    
    def gen_constructor(self):
        raise NotImplementedError

    def gen_constructor_arguments(self):
        return ""

    def gen_initializer_list(self):
        return ""
    
    def gen_destructor(self):
        raise NotImplementedError

    def gen_members(self):
        raise NotImplementedError
    
    def gen_pre_debug_code(self):
        if not self.options.enable_debug_code:
            return ""
        code = CodeFormatter()
        code.begin_debug()
        code.begin_block()
        code += 'std::cout << std::endl;'
        code += 'std::cout << "SFC DEBUGGING OUTPUT: " << std::endl;'
        code += 'std::cout << "void %s::tabulate_tensor(...)" << std::endl;' % self.itgrep.classname
        code += 'std::cout << "{" << std::endl;'
        fr = self.itgrep.formrep
        fd = fr.formdata
        for k in range(fr.num_coefficients):
            element = fd.elements[k]
            rep = fr.element_reps[element]
            code += 'std::cout << "    w[%d][:] = [ ";' % k
            code += "for(int j=0; j<%d; j++)" % rep.local_dimension
            code.begin_block()
            code += 'std::cout << w[%d][j] << ", ";' % k
            code.end_block()
            code += 'std::cout << " ]" << std::endl;'
        code += 'std::cout << "    now computing element tensor..." << std::endl;'
        code.end_block()
        code.end_debug()
        return str(code)
    
    def gen_post_debug_code(self):
        if not self.options.enable_debug_code:
            return ""
        code = CodeFormatter()
        code.begin_debug()
        code.begin_block()
        code += 'std::cout << "    ... done computing element tensor." << std::endl;'
        code += "for(int k=0; k<%d; k++)" % len(self.itgrep.indices)
        code.begin_block()
        # TODO: enhance this output by printing indices
        code += 'std::cout << "    A[" << k << "]" << " = " << A[k] << std::endl;'
        code.end_block()
        code += 'std::cout << "}" << std::endl;'
        code.end_block()
        code.end_debug()
        return str(code)
    
    def gen_partition_block(self, data, deps, basis_functions):
        "Generate code for a partition of the linearized computational graph."
        code = gen_token_assignments( self.itgrep.iter_partition(data, deps, basis_functions) )
        if code:
            deps = ", ".join(deps)
            if basis_functions:
                deps += "; using basis functions %s" % str(basis_functions)
            code = "// Partition depending on %s\n{\n%s\n}" % (deps, indent(code))
        return code
    
    def gen_geometry_block(self):
        sfc_debug("Entering IntegralCG.gen_geometry_block")

        code = CodeFormatter()
        
        # --- Sign correction code: # TODO
        #sign_sym = itgrep.sign_sym
        #for fe_name in sign_sym.keys():
        #    nbf = len(sign_sym[fe_name])
        #    code += use_sign_code2(fe_name, nbf, nsd)
        
        # --- Generate code for for initial geometry tokens
        # (vertex coordinates, affine map, normal vector, ...)
        geometry_tokens = self.itgrep.iter_geometry_tokens()
        
        # if we want to print debug tokens, duplicate the iterator
        if self.options.enable_debug_code:
            geometry_tokens, debug_geometry_tokens = tee(geometry_tokens)
        
        code += "// Geometric quantities"
        code += gen_const_token_definitions(geometry_tokens)
        
        if self.options.enable_debug_code:
            code += ""
            code.begin_debug()
            code += gen_token_prints(debug_geometry_tokens)
            code.end_debug()
            code += ""
        
        # --- Facet tokens code:
        if self.itgrep._on_facet:
            fr = self.itgrep.formrep
            code += "// Geometric quantities on each facet"
            for facet in range(fr.cell.num_facets):
                facet_tokens = self.itgrep.iter_facet_tokens(facet)
                
                if facet == 0:
                    # if we want to print debug tokens, duplicate the iterator
                    if self.options.enable_debug_code:
                        facet_tokens, debug_facet_tokens = tee(facet_tokens)
                    
                    # duplicate the iterator to generate declarations outside switch
                    facet_tokens, facet_tokens2 = tee(facet_tokens)
                    code += gen_token_declarations(facet_tokens2)
                    code.begin_switch("facet")
                
                code.begin_case(facet, braces=True)
                code += "// Geometric quantities on facet %d" % facet
                code += gen_token_assignments(facet_tokens)
                code.end_case()
            code += "default:"
            code.indent()
            code += 'throw std::runtime_error("Invalid facet number.");'
            code.dedent()
            code.end_switch()

            if self.options.enable_debug_code:
                code += ""
                code.begin_debug()
                code += 'std::cout << "facet = " << facet << std::endl;'
                code += gen_token_prints(debug_facet_tokens)
                code.end_debug()
                code += ""
        
        sfc_debug("Leaving IntegralCG.gen_geometry_block")
        return str(code)
    
    def gen_quadrature_runtime_block(self, data):
        code = gen_const_token_definitions(self.itgrep.iter_runtime_quad_tokens(data))
        if code:
            code = "// Geometric quantities and coefficients\n" + code
        return code
    
    def gen_A_assignment_block(self, data):
        "Assigning values to element tensor (for pre-integrated tensor components)"
        # Only one symbolic integral supported:
        assert data.integral is self.itgrep.symbolic_integral
        
        if self.itgrep._on_facet:
            # TODO: Each facet block can be extracted as a separate function
            fr = self.itgrep.formrep
            code = CodeFormatter()
            code.begin_switch("facet")
            for facet in range(fr.cell.num_facets):        
                A_tokens = self.itgrep.iter_A_tokens(data, facet)
                code.begin_case(facet, braces=True)
                code += "// Integrated element tensor entries"
                code += gen_token_assignments(A_tokens)
                code.end_case()
            code.end_switch()
            code = str(code)
        
        else:
            A_tokens = self.itgrep.iter_A_tokens(data)
            code = gen_token_assignments(A_tokens)
            code = "// Integrated element tensor entries\n{\n%s\n}" % code
        
        return code

    def gen_A_reset_block(self):
        # Zero element tensor
        code  = "// Reset element tensor\n"
        code += "memset(A, 0, sizeof(double)*%d);" % len(self.itgrep.indices)
        return code
    
    def gen_quadrature_begin_block(self, data):
        "Beginning of quadrature loop."

        assert data.integral in self.itgrep.quadrature_integrals
        
        code = CodeFormatter()
        
        if self.itgrep._on_facet:
            num_points = data.facet_quad_rules[0].num_points
            code += "for(int iq=0; iq<%d; iq++)" % num_points
            code += "{"
            code.indent()
            code += "// Fetch quadrature rule"
            code += "const double quad_weight = facet_quad_weights[facet][iq];"
            code += "const double *_p = facet_quad_points[facet][iq];"
        else:
            num_points = data.quad_rule.num_points
            code += "for(int iq=0; iq<%d; iq++)" % num_points
            code += "{"
            code.indent()
            code += "// Fetch quadrature rule"
            code += "const double quad_weight = quad_weights[iq];"
            code += "const double *_p = quad_points[iq];"
        
        nsd = self.itgrep.formrep.cell.nsd
        if nsd > 0: code += "const double x = _p[0];"
        if nsd > 1: code += "const double y = _p[1];"
        if nsd > 2: code += "const double z = _p[2];"
        
        code.dedent()

        return str(code)
    
    def gen_A_accumulation_block(self, data):
        "Accumulation of nonzero element tensor entries."
        assert data.integral in self.itgrep.quadrature_integrals
        A_tokens = self.itgrep.iter_A_tokens(data)
        code = gen_token_additions((A_sym, A_expr) for (A_sym, A_expr) in A_tokens if A_expr != 0)
        assert code
        code = "// Accumulating element tensor entries\n{\n%s\n}" % indent(code)
        return code
    
    def gen_quadrature_end_block(self):
        return "}"

    def gen_symbol_allocation_block(self):
        if self.itgrep._symbol_counter > 0:
            code = "double s[%d];" % self.itgrep._symbol_counter
        else:
            code = ""
        return code
    
    def gen_tabulate_tensor(self):
        "Generic function to generate tabulate_tensor composed by reusable code blocks."
        # TODO: Improve code generation here for "robustness", several possibilities:
        # - Use a loop over both trial and test functions.
        # - Use a loop over trial functions, keep inlined application of test functions.
        # - Apply "outlining" to reduce function size, that is, generate helper functions
        #   instead of inlining all expressions explicitly.
        #   For example one for each integral and one for each block.
        
        sfc_debug("Entering IntegralCG.tabulate_tensor")

        fr = self.itgrep.formrep
        fd = fr.formdata
        r = fr.rank

        def generate_partition_blocks(data, bf_dep_list, known):
            pblocks = []
            known = frozenset(known)
            
            # Define sequence of blocks to handle
            todo = []
            for (iota, keep) in bf_dep_list:
                deps = known
                if keep is not None:
                    deps |= frozenset("v%d" % j for (j,k) in enumerate(keep) if k)
                if deps in data.partitions:
                    todo.append((deps, iota))
            
            # Handle each block in sequence
            for t in todo:
                deps, iota = t
                code = self.gen_partition_block(data, deps, iota)
                if code:
                    pblocks += [code]
            return pblocks
        
        # Build a list of code blocks
        blocks = []
        
        # Print debugging info about input arguments in code
        blocks += [self.gen_pre_debug_code()]
        
        # Geometry variables like G, Ginv, detG, n
        blocks += [self.gen_geometry_block()]
        
        # Symbolic integration part or element tensor reset
        integral = self.itgrep.symbolic_integral
        if integral is not None:
            data = self.itgrep.integral_data[integral.measure()]
            # TODO: Do we need any intermediate variable blocks?
            #       Maybe if we apply some optimizations.
            #       We could f.ex. compute the integrands to be
            #       symbolically integrated from the graph,
            #       with the same tokens precomputed as with quadrature.
            blocks += [self.gen_A_assignment_block(data)]
        else:
            blocks += [self.gen_A_reset_block()]
        
        # List of ways that expressions may depend on basis functions 
        bf_dep_list = [((), None)]
        
        if r == 1:
            m = fr.element_reps[fd.elements[0]].local_dimension
            bf_dep_list += [((i,), (True,)) for i in range(m)]
        elif r == 2:
            m = fr.element_reps[fd.elements[0]].local_dimension
            n = fr.element_reps[fd.elements[1]].local_dimension
            bf_dep_list += [((None, i), (False, True)) for i in range(m)]
            bf_dep_list += [((i, None), (True, False)) for i in range(m)]
            bf_dep_list += [((i,    j), (True, True))  for i in range(m) for j in range(n)]
        elif r > 2:
            sfc_error("Support for higher order tensors not implemented.")
        
        # Code for each quadrature integral
        for integral in self.itgrep.quadrature_integrals:
            # Data about this particular integral, including integration mode and computational graph
            data = self.itgrep.integral_data[integral.measure()]

            # FIXME: These should be precomputed, no dependencies 
            blocks += generate_partition_blocks(data, bf_dep_list, ())
            
            # Compute blocks of intermediate variables, s[...] = ...; that are independent of coordinates
            blocks += generate_partition_blocks(data, bf_dep_list, ("c",))
            
            # Begin quadrature loop
            blocks += [self.gen_quadrature_begin_block(data)]

            # Build quadrature loop body from blocks
            qblocks = []
            # Just for indentation similar to generated code :)
            if True:
                qblocks += [indent(self.gen_quadrature_runtime_block(data))]
                
                # FIXME: These should be precomputed for each quadrature point, no runtime dependencies
                qblocks += generate_partition_blocks(data, bf_dep_list, ("x",))
                
                # Compute blocks of intermediate variables, s[...] = ...;
                qblocks += generate_partition_blocks(data, bf_dep_list, ("c", "x"))

                # Accumulate element tensor values, A[...] += ...;
                qblocks += [self.gen_A_accumulation_block(data)]

            blocks += [indent(code) for code in qblocks]
            
            # End quadrature loop
            blocks += [self.gen_quadrature_end_block()]
        
        # Print debugging info about output values in code
        blocks += [self.gen_post_debug_code()]
        
        # Insert declaration for allocation of symbols at beginning
        code = self.gen_symbol_allocation_block()
        if code:
            blocks.insert(0, code)
        
        # Compose the final tabulate_tensor code!
        final_code = "\n\n".join(blocks)
        
        sfc_debug("Leaving IntegralCG.tabulate_tensor")
        return final_code

    def gen_tabulate_tensor_quadrature(self):
        return 'throw std::runtime_error("Not implemented.");' # FIXME
