arch-riscv: Add support for trap value register

RISC-V has a set of CSRs that contain information about a trap that was
taken into each privilegel level, such as illegal instruction bytes or
faulting address.  This patch adds that register, modifies existing
faults to make use of it, and adds a new fault for future use with
handling page faults and bad addresses.

Change-Id: I3004bd7b907e7dc75e5f1a8452a1d74796a7a551
Reviewed-on: https://gem5-review.googlesource.com/11135
Reviewed-by: Jason Lowe-Power <jason@lowepower.com>
Maintainer: Alec Roelke <alec.roelke@gmail.com>
diff --git a/src/arch/riscv/faults.cc b/src/arch/riscv/faults.cc
index efab6c4..b5f3d07 100644
--- a/src/arch/riscv/faults.cc
+++ b/src/arch/riscv/faults.cc
@@ -71,12 +71,13 @@
         }
 
         // Set fault registers and status
-        MiscRegIndex cause, epc, tvec;
+        MiscRegIndex cause, epc, tvec, tval;
         switch (prv) {
           case PRV_U:
             cause = MISCREG_UCAUSE;
             epc = MISCREG_UEPC;
             tvec = MISCREG_UTVEC;
+            tval = MISCREG_UTVAL;
 
             status.upie = status.uie;
             status.uie = 0;
@@ -85,6 +86,7 @@
             cause = MISCREG_SCAUSE;
             epc = MISCREG_SEPC;
             tvec = MISCREG_STVEC;
+            tval = MISCREG_STVAL;
 
             status.spp = pp;
             status.spie = status.sie;
@@ -94,6 +96,7 @@
             cause = MISCREG_MCAUSE;
             epc = MISCREG_MEPC;
             tvec = MISCREG_MTVEC;
+            tval = MISCREG_MTVAL;
 
             status.mpp = pp;
             status.mpie = status.sie;
@@ -108,6 +111,7 @@
         tc->setMiscReg(cause,
                        (isInterrupt() << (sizeof(MiscReg) * 4 - 1)) | _code);
         tc->setMiscReg(epc, tc->instAddr());
+        tc->setMiscReg(tval, trap_value());
         tc->setMiscReg(MISCREG_PRV, prv);
         tc->setMiscReg(MISCREG_STATUS, status);
 
diff --git a/src/arch/riscv/faults.hh b/src/arch/riscv/faults.hh
index ef0fdb6..6d3fdeb 100644
--- a/src/arch/riscv/faults.hh
+++ b/src/arch/riscv/faults.hh
@@ -129,9 +129,10 @@
         : _name(n), _interrupt(i), _code(c)
     {}
 
-    FaultName name() const { return _name; }
+    FaultName name() const override { return _name; }
     bool isInterrupt() const { return _interrupt; }
     ExceptionCode exception() const { return _code; }
+    virtual MiscReg trap_value() const { return 0; }
 
     virtual void invokeSE(ThreadContext *tc, const StaticInstPtr &inst);
     void invoke(ThreadContext *tc, const StaticInstPtr &inst) override;
@@ -159,61 +160,94 @@
         const FaultName _name;
 };
 
