arch-arm: Allow TLB to be used as a WalkCache

This patch allows partial translation entries (intermediate PAs obtained
from a table walk) to be stored in an ArmTLB. This effectively means
reserving a fraction of the TLB entries to cache table walks

JIRA: https://gem5.atlassian.net/browse/GEM5-1108

Change-Id: Id0efb7d75dd017366c4c3b74de7b57355a53a01a
Signed-off-by: Giacomo Travaglini <giacomo.travaglini@arm.com>
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/52124
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Maintainer: Andreas Sandberg <andreas.sandberg@arm.com>
Tested-by: kokoro <noreply+kokoro@google.com>
diff --git a/src/arch/arm/ArmTLB.py b/src/arch/arm/ArmTLB.py
index 4c86f72..10ed48b 100644
--- a/src/arch/arm/ArmTLB.py
+++ b/src/arch/arm/ArmTLB.py
@@ -51,6 +51,11 @@
     size = Param.Int(64, "TLB size")
     is_stage2 = Param.Bool(False, "Is this a stage 2 TLB?")
 
+    partial_levels = VectorParam.ArmLookupLevel([],
+        "List of intermediate lookup levels allowed to be cached in the TLB "
+        "(=holding intermediate PAs obtained during a table walk")
+
+
 class ArmStage2TLB(ArmTLB):
     size = 32
     is_stage2 = True
diff --git a/src/arch/arm/mmu.cc b/src/arch/arm/mmu.cc
index cffbb20..f4a2114 100644
--- a/src/arch/arm/mmu.cc
+++ b/src/arch/arm/mmu.cc
@@ -67,6 +67,7 @@
     s1State(this, false), s2State(this, true),
     _attr(0),
     _release(nullptr),
+    _hasWalkCache(false),
     stats(this)
 {
     // Cache system-level properties
@@ -102,6 +103,27 @@
     getDTBPtr()->setTableWalker(dtbWalker);
 
     BaseMMU::init();
+
+    _hasWalkCache = checkWalkCache();
+}
+
+bool
+MMU::checkWalkCache() const
+{
+    for (auto tlb : instruction) {
+        if (static_cast<TLB*>(tlb)->walkCache())
+            return true;
+    }
+    for (auto tlb : data) {
+        if (static_cast<TLB*>(tlb)->walkCache())
+            return true;
+    }
+    for (auto tlb : unified) {
+        if (static_cast<TLB*>(tlb)->walkCache())
+            return true;
+    }
+
+    return false;
 }
 
 void
@@ -858,11 +880,11 @@
     Fault fault = getResultTe(&te, req, tc, mode, translation, timing,
                               functional, &mergeTe, state);
     // only proceed if we have a valid table entry
-    if ((te == NULL) && (fault == NoFault)) delay = true;
+    if (!isCompleteTranslation(te) && (fault == NoFault)) delay = true;
 
     // If we have the table entry transfer some of the attributes to the
     // request that triggered the translation
