arch-arm: Allowing table descriptor to be inserted in TLB

This patch is modifying both TableWalker and MMU to effectively
store/use partial translations

* TableWalker changes: If there is a TLB supporting partial
translations (implemented with previous patch), the TableWalker will
craft partial entries and forward them to the TLB as walks are performed

* MMU changes: We now instruct the table walker to start a page
table traversal even if we hit in the TLB, if the matching entry
holds a partial translation

JIRA: https://gem5.atlassian.net/browse/GEM5-1108

Change-Id: Id20aaf4ea02960d50d8345f3e174c698af21ad1c
Signed-off-by: Giacomo Travaglini <giacomo.travaglini@arm.com>
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/52125
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Maintainer: Andreas Sandberg <andreas.sandberg@arm.com>
Tested-by: kokoro <noreply+kokoro@google.com>
diff --git a/src/arch/arm/mmu.cc b/src/arch/arm/mmu.cc
index f4a2114..79a5240 100644
--- a/src/arch/arm/mmu.cc
+++ b/src/arch/arm/mmu.cc
@@ -1471,7 +1471,7 @@
         fault = getTableWalker(mode, state.isStage2)->walk(
             req, tc, state.asid, state.vmid, state.isHyp, mode,
             translation, timing, functional, is_secure,
-            tran_type, state.stage2DescReq);
+            tran_type, state.stage2DescReq, *te);
 
         // for timing mode, return and wait for table walk,
         if (timing || fault != NoFault) {
diff --git a/src/arch/arm/table_walker.cc b/src/arch/arm/table_walker.cc
index c70d856..f94e765 100644
--- a/src/arch/arm/table_walker.cc
+++ b/src/arch/arm/table_walker.cc
@@ -290,7 +290,7 @@
                   vmid_t _vmid, bool _isHyp, MMU::Mode _mode,
                   MMU::Translation *_trans, bool _timing, bool _functional,
                   bool secure, MMU::ArmTranslationType tranType,
-                  bool _stage2Req)
+                  bool _stage2Req, const TlbEntry *walk_entry)
 {
     assert(!(_functional && _timing));
     ++stats.walks;
@@ -344,6 +344,11 @@
     }
     currState->transState = _trans;
     currState->req = _req;
+    if (walk_entry) {
+        currState->walkEntry = *walk_entry;
+    } else {
+        currState->walkEntry = TlbEntry();
+    }
     currState->fault = NoFault;
     currState->asid = _asid;
     currState->vmid = _vmid;
@@ -494,11 +499,15 @@
     // a match, we just want to get rid of the walk. The latter could happen
     // when there are multiple outstanding misses to a single page and a
     // previous request has been successfully translated.