-class UnknownInstFault : public RiscvFault
+class InstFault : public RiscvFault
+{
+  protected:
+    const ExtMachInst _inst;
+
+  public:
+    InstFault(FaultName n, const ExtMachInst inst)
+        : RiscvFault(n, false, INST_ILLEGAL), _inst(inst)
+    {}
+
+    MiscReg trap_value() const override { return _inst; }
+};
+
+class UnknownInstFault : public InstFault
 {
   public:
-    UnknownInstFault() : RiscvFault("Unknown instruction", false, INST_ILLEGAL)
+    UnknownInstFault(const ExtMachInst inst)
+        : InstFault("Unknown instruction", inst)
     {}
 
     void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
-class IllegalInstFault : public RiscvFault
+class IllegalInstFault : public InstFault
 {
   private:
     const std::string reason;
 
   public:
-    IllegalInstFault(std::string r)
-        : RiscvFault("Illegal instruction", false, INST_ILLEGAL)
+    IllegalInstFault(std::string r, const ExtMachInst inst)
+        : InstFault("Illegal instruction", inst)
     {}
 
     void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
-class UnimplementedFault : public RiscvFault
+class UnimplementedFault : public InstFault
 {
   private:
     const std::string instName;
 
   public:
-    UnimplementedFault(std::string name)
-        : RiscvFault("Unimplemented instruction", false, INST_ILLEGAL),
+    UnimplementedFault(std::string name, const ExtMachInst inst)
+        : InstFault("Unimplemented instruction", inst),
           instName(name)
     {}
 
     void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
-class IllegalFrmFault: public RiscvFault
+class IllegalFrmFault: public InstFault
 {
   private:
     const uint8_t frm;
 
   public:
-    IllegalFrmFault(uint8_t r)
-        : RiscvFault("Illegal floating-point rounding mode", false,
-                     INST_ILLEGAL),
+    IllegalFrmFault(uint8_t r, const ExtMachInst inst)
+        : InstFault("Illegal floating-point rounding mode", inst),
           frm(r)
     {}
 
     void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
+class AddressFault : public RiscvFault
+{
+  private:
+    const Addr _addr;
+
+  public:
+    AddressFault(const Addr addr, ExceptionCode code)
+        : RiscvFault("Address", false, code), _addr(addr)
+    {}
+
+    MiscReg trap_value() const override { return _addr; }
+};
+
 class BreakpointFault : public RiscvFault
 {
+  private:
+    const PCState pcState;
+
   public:
-    BreakpointFault() : RiscvFault("Breakpoint", false, BREAKPOINT) {}
+    BreakpointFault(const PCState &pc)
+        : RiscvFault("Breakpoint", false, BREAKPOINT), pcState(pc)
+    {}
+
+    MiscReg trap_value() const override { return pcState.pc(); }
     void invokeSE(ThreadContext *tc, const StaticInstPtr &inst) override;
 };
 
@@ -228,4 +262,4 @@
 
 } // namespace RiscvISA
 
-#endif // __ARCH_RISCV_FAULTS_HH__
+#endif // __ARCH_RISCV_FAULTS_HH__
\ No newline at end of file
diff --git a/src/arch/riscv/insts/unknown.hh b/src/arch/riscv/insts/unknown.hh
index 049f879..a96474a 100644
--- a/src/arch/riscv/insts/unknown.hh
+++ b/src/arch/riscv/insts/unknown.hh
@@ -59,7 +59,7 @@
     Fault
     execute(ExecContext *, Trace::InstRecord *) const override
     {
-        return std::make_shared<UnknownInstFault>();
+        return std::make_shared<UnknownInstFault>(machInst);
     }
 
     std::string
diff --git a/src/arch/riscv/isa/decoder.isa b/src/arch/riscv/isa/decoder.isa
index b4bf385..a0e3ad1 100644
--- a/src/arch/riscv/isa/decoder.isa
+++ b/src/arch/riscv/isa/decoder.isa
@@ -43,7 +43,8 @@
                   CIMM8<5:2> << 6;
         }}, {{
             if (machInst == 0)
-                fault = make_shared<IllegalInstFault>("zero instruction");
+                fault = make_shared<IllegalInstFault>("zero instruction",
+                                                      machInst);
             Rp2 = sp + imm;
         }}, uint64_t);
         format CompressedLoad {
@@ -106,9 +107,11 @@
             }}, {{
                 if ((RC1 == 0) != (imm == 0)) {
                     if (RC1 == 0) {
-                        fault = make_shared<IllegalInstFault>("source reg x0");
+                        fault = make_shared<IllegalInstFault>("source reg x0",
+                                                              machInst);
                     } else // imm == 0
-                        fault = make_shared<IllegalInstFault>("immediate = 0");
+                        fault = make_shared<IllegalInstFault>("immediate = 0",
+                                                              machInst);
                 }
                 Rc1_sd = Rc1_sd + imm;
             }});
@@ -118,7 +121,8 @@
                     imm |= ~((uint64_t)0x1F);
             }}, {{
                 if (RC1 == 0) {
-                    fault = make_shared<IllegalInstFault>("source reg x0");
+                    fault = make_shared<IllegalInstFault>("source reg x0",
+                                                          machInst);
                 }
                 Rc1_sd = (int32_t)Rc1_sd + imm;
             }});
@@ -128,7 +132,8 @@
                     imm |= ~((uint64_t)0x1F);
             }}, {{
                 if (RC1 == 0) {
-                    fault = make_shared<IllegalInstFault>("source reg x0");
+                    fault = make_shared<IllegalInstFault>("source reg x0",
+                                                          machInst);
                 }
                 Rc1_sd = imm;
             }});
@@ -142,7 +147,8 @@
                         imm |= ~((int64_t)0x1FF);
                 }}, {{
                     if (imm == 0) {
-                        fault = make_shared<IllegalInstFault>("immediate = 0");
+                        fault = make_shared<IllegalInstFault>("immediate = 0",
+                                                              machInst);
                     }
                     sp_sd = sp_sd + imm;
                 }});
