# Copyright (c) 2021 ARM Limited
# All rights reserved.
#
# The license below extends only to copyright in the software and shall
# not be construed as granting a license to any other intellectual
# property including but not limited to intellectual property relating
# to a hardware implementation of the functionality of the software
# licensed hereunder.  You may use the software subject to the license
# terms below provided that you ensure that this notice is replicated
# unmodified and in its entirety in all distributions of the software,
# modified or unmodified, in source code or in binary form.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met: redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer;
# redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution;
# neither the name of the copyright holders nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#

import math

from m5.util import fatal
from m5.params import *
from m5.objects import *

from m5.defines import buildEnv
if buildEnv['PROTOCOL'] == 'CHI':
    import ruby.CHI as CHI

from topologies.BaseTopology import SimpleTopology

class CustomMesh(SimpleTopology):
    description = 'CustomMesh'

    def __init__(self, controllers):
        self.nodes = controllers

    #--------------------------------------------------------------------------
    # _makeMesh
    #--------------------------------------------------------------------------

    def _makeMesh(self, IntLink, link_latency, num_rows, num_columns,
                  cross_links, cross_link_latency):

        # East->West, West->East, North->South, South->North
        # XY routing weights
        link_weights = [1, 1, 2, 2]

        # East output to West input links
        for row in range(num_rows):
            for col in range(num_columns):
                if (col + 1 < num_columns):
                    east_out = col + (row * num_columns)
                    west_in = (col + 1) + (row * num_columns)
                    llat = cross_link_latency \
                                if (east_out, west_in) in cross_links \
                                else link_latency
                    self._int_links.append(\
                                IntLink(link_id=self._link_count,
                                        src_node=self._routers[east_out],
                                        dst_node=self._routers[west_in],
                                        dst_inport="West",
                                        latency = llat,
                                        weight=link_weights[0]))
                    self._link_count += 1

        # West output to East input links
        for row in range(num_rows):
            for col in range(num_columns):
                if (col + 1 < num_columns):
                    east_in = col + (row * num_columns)
                    west_out = (col + 1) + (row * num_columns)
                    llat = cross_link_latency \
                                if (west_out, east_in) in cross_links \
                                else link_latency
                    self._int_links.append(\
                                IntLink(link_id=self._link_count,
                                        src_node=self._routers[west_out],
                                        dst_node=self._routers[east_in],
                                        dst_inport="East",
                                        latency = llat,
                                        weight=link_weights[1]))
                    self._link_count += 1

        # North output to South input links
        for col in range(num_columns):
            for row in range(num_rows):
                if (row + 1 < num_rows):
                    north_out = col + (row * num_columns)
                    south_in = col + ((row + 1) * num_columns)
                    llat = cross_link_latency \
                            if (north_out, south_in) in cross_links \
                            else link_latency
                    self._int_links.append(\
                                IntLink(link_id=self._link_count,
                                        src_node=self._routers[north_out],
                                        dst_node=self._routers[south_in],
                                        dst_inport="South",
                                        latency = llat,
                                        weight=link_weights[2]))
                    self._link_count += 1

        # South output to North input links
        for col in range(num_columns):
            for row in range(num_rows):
                if (row + 1 < num_rows):
                    north_in = col + (row * num_columns)
                    south_out = col + ((row + 1) * num_columns)
                    llat = cross_link_latency \
                            if (south_out, north_in) in cross_links \
                            else link_latency
                    self._int_links.append(\
                                IntLink(link_id=self._link_count,
                                        src_node=self._routers[south_out],
                                        dst_node=self._routers[north_in],
                                        dst_inport="North",
                                        latency = llat,
                                        weight=link_weights[3]))
                    self._link_count += 1

    #--------------------------------------------------------------------------
    # distributeNodes
    #--------------------------------------------------------------------------

    def _createRNFRouter(self, mesh_router):
        # Create a zero-latency router bridging node controllers
        # and the mesh router
        node_router = self._Router(router_id = len(self._routers),
                                    latency = 0)
        self._routers.append(node_router)

        # connect node_router <-> mesh router
        self._int_links.append(self._IntLink( \
                                    link_id = self._link_count,
                                    src_node = node_router,
                                    dst_node = mesh_router,
                            latency = self._router_link_latency))
        self._link_count += 1

        self._int_links.append(self._IntLink( \
                                    link_id = self._link_count,
                                    src_node = mesh_router,
                                    dst_node = node_router,
                            latency = self._router_link_latency))
        self._link_count += 1

        return node_router

    def distributeNodes(self, num_nodes_per_router, router_idx_list,
                        node_list):

        if num_nodes_per_router:
            # evenly distribute nodes to all listed routers
            assert(len(router_idx_list)*num_nodes_per_router == len(node_list))

            for idx, node in enumerate(node_list):
                mesh_router_idx = router_idx_list[idx // num_nodes_per_router]
                router = self._routers[mesh_router_idx]

                # Create another router bridging RNF node controllers
                # and the mesh router
                # for non-RNF nodes, node router is mesh router
                if isinstance(node, CHI.CHI_RNF):
                    router = self._createRNFRouter(router)

                # connect all ctrls in the node to node_router
                ctrls = node.getNetworkSideControllers()
                for c in ctrls:
                    self._ext_links.append(self._ExtLink(
                                    link_id = self._link_count,
                                    ext_node = c,
                                    int_node = router,
                                    latency = self._node_link_latency))
                    self._link_count += 1
        else:
            # try to circulate all nodes to all routers, some routers may be
            # connected to zero or more than one node.
            idx = 0
            for node in node_list:
                ridx = router_idx_list[idx]
                router = self._routers[ridx]

                if isinstance(node, CHI.CHI_RNF):
                    router = self._createRNFRouter(router)
                ctrls = node.getNetworkSideControllers()
                for c in ctrls:
                    self._ext_links.append(self._ExtLink( \
                                                 link_id = self._link_count,
                                                 ext_node = c,
                                                 int_node = router,
                                            latency = self._node_link_latency))
                    self._link_count += 1
                idx = (idx + 1) % len(router_idx_list)

    #--------------------------------------------------------------------------
    # makeTopology
    #--------------------------------------------------------------------------

    def makeTopology(self, options, network, IntLink, ExtLink, Router):
        assert(buildEnv['PROTOCOL'] == 'CHI')

        num_rows = options.num_rows
        num_cols = options.num_cols
        num_mesh_routers = num_rows * num_cols

        self._IntLink = IntLink
        self._ExtLink = ExtLink
        self._Router = Router

        if hasattr(options, 'router_link_latency'):
            self._router_link_latency = options.router_link_latency
            self._node_link_latency = options.node_link_latency
        else:
            print("WARNING: router/node link latencies not provided")
            self._router_link_latency = options.link_latency
            self._node_link_latency = options.link_latency

        # classify nodes into different types
        rnf_list = []
        hnf_list = []
        mem_ctrls = []
        io_mem_ctrls = []
        io_rni_ctrls = []

        for n in self.nodes:
            if isinstance(n, CHI.CHI_RNF):
                rnf_list.append(n)
            elif isinstance(n, CHI.CHI_HNF):
                hnf_list.append(n)
            elif isinstance(n, CHI.CHI_SNF_MainMem):
                mem_ctrls.append(n)
            elif isinstance(n, CHI.CHI_SNF_BootMem):
                io_mem_ctrls.append(n)
            elif isinstance(n, CHI.CHI_RNI_DMA):
                io_rni_ctrls.append(n)
            elif isinstance(n, CHI.CHI_RNI_IO):
                io_rni_ctrls.append(n)
            else:
                fatal('topologies.CustomMesh: {} not supported'
                            .format(n.__class__.__name__))

        # Create all mesh routers
        self._routers = [Router(router_id=i, latency = options.router_latency)\
                                    for i in range(num_mesh_routers)]

        self._link_count = 0
        self._int_links = []
        self._ext_links = []

        # Create all the mesh internal links.
        self._makeMesh(IntLink, self._router_link_latency, num_rows, num_cols,
                       options.cross_links, options.cross_link_latency)

        # Place CHI_RNF on the mesh
        num_nodes_per_router = options.CHI_RNF['num_nodes_per_router'] \
                if 'num_nodes_per_router' in options.CHI_RNF else None
        self.distributeNodes(num_nodes_per_router,
                             options.CHI_RNF['router_list'],
                             rnf_list)

        # Place CHI_HNF on the mesh
        num_nodes_per_router = options.CHI_HNF['num_nodes_per_router'] \
                if 'num_nodes_per_router' in options.CHI_HNF else None
        self.distributeNodes(num_nodes_per_router,
                             options.CHI_HNF['router_list'],
                             hnf_list)

        # Place CHI_SNF_MainMem on the mesh
        num_nodes_per_router = options.CHI_SNF_MainMem['num_nodes_per_router']\
                if 'num_nodes_per_router' in options.CHI_SNF_MainMem else None
        self.distributeNodes(num_nodes_per_router,
                             options.CHI_SNF_MainMem['router_list'],
                             mem_ctrls)

        # Place all IO mem nodes on the mesh
        num_nodes_per_router = options.CHI_SNF_IO['num_nodes_per_router'] \
                if 'num_nodes_per_router' in options.CHI_SNF_IO else None
        self.distributeNodes(num_nodes_per_router,
                             options.CHI_SNF_IO['router_list'],
                             io_mem_ctrls)

        # Place all IO request nodes on the mesh
        num_nodes_per_router = options.CHI_RNI_IO['num_nodes_per_router'] \
                if 'num_nodes_per_router' in options.CHI_RNI_IO else None
        self.distributeNodes(num_nodes_per_router,
                             options.CHI_RNI_IO['router_list'],
                             io_rni_ctrls)

        # Set up
        network.int_links = self._int_links
        network.ext_links = self._ext_links
        network.routers = self._routers

        pairing = getattr(options, 'pairing', None)
        if pairing != None:
            self._autoPairHNFandSNF(hnf_list, mem_ctrls, pairing)

    #--------------------------------------------------------------------------
    # _autoPair
    #--------------------------------------------------------------------------
    def _autoPairHNFandSNF(self, cache_ctrls, mem_ctrls, pairing):
        # Use the pairing defined by the configuration to reassign the
        # memory ranges
        pair_debug = False

        print("Pairing HNFs to SNFs")
        print(pairing)

        all_cache = []
        for c in cache_ctrls: all_cache.extend(c.getNetworkSideControllers())
        all_mem = []
        for c in mem_ctrls: all_mem.extend(c.getNetworkSideControllers())

        # checks and maps index from pairing map to component
        assert(len(pairing) == len(all_cache))

        def _tolist(val): return val if isinstance(val, list) else [val]

        for m in all_mem: m._pairing = []

        pairing_check = max(1, len(all_mem) / len(all_cache))
        for cidx,c in enumerate(all_cache):
            c._pairing = []
            for midx in _tolist(pairing[cidx]):
                c._pairing.append(all_mem[midx])
                if c not in all_mem[midx]._pairing:
                    all_mem[midx]._pairing.append(c)
            assert(len(c._pairing) == pairing_check)
            if pair_debug:
                print(c.path())
                for r in c.addr_ranges:
                    print("%s" % r)
                for p in c._pairing:
                    print("\t"+p.path())
                    for r in p.addr_ranges:
                        print("\t%s" % r)

        # all must be paired
        for c in all_cache: assert(len(c._pairing) > 0)
        for m in all_mem: assert(len(m._pairing) > 0)

        # only support a single range for the main memory controllers
        tgt_range_start = all_mem[0].addr_ranges[0].start.value
        for mem in all_mem:
            for r in mem.addr_ranges:
                if r.start.value != tgt_range_start:
                    fatal('topologies.CustomMesh: not supporting pairing of '\
                          'main memory with multiple ranges')

        # reassign ranges for a 1 -> N paring
        def _rerange(src_cntrls, tgt_cntrls, fix_tgt_peer):
            assert(len(tgt_cntrls) >= len(src_cntrls))

            def _rangeToBit(addr_ranges):
                bit = None
                for r in addr_ranges:
                    if bit == None:
                        bit = r.intlvMatch
                    else:
                        assert(bit == r.intlvMatch)
                return bit

            def _getPeer(cntrl):
                return cntrl.memory_out_port.peer.simobj

            sorted_src = list(src_cntrls)
            sorted_src.sort(key = lambda x: _rangeToBit(x.addr_ranges))

            # paired controllers need to have seq. interleaving match values
            intlvMatch = 0
            for src in sorted_src:
                for tgt in src._pairing:
                    for r in tgt.addr_ranges:
                        r.intlvMatch = intlvMatch
                    if fix_tgt_peer:
                        _getPeer(tgt).range.intlvMatch = intlvMatch
                    intlvMatch = intlvMatch + 1

            # recreate masks
            for src in sorted_src:
                for src_range in src.addr_ranges:
                    if src_range.start.value != tgt_range_start:
                        continue
                    new_src_mask = []
                    for m in src_range.masks:
                        # TODO should mask all the way to the max range size
                        new_src_mask.append(m | (m*2) | (m*4) |
                                                  (m*8) | (m*16))
                    for tgt in src._pairing:
                        paired = False
                        for tgt_range in tgt.addr_ranges:
                            if tgt_range.start.value == \
                               src_range.start.value:
                                src_range.masks = new_src_mask
                                new_tgt_mask = []
                                lsbs = len(tgt_range.masks) - \
                                       len(new_src_mask)
                                for i in range(lsbs):
                                    new_tgt_mask.append(tgt_range.masks[i])
                                for m in new_src_mask:
                                    new_tgt_mask.append(m)
                                tgt_range.masks = new_tgt_mask
                                if fix_tgt_peer:
                                    _getPeer(tgt).range.masks = new_tgt_mask
                                paired = True
                        if not paired:
                            fatal('topologies.CustomMesh: could not ' \
                                    'reassign ranges {} {}'.format(
                                    src.path(), tgt.path()))
        if len(all_mem) >= len(all_cache):
            _rerange(all_cache, all_mem, True)
        else:
            _rerange(all_mem, all_cache, False)

        if pair_debug:
            print("")
            for cidx,c in enumerate(all_cache):
                assert(len(c._pairing) == pairing_check)
                print(c.path())
                for r in c.addr_ranges:
                    print("%s" % r)
                for p in c._pairing:
                    print("\t"+p.path())
                    for r in p.addr_ranges:
                        print("\t%s" % r)