-    if (te != NULL) {
+    if (isCompleteTranslation(te)) {
         // Set memory attributes
         DPRINTF(TLBVerbose,
                 "Setting memory attributes: shareable: %d, innerAttrs: %d, "
@@ -1429,7 +1451,7 @@
     *te = lookup(vaddr, state.asid, state.vmid, state.isHyp, is_secure, false,
                  false, target_el, false, state.isStage2, mode);
 
-    if (*te == NULL) {
+    if (!isCompleteTranslation(*te)) {
         if (req->isPrefetch()) {
             // if the request is a prefetch don't attempt to fill the TLB or go
             // any further with the memory access (here we can safely use the
@@ -1474,12 +1496,12 @@
     if (state.isStage2) {
         // We are already in the stage 2 TLB. Grab the table entry for stage
         // 2 only. We are here because stage 1 translation is disabled.
-        TlbEntry *s2_te = NULL;
+        TlbEntry *s2_te = nullptr;
         // Get the stage 2 table entry
         fault = getTE(&s2_te, req, tc, mode, translation, timing, functional,
                       state.isSecure, state.curTranType, state);
         // Check permissions of stage 2
-        if ((s2_te != NULL) && (fault == NoFault)) {
+        if (isCompleteTranslation(s2_te) && (fault == NoFault)) {
             if (state.aarch64)
                 fault = checkPermissions64(s2_te, req, mode, tc, state);
             else
@@ -1489,22 +1511,22 @@
         return fault;
     }
 
-    TlbEntry *s1Te = NULL;
+    TlbEntry *s1_te = nullptr;
 
     Addr vaddr_tainted = req->getVaddr();
 
     // Get the stage 1 table entry
-    fault = getTE(&s1Te, req, tc, mode, translation, timing, functional,
+    fault = getTE(&s1_te, req, tc, mode, translation, timing, functional,
                   state.isSecure, state.curTranType, state);
     // only proceed if we have a valid table entry
-    if ((s1Te != NULL) && (fault == NoFault)) {
+    if (isCompleteTranslation(s1_te) && (fault == NoFault)) {
         // Check stage 1 permissions before checking stage 2
         if (state.aarch64)
-            fault = checkPermissions64(s1Te, req, mode, tc, state);
+            fault = checkPermissions64(s1_te, req, mode, tc, state);
         else
-            fault = checkPermissions(s1Te, req, mode, state);
+            fault = checkPermissions(s1_te, req, mode, state);
         if (state.stage2Req & (fault == NoFault)) {
-            Stage2LookUp *s2_lookup = new Stage2LookUp(this, *s1Te,
+            Stage2LookUp *s2_lookup = new Stage2LookUp(this, *s1_te,
                 req, translation, mode, timing, functional, state.isSecure,
                 state.curTranType);
             fault = s2_lookup->getTe(tc, mergeTe);
@@ -1531,12 +1553,18 @@
                     arm_fault->annotate(ArmFault::OVA, vaddr_tainted);
                 }
             }
-            *te = s1Te;
+            *te = s1_te;
         }
     }
     return fault;
 }
 
+bool
+MMU::isCompleteTranslation(TlbEntry *entry) const
+{
+    return entry && !entry->partial;
+}
+
 void
 MMU::takeOverFrom(BaseMMU *old_mmu)
 {
diff --git a/src/arch/arm/mmu.hh b/src/arch/arm/mmu.hh
index 2f353c7..fc5b307 100644
--- a/src/arch/arm/mmu.hh
+++ b/src/arch/arm/mmu.hh
@@ -344,6 +344,8 @@
 
     const ArmRelease* release() const { return _release; }
 
+    bool hasWalkCache() const { return _hasWalkCache; }
+
     /**
      * Determine the EL to use for the purpose of a translation given
      * a specific translation type. If the translation type doesn't
@@ -425,6 +427,15 @@
                    LookupLevel lookup_level, CachedState &state);
 
   protected:
+    bool checkWalkCache() const;
+
+    bool isCompleteTranslation(TlbEntry *te) const;
+
+    CachedState& updateMiscReg(
+        ThreadContext *tc, ArmTranslationType tran_type,
+        bool stage2);
+
+  protected:
     ContextID miscRegContext;
 
   public:
@@ -440,9 +451,7 @@
 
     AddrRange m5opRange;
 
-    CachedState& updateMiscReg(
-        ThreadContext *tc, ArmTranslationType tran_type,
-        bool stage2);
+    bool _hasWalkCache;
 
     struct Stats : public statistics::Group
     {
diff --git a/src/arch/arm/tlb.cc b/src/arch/arm/tlb.cc
index e2897f8..a7c3f12 100644
--- a/src/arch/arm/tlb.cc
+++ b/src/arch/arm/tlb.cc
@@ -61,9 +61,33 @@
 TLB::TLB(const ArmTLBParams &p)
     : BaseTLB(p), table(new TlbEntry[p.size]), size(p.size),
       isStage2(p.is_stage2),
+      _walkCache(false),
       tableWalker(nullptr),
       stats(*this), rangeMRU(1), vmid(0)
 {
+    for (int lvl = LookupLevel::L0;
+         lvl < LookupLevel::Num_ArmLookupLevel; lvl++) {
+
+        auto it = std::find(
+            p.partial_levels.begin(),
+            p.partial_levels.end(),
+            lvl);
+
+        auto lookup_lvl = static_cast<LookupLevel>(lvl);
+
+        if (it != p.partial_levels.end()) {
+            // A partial entry from of the current LookupLevel can be
+            // cached within the TLB
+            partialLevels[lookup_lvl] = true;
+
+            // Make sure this is not the last level (complete translation)
+            if (lvl != LookupLevel::Num_ArmLookupLevel - 1) {
+                _walkCache = true;
+            }
+        } else {
+            partialLevels[lookup_lvl] = false;
+        }
+    }
 }
 
 TLB::~TLB()
@@ -79,32 +103,63 @@
 }
 
 TlbEntry*
-TLB::lookup(const Lookup &lookup_data)
+TLB::match(const Lookup &lookup_data)
 {
-    TlbEntry *retval = NULL;
-    const auto functional = lookup_data.functional;
-    const auto mode = lookup_data.mode;
+    // Vector of TLB entry candidates.
+    // Only one of them will be assigned to retval and will
+    // be returned to the MMU (in case of a hit)
+    // The vector has one entry per lookup level as it stores
+    // both complete and partial matches
+    std::vector<std::pair<int, const TlbEntry*>> hits{
+        LookupLevel::Num_ArmLookupLevel, {0, nullptr}};
 
-    // Maintaining LRU array
     int x = 0;
-    while (retval == NULL && x < size) {
+    while (x < size) {
         if (table[x].match(lookup_data)) {
-            // We only move the hit entry ahead when the position is higher
-            // than rangeMRU
-            if (x > rangeMRU && !functional) {
-                TlbEntry tmp_entry = table[x];
-                for (int i = x; i > 0; i--)
-                    table[i] = table[i - 1];
-                table[0] = tmp_entry;
-                retval = &table[0];
-            } else {
-                retval = &table[x];
-            }
-            break;
+            const TlbEntry &entry = table[x];
+            hits[entry.lookupLevel] = std::make_pair(x, &entry);
+
+            // This is a complete translation, no need to loop further
+            if (!entry.partial)
+                break;
         }
         ++x;
     }
 
+    // Loop over the list of TLB entries matching our translation
+    // request, starting from the highest lookup level (complete
+    // translation) and iterating backwards (using reverse iterators)
+    for (auto it = hits.rbegin(); it != hits.rend(); it++) {
+        const auto& [idx, entry] = *it;
+        if (!entry) {
+            // No match for the current LookupLevel
+            continue;
+        }
+
+        // Maintaining LRU array
+        // We only move the hit entry ahead when the position is higher
+        // than rangeMRU
+        if (idx > rangeMRU && !lookup_data.functional) {
+            TlbEntry tmp_entry = *entry;
+            for (int i = idx; i > 0; i--)
+                table[i] = table[i - 1];
+            table[0] = tmp_entry;
+            return &table[0];
+        } else {
+            return &table[idx];
+        }
+    }
+
+    return nullptr;
+}
+
+TlbEntry*
+TLB::lookup(const Lookup &lookup_data)
+{
+    const auto mode = lookup_data.mode;
+
+    TlbEntry *retval = match(lookup_data);
+
     DPRINTF(TLBVerbose, "Lookup %#x, asn %#x -> %s vmn 0x%x hyp %d secure %d "
             "ppn %#x size: %#x pa: %#x ap:%d ns:%d nstid:%d g:%d asid: %d "
             "el: %d\n",
@@ -118,21 +173,27 @@
             retval ? retval->el        : 0);
 
     // Updating stats if this was not a functional lookup
-    if (!functional) {
+    if (!lookup_data.functional) {
         if (!retval) {
-            if (mode == BaseMMU::Execute)
+            if (mode == BaseMMU::Execute) {
                 stats.instMisses++;
-            else if (mode == BaseMMU::Write)
+            } else if (mode == BaseMMU::Write) {
                 stats.writeMisses++;
-            else
+            } else {
                 stats.readMisses++;
+            }
         } else {
-            if (mode == BaseMMU::Execute)
+            if (retval->partial) {
+                stats.partialHits++;
+            }
+
+            if (mode == BaseMMU::Execute) {
                 stats.instHits++;
-            else if (mode == BaseMMU::Write)
+            } else if (mode == BaseMMU::Write) {
                stats.writeHits++;
-            else
+            } else {
                 stats.readHits++;
+            }
         }
     }
 
@@ -149,8 +210,13 @@
     } else {
         if (auto tlb = static_cast<TLB*>(nextLevel())) {
             te = tlb->multiLookup(lookup_data);
-            if (te && !lookup_data.functional)
+            if (te && !lookup_data.functional &&
+                (!te->partial || partialLevels[te->lookupLevel])) {
+                // Insert entry only if this is not a functional
+                // lookup and if the translation is complete (unless this
+                // TLB caches partial translations)
                 insert(*te);
+            }
         }
     }
 
@@ -203,7 +269,11 @@
 void
 TLB::multiInsert(TlbEntry &entry)
 {
-    insert(entry);
+    // Insert a partial translation only if the TLB is configured
+    // as a walk cache
+    if (!entry.partial || partialLevels[entry.lookupLevel]) {
+        insert(entry);
+    }
 
     if (auto next_level = static_cast<TLB*>(nextLevel())) {
         next_level->multiInsert(entry);
@@ -233,9 +303,11 @@
     while (x < size) {
         te = &table[x];
 
-        DPRINTF(TLB, " -  %s\n", te->print());
-        te->valid = false;
-        stats.flushedEntries++;
+        if (te->valid) {
+            DPRINTF(TLB, " -  %s\n", te->print());
+            te->valid = false;
+            stats.flushedEntries++;
+        }
         ++x;
     }
 
@@ -559,6 +631,8 @@
 
 TLB::TlbStats::TlbStats(TLB &parent)
   : statistics::Group(&parent), tlb(parent),
+    ADD_STAT(partialHits, statistics::units::Count::get(),
+             "partial translation hits"),
     ADD_STAT(instHits, statistics::units::Count::get(), "Inst hits"),
     ADD_STAT(instMisses, statistics::units::Count::get(), "Inst misses"),
     ADD_STAT(readHits, statistics::units::Count::get(), "Read hits"),
@@ -615,6 +689,8 @@
         readAccesses.flags(statistics::nozero);
         writeAccesses.flags(statistics::nozero);
     }
+
+    partialHits.flags(statistics::nozero);
 }
 
 void
diff --git a/src/arch/arm/tlb.hh b/src/arch/arm/tlb.hh
index 8eb7d01..40ad76c 100644
--- a/src/arch/arm/tlb.hh
+++ b/src/arch/arm/tlb.hh
@@ -115,9 +115,25 @@
 class TLB : public BaseTLB
 {
   protected:
-    TlbEntry* table;     // the Page Table
-    int size;            // TLB Size
-    bool isStage2;       // Indicates this TLB is part of the second stage MMU
+    TlbEntry* table;
+
+    /** TLB Size */
+    int size;
+
+    /** Indicates this TLB caches IPA->PA translations */
+    bool isStage2;
+
+    /**
+     * Hash map containing one entry per lookup level
+     * The TLB is caching partial translations from the key lookup level
+     * if the matching value is true.
+     */
+    std::unordered_map<enums::ArmLookupLevel, bool> partialLevels;
+
+    /**
+     * True if the TLB caches partial translations
+     */
+    bool _walkCache;
 
     TableWalker *tableWalker;
 
@@ -128,6 +144,7 @@
         const TLB &tlb;
 
         // Access Stats
+        mutable statistics::Scalar partialHits;
         mutable statistics::Scalar instHits;
         mutable statistics::Scalar instMisses;
         mutable statistics::Scalar readHits;
@@ -186,6 +203,8 @@
 
     int getsize() const { return size; }
 
+    bool walkCache() const { return _walkCache; }
+
     void setVMID(vmid_t _vmid) { vmid = _vmid; }
 
     /** Insert a PTE in the current TLB */
@@ -313,6 +332,10 @@
      * data access or a data TLB entry on an instruction access:
      */
     void checkPromotion(TlbEntry *entry, BaseMMU::Mode mode);
+
+    /** Helper function looking up for a matching TLB entry
+     * Does not update stats; see lookup method instead */
+    TlbEntry *match(const Lookup &lookup_data);
 };
 
 } // namespace ArmISA