@@ -152,10 +158,12 @@
                         imm |= ~((uint64_t)0x1FFFF);
                 }}, {{
                     if (RC1 == 0 || RC1 == 2) {
-                        fault = make_shared<IllegalInstFault>("source reg x0");
+                        fault = make_shared<IllegalInstFault>("source reg x0",
+                                                              machInst);
                     }
                     if (imm == 0) {
-                        fault = make_shared<IllegalInstFault>("immediate = 0");
+                        fault = make_shared<IllegalInstFault>("immediate = 0",
+                                                              machInst);
                     }
                     Rc1_sd = imm;
                 }});
@@ -167,7 +175,8 @@
                     imm = CIMM5 | (CIMM1 << 5);
                 }}, {{
                     if (imm == 0) {
-                        fault = make_shared<IllegalInstFault>("immediate = 0");
+                        fault = make_shared<IllegalInstFault>("immediate = 0",
+                                                              machInst);
                     }
                     Rp1 = Rp1 >> imm;
                 }}, uint64_t);
@@ -175,7 +184,8 @@
                     imm = CIMM5 | (CIMM1 << 5);
                 }}, {{
                     if (imm == 0) {
-                        fault = make_shared<IllegalInstFault>("immediate = 0");
+                        fault = make_shared<IllegalInstFault>("immediate = 0",
+                                                              machInst);
                     }
                     Rp1_sd = Rp1_sd >> imm;
                 }}, uint64_t);
@@ -246,10 +256,12 @@
             imm = CIMM5 | (CIMM1 << 5);
         }}, {{
             if (imm == 0) {
-                fault = make_shared<IllegalInstFault>("immediate = 0");
+                fault = make_shared<IllegalInstFault>("immediate = 0",
+                                                      machInst);
             }
             if (RC1 == 0) {
-                fault = make_shared<IllegalInstFault>("source reg x0");
+                fault = make_shared<IllegalInstFault>("source reg x0",
+                                                      machInst);
             }
             Rc1 = Rc1 << imm;
         }}, uint64_t);