-    if (!currState->transState->squashed() && !te) {
+    if (!currState->transState->squashed() && (!te || te->partial)) {
         // We've got a valid request, lets process it
         pending = true;
         pendingQueue.pop_front();
         // Keep currState in case one of the processWalk... calls NULLs it
+
+        if (te && te->partial) {
+            currState->walkEntry = *te;
+        }
         WalkerState *curr_state_copy = currState;
         Fault f;
         if (currState->aarch64)
@@ -897,7 +906,6 @@
         currState->tcr,
         currState->el);
 
-    LookupLevel start_lookup_level = LookupLevel::Num_ArmLookupLevel;
     bool vaddr_fault = false;
     switch (currState->el) {
       case EL0:
@@ -959,11 +967,6 @@
             tsz = 64 - currState->vtcr.t0sz64;
             tg = GrainMap_tg0[currState->vtcr.tg0];
 
-            start_lookup_level = getPageTableOps(tg)->firstS2Level(
-                currState->vtcr.sl0);
-
-            panic_if(start_lookup_level == LookupLevel::Num_ArmLookupLevel,
-                     "Cannot discern lookup level from vtcr.{sl0,tg0}");
             ps = currState->vtcr.ps;
             currState->isUncacheable = currState->vtcr.irgn0 == 0;
         } else {
@@ -1096,15 +1099,6 @@
         tg = Grain4KB;
     }
 
-    // Determine starting lookup level
-    if (start_lookup_level == LookupLevel::Num_ArmLookupLevel) {
-        const auto* ptops = getPageTableOps(tg);
-
-        start_lookup_level = ptops->firstLevel(64 - tsz);
-        panic_if(start_lookup_level == LookupLevel::Num_ArmLookupLevel,
-                 "Table walker couldn't find lookup level\n");
-    }
-
     // Clamp to lower limit
     int pa_range = decodePhysAddrRange64(ps);
     if (pa_range > _physAddrRange) {
@@ -1113,22 +1107,12 @@
         currState->physAddrRange = pa_range;
     }
 
-    // Determine table base address
-    int stride = tg - 3;
-    int base_addr_lo = 3 + tsz - stride * (3 - start_lookup_level) - tg;
-    Addr base_addr = 0;
-
-    if (pa_range == 52) {
-        int z = (base_addr_lo < 6) ? 6 : base_addr_lo;
-        base_addr = mbits(ttbr, 47, z);
-        base_addr |= (bits(ttbr, 5, 2) << 48);
-    } else {
-        base_addr = mbits(ttbr, 47, base_addr_lo);
-    }
+    auto [table_addr, desc_addr, start_lookup_level] = walkAddresses(
+        ttbr, tg, tsz, pa_range);
 
     // Determine physical address size and raise an Address Size Fault if
     // necessary
-    if (checkAddrSizeFaultAArch64(base_addr, currState->physAddrRange)) {
+    if (checkAddrSizeFaultAArch64(table_addr, currState->physAddrRange)) {
         DPRINTF(TLB, "Address size fault before any lookup\n");
         Fault f;
         if (currState->isFetch)
@@ -1159,11 +1143,6 @@
 
     }
 
-    // Determine descriptor address
-    Addr desc_addr = base_addr |
-        (bits(currState->vaddr, tsz - 1,
-              stride * (3 - start_lookup_level) + tg) << 3);
-
     // Trickbox address check
     Fault f = testWalk(desc_addr, sizeof(uint64_t),
                        TlbEntry::DomainType::NoAccess, start_lookup_level, isStage2);
@@ -1208,6 +1187,54 @@
     return f;
 }
 
+std::tuple<Addr, Addr, TableWalker::LookupLevel>
+TableWalker::walkAddresses(Addr ttbr, GrainSize tg, int tsz, int pa_range)
+{
+    const auto* ptops = getPageTableOps(tg);
+
+    LookupLevel first_level = LookupLevel::Num_ArmLookupLevel;
+    Addr table_addr = 0;
+    Addr desc_addr = 0;
+
+    if (currState->walkEntry.valid) {
+        // WalkCache hit
+        TlbEntry* entry = &currState->walkEntry;
+        DPRINTF(PageTableWalker,
+                "Walk Cache hit: va=%#x, level=%d, table address=%#x\n",
+                currState->vaddr, entry->lookupLevel, entry->pfn);
+
+        currState->xnTable = entry->xn;
+        currState->pxnTable = entry->pxn;
+        currState->rwTable = bits(entry->ap, 1);
+        currState->userTable = bits(entry->ap, 0);
+
+        table_addr = entry->pfn;
+        first_level = (LookupLevel)(entry->lookupLevel + 1);
+    } else {
+        // WalkCache miss
+        first_level = isStage2 ?
+            ptops->firstS2Level(currState->vtcr.sl0) :
+            ptops->firstLevel(64 - tsz);
+        panic_if(first_level == LookupLevel::Num_ArmLookupLevel,
+                 "Table walker couldn't find lookup level\n");
+
+        int stride = tg - 3;
+        int base_addr_lo = 3 + tsz - stride * (3 - first_level) - tg;
+
+        if (pa_range == 52) {
+            int z = (base_addr_lo < 6) ? 6 : base_addr_lo;
+            table_addr = mbits(ttbr, 47, z);
+            table_addr |= (bits(ttbr, 5, 2) << 48);
+        } else {
+            table_addr = mbits(ttbr, 47, base_addr_lo);
+        }
+    }
+
+    desc_addr = table_addr + ptops->index(currState->vaddr, first_level, tsz);
+
+    return std::make_tuple(table_addr, desc_addr, first_level);
+}
+
 void
 TableWalker::memAttrs(ThreadContext *tc, TlbEntry &te, SCTLR sctlr,
                       uint8_t texcb, bool s)
@@ -1890,6 +1917,11 @@
                 return;
             }
 
+            if (mmu->hasWalkCache()) {
+                insertPartialTableEntry(currState->longDesc);
+            }
+
+
             Request::Flags flag = Request::PT_WALK;
             if (currState->secureLookup)
                 flag.set(Request::SECURE);
@@ -2253,6 +2285,55 @@
 }
 
 void
