arch-riscv: Add support for fault handling

This patch adds support for handling RISC-V faults, including tracking
current and previous execution privilege, correctly switching to
the privilege mode specified by CSRs, and setting/storing the PC.  It
also includes changes introduced by patch #9821, which disables
interrupts during handling of a fault.

Change-Id: Ie9c0f29719620c20783540d3bdb2db44f6114fc9
Reviewed-on: https://gem5-review.googlesource.com/9161
Maintainer: Alec Roelke <ar4jc@virginia.edu>
Reviewed-by: Jason Lowe-Power <jason@lowepower.com>
diff --git a/src/arch/riscv/faults.cc b/src/arch/riscv/faults.cc
index ce4cb38..efab6c4 100644
--- a/src/arch/riscv/faults.cc
+++ b/src/arch/riscv/faults.cc
@@ -32,6 +32,8 @@
  */
 #include "arch/riscv/faults.hh"
 
+#include "arch/riscv/isa.hh"
+#include "arch/riscv/registers.hh"
 #include "arch/riscv/system.hh"
 #include "arch/riscv/utility.hh"
 #include "cpu/base.hh"
@@ -39,10 +41,11 @@
 #include "sim/debug.hh"
 #include "sim/full_system.hh"
 
-using namespace RiscvISA;
+namespace RiscvISA
+{
 
 void
-RiscvFault::invoke_se(ThreadContext *tc, const StaticInstPtr &inst)
+RiscvFault::invokeSE(ThreadContext *tc, const StaticInstPtr &inst)
 {
     panic("Fault %s encountered at pc 0x%016llx.", name(), tc->pcState().pc());
 }
@@ -50,14 +53,71 @@
 void
 RiscvFault::invoke(ThreadContext *tc, const StaticInstPtr &inst)
 {
+    PCState pcState = tc->pcState();
+
     if (FullSystem) {
-        panic("Full system mode not supported for RISC-V.");
+        PrivilegeMode pp = (PrivilegeMode)tc->readMiscReg(MISCREG_PRV);
+        PrivilegeMode prv = PRV_M;
+        STATUS status = tc->readMiscReg(MISCREG_STATUS);
+
+        // Set fault handler privilege mode
+        if (pp != PRV_M &&
+            bits(tc->readMiscReg(MISCREG_MEDELEG), _code) != 0) {
+            prv = PRV_S;
+        }
+        if (pp == PRV_U &&
+            bits(tc->readMiscReg(MISCREG_SEDELEG), _code) != 0) {
+            prv = PRV_U;
+        }
+
+        // Set fault registers and status
+        MiscRegIndex cause, epc, tvec;
+        switch (prv) {
+          case PRV_U:
+            cause = MISCREG_UCAUSE;
+            epc = MISCREG_UEPC;
+            tvec = MISCREG_UTVEC;
+
+            status.upie = status.uie;
+            status.uie = 0;
+            break;
+          case PRV_S:
+            cause = MISCREG_SCAUSE;
+            epc = MISCREG_SEPC;
+            tvec = MISCREG_STVEC;
+
+            status.spp = pp;
+            status.spie = status.sie;
+            status.sie = 0;
+            break;
+          case PRV_M:
+            cause = MISCREG_MCAUSE;
+            epc = MISCREG_MEPC;
+            tvec = MISCREG_MTVEC;
+
+            status.mpp = pp;
+            status.mpie = status.sie;
+            status.mie = 0;
+            break;
+          default:
+            panic("Unknown privilege mode %d.", prv);
+            break;
+        }
+
+        // Set fault cause, privilege, and return PC
+        tc->setMiscReg(cause,
+                       (isInterrupt() << (sizeof(MiscReg) * 4 - 1)) | _code);
+        tc->setMiscReg(epc, tc->instAddr());
+        tc->setMiscReg(MISCREG_PRV, prv);
+        tc->setMiscReg(MISCREG_STATUS, status);
+
+        // Set PC to fault handler address
+        pcState.set(tc->readMiscReg(tvec) >> 2);
     } else {
-        invoke_se(tc, inst);
-        PCState pcState = tc->pcState();
+        invokeSE(tc, inst);
         advancePC(pcState, inst);
-        tc->pcState(pcState);
     }
+    tc->pcState(pcState);
 }
 
 void Reset::invoke(ThreadContext *tc, const StaticInstPtr &inst)
@@ -73,21 +133,21 @@
 }
 
 void