@@ -269,7 +281,8 @@
                          CIMM5<1:0> << 6;
             }}, {{
                 if (RC1 == 0) {
-                    fault = make_shared<IllegalInstFault>("source reg x0");
+                    fault = make_shared<IllegalInstFault>("source reg x0",
+                                                          machInst);
                 }
                 Rc1_sd = Mem_sw;
             }}, {{
@@ -281,7 +294,8 @@
                          CIMM5<2:0> << 6;
             }}, {{
                 if (RC1 == 0) {
-                    fault = make_shared<IllegalInstFault>("source reg x0");
+                    fault = make_shared<IllegalInstFault>("source reg x0",
+                                                          machInst);
                 }
                 Rc1_sd = Mem_sd;
             }}, {{
@@ -292,13 +306,15 @@
             0x0: decode RC2 {
                 0x0: Jump::c_jr({{
                     if (RC1 == 0) {
-                        fault = make_shared<IllegalInstFault>("source reg x0");
+                        fault = make_shared<IllegalInstFault>("source reg x0",
+                                                              machInst);
                     }
                     NPC = Rc1;
                 }}, IsIndirectControl, IsUncondControl, IsCall);
                 default: CROp::c_mv({{
                     if (RC1 == 0) {
-                        fault = make_shared<IllegalInstFault>("source reg x0");
+                        fault = make_shared<IllegalInstFault>("source reg x0",
+                                                              machInst);
                     }
                     Rc1 = Rc2;
                 }});
@@ -306,15 +322,17 @@
             0x1: decode RC1 {
                 0x0: SystemOp::c_ebreak({{
                     if (RC2 != 0) {
-                        fault = make_shared<IllegalInstFault>("source reg x1");
+                        fault = make_shared<IllegalInstFault>("source reg x1",
+                                                              machInst);
                     }
-                    fault = make_shared<BreakpointFault>();
+                    fault = make_shared<BreakpointFault>(xc->pcState());
                 }}, IsSerializeAfter, IsNonSpeculative, No_OpClass);
                 default: decode RC2 {
                     0x0: Jump::c_jalr({{
                         if (RC1 == 0) {
                             fault = make_shared<IllegalInstFault>
-                                                        ("source reg x0");
+                                                        ("source reg x0",
+                                                         machInst);
                         }
                         ra = NPC;
                         NPC = Rc1;
@@ -1250,7 +1268,8 @@
                 }
                 0x20: fcvt_s_d({{
                     if (CONV_SGN != 1) {
-                        fault = make_shared<IllegalInstFault>("CONV_SGN != 1");
+                        fault = make_shared<IllegalInstFault>("CONV_SGN != 1",
+                                                              machInst);
                     }
                     float fd;
                     if (issignalingnan(Fs1)) {
@@ -1263,7 +1282,8 @@
                 }}, FloatCvtOp);
                 0x21: fcvt_d_s({{
                     if (CONV_SGN != 0) {
-                        fault = make_shared<IllegalInstFault>("CONV_SGN != 0");
+                        fault = make_shared<IllegalInstFault>("CONV_SGN != 0",
+                                                              machInst);
                     }
                     uint32_t temp;
                     float fs1 = reinterpret_cast<float&>(temp = Fs1_bits);
@@ -1277,7 +1297,8 @@
                 }}, FloatCvtOp);
                 0x2c: fsqrt_s({{
                     if (RS2 != 0) {
-                        fault = make_shared<IllegalInstFault>("source reg x1");
+                        fault = make_shared<IllegalInstFault>("source reg x1",
+                                                              machInst);
                     }
                     uint32_t temp;
                     float fs1 = reinterpret_cast<float&>(temp = Fs1_bits);
@@ -1291,7 +1312,8 @@
                 }}, FloatSqrtOp);
                 0x2d: fsqrt_d({{
                     if (RS2 != 0) {
-                        fault = make_shared<IllegalInstFault>("source reg x1");
+                        fault = make_shared<IllegalInstFault>("source reg x1",
+                                                              machInst);
                     }
                     Fd = sqrt(Fs1);
                 }}, FloatSqrtOp);
@@ -1690,10 +1712,11 @@
                     }}, IsSerializeAfter, IsNonSpeculative, IsSyscall,
                         No_OpClass);
                     0x1: ebreak({{
-                        fault = make_shared<BreakpointFault>();
+                        fault = make_shared<BreakpointFault>(xc->pcState());
                     }}, IsSerializeAfter, IsNonSpeculative, No_OpClass);
                     0x100: eret({{
-                        fault = make_shared<UnimplementedFault>("eret");
+                        fault = make_shared<UnimplementedFault>("eret",
+                                                                machInst);
                     }}, No_OpClass);
                 }
             }
diff --git a/src/arch/riscv/isa/formats/fp.isa b/src/arch/riscv/isa/formats/fp.isa
index 1f08ca5..5f06721 100644
--- a/src/arch/riscv/isa/formats/fp.isa
+++ b/src/arch/riscv/isa/formats/fp.isa
@@ -57,7 +57,7 @@
                 break;
             case 0x4:
                 // Round to nearest, ties to max magnitude not implemented
-                fault = make_shared<IllegalFrmFault>(ROUND_MODE);
+                fault = make_shared<IllegalFrmFault>(ROUND_MODE, machInst);
                 break;
             case 0x7: {
                 uint8_t frm = xc->readMiscReg(MISCREG_FRM);
@@ -76,16 +76,17 @@
                     break;
                 case 0x4:
                     // Round to nearest, ties to max magnitude not implemented
-                    fault = make_shared<IllegalFrmFault>(ROUND_MODE);
+                    fault = make_shared<IllegalFrmFault>(ROUND_MODE, machInst);
                     break;
                 default:
-                    fault = std::make_shared<IllegalFrmFault>(frm);
+                    fault = std::make_shared<IllegalFrmFault>(frm, machInst);
                     break;
                 }
                 break;
             }
             default:
-                fault = std::make_shared<IllegalFrmFault>(ROUND_MODE);
+                fault = std::make_shared<IllegalFrmFault>(ROUND_MODE,
+                                                          machInst);
                 break;
             }
 
diff --git a/src/arch/riscv/isa/formats/standard.isa b/src/arch/riscv/isa/formats/standard.isa
index e69ad7e..e9539fe 100644
--- a/src/arch/riscv/isa/formats/standard.isa
+++ b/src/arch/riscv/isa/formats/standard.isa
@@ -231,7 +231,7 @@
                 olddata = xc->readMiscReg(CSRData.at(csr).physIndex);
             } else {
                 std::string error = csprintf("Illegal CSR index %#x\n", csr);
-                fault = make_shared<IllegalInstFault>(error);
+                fault = make_shared<IllegalInstFault>(error, machInst);
                 olddata = 0;
             }
             break;
@@ -252,7 +252,7 @@
                     if (bits(csr, 11, 10) == 0x3) {
                         std::string error = csprintf("CSR %s is read-only\n",
                                                      CSRData.at(csr).name);
-                        fault = make_shared<IllegalInstFault>(error);
+                        fault = make_shared<IllegalInstFault>(error, machInst);
                     } else {
                         DPRINTF(RiscvMisc, "Writing %#x to CSR %s.\n", data,
                                 CSRData.at(csr).name);