+TableWalker::insertPartialTableEntry(LongDescriptor &descriptor)
+{
+    const bool have_security = release->has(ArmExtension::SECURITY);
+    TlbEntry te;
+
+    // Create and fill a new page table entry
+    te.valid          = true;
+    te.longDescFormat = true;
+    te.partial        = true;
+    te.global         = false;
+    te.isHyp          = currState->isHyp;
+    te.asid           = currState->asid;
+    te.vmid           = currState->vmid;
+    te.N              = descriptor.offsetBits();
+    te.vpn            = currState->vaddr >> te.N;
+    te.size           = (1ULL << te.N) - 1;
+    te.pfn            = descriptor.nextTableAddr();
+    te.domain         = descriptor.domain();
+    te.lookupLevel    = descriptor.lookupLevel;
+    te.ns             = !descriptor.secure(have_security, currState);
+    te.nstid          = !currState->isSecure;
+    te.type           = TypeTLB::unified;
+
+    if (currState->aarch64)
+        te.el         = currState->el;
+    else
+        te.el         = EL1;
+
+    te.xn = currState->xnTable;
+    te.pxn = currState->pxnTable;
+    te.ap = (currState->rwTable << 1) | (currState->userTable);
+
+    // Debug output
+    DPRINTF(TLB, descriptor.dbgHeader().c_str());
+    DPRINTF(TLB, " - N:%d pfn:%#x size:%#x global:%d valid:%d\n",
+            te.N, te.pfn, te.size, te.global, te.valid);
+    DPRINTF(TLB, " - vpn:%#x xn:%d pxn:%d ap:%d domain:%d asid:%d "
+            "vmid:%d hyp:%d nc:%d ns:%d\n", te.vpn, te.xn, te.pxn,
+            te.ap, static_cast<uint8_t>(te.domain), te.asid, te.vmid, te.isHyp,
+            te.nonCacheable, te.ns);
+    DPRINTF(TLB, " - domain from L%d desc:%d data:%#x\n",
+            descriptor.lookupLevel, static_cast<uint8_t>(descriptor.domain()),
+            descriptor.getRawData());
+
+    // Insert the entry into the TLBs
+    tlb->multiInsert(te);
+}
+
+void
 TableWalker::insertTableEntry(DescriptorBase &descriptor, bool long_descriptor)
 {
     const bool have_security = release->has(ArmExtension::SECURITY);
diff --git a/src/arch/arm/table_walker.hh b/src/arch/arm/table_walker.hh
index 554bc2b..28d3d4d 100644
--- a/src/arch/arm/table_walker.hh
+++ b/src/arch/arm/table_walker.hh
@@ -450,12 +450,17 @@
         std::string
         dbgHeader() const override
         {
-            if (type() == LongDescriptor::Page) {
+            switch (type()) {
+              case LongDescriptor::Page:
                 assert(lookupLevel == LookupLevel::L3);
                 return "Inserting Page descriptor into TLB\n";
-            } else {
+              case LongDescriptor::Block:
                 assert(lookupLevel < LookupLevel::L3);
                 return "Inserting Block descriptor into TLB\n";
+              case LongDescriptor::Table:
+                return "Inserting Table descriptor into TLB\n";
+              default:
+                panic("Trying to insert and invalid descriptor\n");
             }
         }
 
@@ -466,8 +471,12 @@
         bool
         secure(bool have_security, WalkerState *currState) const override
         {
-            assert(type() == Block || type() == Page);
-            return have_security && (currState->secureLookup && !bits(data, 5));
+            if (type() == Block || type() == Page) {
+                return have_security &&
+                    (currState->secureLookup && !bits(data, 5));
+            } else {
+                return have_security && currState->secureLookup;
+            }
         }
 
         /** Return the descriptor type */
@@ -537,9 +546,11 @@
                     default:
                         panic("Invalid AArch64 VM granule size\n");
                 }
-            } else {
-                panic("AArch64 page table entry must be block or page\n");
+            } else if (type() == Table) {
+                const auto* ptops = getPageTableOps(grainSize);
+                return ptops->walkBits(lookupLevel);
             }
+            panic("AArch64 page table entry must be block or page\n");
         }
 
         /** Return the physical frame, bits shifted right */
@@ -700,7 +711,6 @@
         domain() const override
         {
             // Long-desc. format only supports Client domain
-            assert(type() == Block || type() == Page);
             return TlbEntry::DomainType::Client;
         }
 
@@ -804,6 +814,9 @@
         /** Request that is currently being serviced */
         RequestPtr req;
 
+        /** Initial walk entry allowing to skip lookup levels */
+        TlbEntry walkEntry;
+
         /** ASID that we're servicing the request under */
         uint16_t asid;
         vmid_t vmid;
@@ -1088,9 +1101,10 @@
 
     Fault walk(const RequestPtr &req, ThreadContext *tc,
                uint16_t asid, vmid_t _vmid,
-               bool _isHyp, BaseMMU::Mode mode, BaseMMU::Translation *_trans,
+               bool hyp, BaseMMU::Mode mode, BaseMMU::Translation *_trans,
                bool timing, bool functional, bool secure,
-               MMU::ArmTranslationType tranType, bool _stage2Req);
+               MMU::ArmTranslationType tran_type, bool stage2,
+               const TlbEntry *walk_entry);
 
     void setMmu(MMU *_mmu);
     void setTlb(TLB *_tlb) { tlb = _tlb; }
@@ -1135,6 +1149,15 @@
     Fault generateLongDescFault(ArmFault::FaultSource src);
 
     void insertTableEntry(DescriptorBase &descriptor, bool longDescriptor);
+    void insertPartialTableEntry(LongDescriptor &descriptor);
+
+    /** Returns a tuple made of:
+     * 1) The address of the first page table
+     * 2) The address of the first descriptor within the table
+     * 3) The page table level
+     */
+    std::tuple<Addr, Addr, LookupLevel> walkAddresses(
+        Addr ttbr, GrainSize tg, int tsz, int pa_range);
 
     Fault processWalk();
     Fault processWalkLPAE();