-UnknownInstFault::invoke_se(ThreadContext *tc, const StaticInstPtr &inst)
+UnknownInstFault::invokeSE(ThreadContext *tc, const StaticInstPtr &inst)
 {
     panic("Unknown instruction 0x%08x at pc 0x%016llx", inst->machInst,
         tc->pcState().pc());
 }
 
 void
-IllegalInstFault::invoke_se(ThreadContext *tc, const StaticInstPtr &inst)
+IllegalInstFault::invokeSE(ThreadContext *tc, const StaticInstPtr &inst)
 {
     panic("Illegal instruction 0x%08x at pc 0x%016llx: %s", inst->machInst,
         tc->pcState().pc(), reason.c_str());
 }
 
 void
-UnimplementedFault::invoke_se(ThreadContext *tc,
+UnimplementedFault::invokeSE(ThreadContext *tc,
         const StaticInstPtr &inst)
 {
     panic("Unimplemented instruction %s at pc 0x%016llx", instName,
@@ -95,21 +155,23 @@
 }
 
 void
-IllegalFrmFault::invoke_se(ThreadContext *tc, const StaticInstPtr &inst)
+IllegalFrmFault::invokeSE(ThreadContext *tc, const StaticInstPtr &inst)
 {
     panic("Illegal floating-point rounding mode 0x%x at pc 0x%016llx.",
             frm, tc->pcState().pc());
 }
 
 void
-BreakpointFault::invoke_se(ThreadContext *tc, const StaticInstPtr &inst)
+BreakpointFault::invokeSE(ThreadContext *tc, const StaticInstPtr &inst)
 {
     schedRelBreak(0);
 }
 
 void
-SyscallFault::invoke_se(ThreadContext *tc, const StaticInstPtr &inst)
+SyscallFault::invokeSE(ThreadContext *tc, const StaticInstPtr &inst)
 {
     Fault *fault = NoFault;
     tc->syscall(tc->readIntReg(SyscallNumReg), fault);
 }
+
+} // namespace RiscvISA
\ No newline at end of file
diff --git a/src/arch/riscv/faults.hh b/src/arch/riscv/faults.hh
index 478bfd2..ef0fdb6 100644
--- a/src/arch/riscv/faults.hh
+++ b/src/arch/riscv/faults.hh
@@ -36,19 +36,22 @@
 
 #include <string>
 
+#include "arch/riscv/registers.hh"
 #include "cpu/thread_context.hh"
 #include "sim/faults.hh"
 
 namespace RiscvISA
 {
 
-const uint32_t FloatInexact = 1 << 0;
-const uint32_t FloatUnderflow = 1 << 1;
-const uint32_t FloatOverflow = 1 << 2;
-const uint32_t FloatDivZero = 1 << 3;
-const uint32_t FloatInvalid = 1 << 4;
+enum FloatException : MiscReg {
+    FloatInexact = 0x1,
+    FloatUnderflow = 0x2,
+    FloatOverflow = 0x4,
+    FloatDivZero = 0x8,
+    FloatInvalid = 0x10
+};
 
-enum ExceptionCode {
+enum ExceptionCode : MiscReg {
     INST_ADDR_MISALIGNED = 0,
     INST_ACCESS = 1,
     INST_ILLEGAL = 2,
@@ -61,49 +64,77 @@
     AMO_ACCESS = 7,
     ECALL_USER = 8,
     ECALL_SUPER = 9,
-    ECALL_HYPER = 10,
-    ECALL_MACH = 11
+    ECALL_MACHINE = 11,
+    INST_PAGE = 12,
+    LOAD_PAGE = 13,
+    STORE_PAGE = 15,
+    AMO_PAGE = 15
 };
 
-enum InterruptCode {
-    SOFTWARE,
-    TIMER
-};
+/**
+ * These fields are specified in the RISC-V Instruction Set Manual, Volume II,
+ * v1.10, accessible at www.riscv.org. in Figure 3.7. The main register that
+ * uses these fields is the MSTATUS register, which is shadowed by two others
+ * accessible at lower privilege levels (SSTATUS and USTATUS) that can't see
+ * the fields for higher privileges.
+ */
+BitUnion64(STATUS)
+    Bitfield<63> sd;
+    Bitfield<35, 34> sxl;
+    Bitfield<33, 32> uxl;
+    Bitfield<22> tsr;
+    Bitfield<21> tw;
+    Bitfield<20> tvm;
+    Bitfield<19> mxr;
+    Bitfield<18> sum;
+    Bitfield<17> mprv;
+    Bitfield<16, 15> xs;
+    Bitfield<14, 13> fs;
+    Bitfield<12, 11> mpp;
+    Bitfield<8> spp;
+    Bitfield<7> mpie;
+    Bitfield<5> spie;
+    Bitfield<4> upie;
+    Bitfield<3> mie;
+    Bitfield<1> sie;
+    Bitfield<0> uie;
+EndBitUnion(STATUS)
+
+/**
+ * These fields are specified in the RISC-V Instruction Set Manual, Volume II,
+ * v1.10 in Figures 3.11 and 3.12, accessible at www.riscv.org. Both the MIP
+ * and MIE registers have the same fields, so accesses to either should use
+ * this bit union.
+ */
+BitUnion64(INTERRUPT)
+    Bitfield<11> mei;
+    Bitfield<9> sei;
+    Bitfield<8> uei;
+    Bitfield<7> mti;
+    Bitfield<5> sti;
+    Bitfield<4> uti;
+    Bitfield<3> msi;
+    Bitfield<1> ssi;
+    Bitfield<0> usi;
+EndBitUnion(INTERRUPT)
 
 class RiscvFault : public FaultBase
 {
   protected:
     const FaultName _name;
+    bool _interrupt;
     const ExceptionCode _code;
-    const InterruptCode _int;
 
-    RiscvFault(FaultName n, ExceptionCode c, InterruptCode i)
-        : _name(n), _code(c), _int(i)
+    RiscvFault(FaultName n, bool i, ExceptionCode c)
+        : _name(n), _interrupt(i), _code(c)
     {}
 
-    FaultName
-    name() const
-    {
-        return _name;
-    }
+    FaultName name() const { return _name; }
+    bool isInterrupt() const { return _interrupt; }
+    ExceptionCode exception() const { return _code; }
 
-    ExceptionCode
-    exception() const
-    {
-        return _code;
-    }
-
-    InterruptCode
-    interrupt() const
-    {
-        return _int;
-    }
-
-    virtual void
-    invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
-
-    void
-    invoke(ThreadContext *tc, const StaticInstPtr &inst);
+    virtual void invokeSE(ThreadContext *tc, const StaticInstPtr &inst);
+    void invoke(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 class Reset : public FaultBase
@@ -131,63 +162,59 @@
 class UnknownInstFault : public RiscvFault
 {
   public:
-    UnknownInstFault() : RiscvFault("Unknown instruction", INST_ILLEGAL,
-            SOFTWARE)
+    UnknownInstFault() : RiscvFault("Unknown instruction", false, INST_ILLEGAL)
     {}
 
-    void
-    invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
+    void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 class IllegalInstFault : public RiscvFault
 {
   private:
     const std::string reason;
+
   public:
     IllegalInstFault(std::string r)
-        : RiscvFault("Illegal instruction", INST_ILLEGAL, SOFTWARE),
-          reason(r)
+        : RiscvFault("Illegal instruction", false, INST_ILLEGAL)
     {}
 
-    void invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
+    void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 class UnimplementedFault : public RiscvFault
 {
   private:
     const std::string instName;
+
   public:
     UnimplementedFault(std::string name)
-        : RiscvFault("Unimplemented instruction", INST_ILLEGAL, SOFTWARE),
-        instName(name)
+        : RiscvFault("Unimplemented instruction", false, INST_ILLEGAL),
+          instName(name)
     {}
 
-    void
-    invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
+    void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 class IllegalFrmFault: public RiscvFault
 {
   private:
     const uint8_t frm;
+
   public:
     IllegalFrmFault(uint8_t r)
-        : RiscvFault("Illegal floating-point rounding mode", INST_ILLEGAL,
-                SOFTWARE),
-        frm(r)
+        : RiscvFault("Illegal floating-point rounding mode", false,
+                     INST_ILLEGAL),
+          frm(r)
     {}
 
-    void invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
+    void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 class BreakpointFault : public RiscvFault
 {
   public:
-    BreakpointFault() : RiscvFault("Breakpoint", BREAKPOINT, SOFTWARE)
-    {}
-
-    void
-    invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
+    BreakpointFault() : RiscvFault("Breakpoint", false, BREAKPOINT) {}
+    void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 class SyscallFault : public RiscvFault
@@ -195,11 +222,8 @@
   public:
     // TODO: replace ECALL_USER with the appropriate privilege level of the
     // caller
-    SyscallFault() : RiscvFault("System call", ECALL_USER, SOFTWARE)
-    {}
-
-    void
-    invoke_se(ThreadContext *tc, const StaticInstPtr &inst);
+    SyscallFault() : RiscvFault("System call", false, ECALL_USER) {}
+    void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
 } // namespace RiscvISA