arch-riscv: use sext rather than manual masks

Replace manual creation of masks for sign extension of immediates with
the sext<N> function.

Change-Id: Ief2df91a25500c64f5bcae0dcd437c1e3bb95e6c
Reviewed-on: https://gem5-review.googlesource.com/6182
Reviewed-by: Alec Roelke <ar4jc@virginia.edu>
Maintainer: Alec Roelke <ar4jc@virginia.edu>
diff --git a/src/arch/riscv/isa/bitfields.isa b/src/arch/riscv/isa/bitfields.isa
index 8372ed9..903fce3 100644
--- a/src/arch/riscv/isa/bitfields.isa
+++ b/src/arch/riscv/isa/bitfields.isa
@@ -64,8 +64,8 @@
 // SB-Type
 def bitfield BIMM12BIT11 <7>;
 def bitfield BIMM12BITS4TO1<11:8>;
-def bitfield IMMSIGN <31>;
 def bitfield BIMM12BITS10TO5 <30:25>;
+def bitfield IMMSIGN <31>;
 
 // UJ-Type
 def bitfield UJIMMBITS10TO1 <30:21>;
diff --git a/src/arch/riscv/isa/formats/mem.isa b/src/arch/riscv/isa/formats/mem.isa
index 2cb2f18..11b6c42 100644
--- a/src/arch/riscv/isa/formats/mem.isa
+++ b/src/arch/riscv/isa/formats/mem.isa
@@ -254,25 +254,17 @@
     }
 }};
 
-def format Load(memacc_code, ea_code={{EA = Rs1 + offset;}}, mem_flags=[],
-        inst_flags=[]) {{
-    offset_code = """
-                    offset = IMM12;
-                    if (IMMSIGN > 0)
-                        offset |= ~((uint64_t)0xFFF);
-                  """
+def format Load(memacc_code, ea_code = {{EA = Rs1 + offset;}},
+        offset_code={{offset = sext<12>(IMM12);}},
+        mem_flags=[], inst_flags=[]) {{
     (header_output, decoder_output, decode_block, exec_output) = \
         LoadStoreBase(name, Name, offset_code, ea_code, memacc_code, mem_flags,
         inst_flags, 'Load', exec_template_base='Load')
 }};
 
-def format Store(memacc_code, ea_code={{EA = Rs1 + offset;}}, mem_flags=[],
-        inst_flags=[]) {{
-    offset_code = """
-                    offset = IMM5 | (IMM7 << 5);
-                    if (IMMSIGN > 0)
-                        offset |= ~((uint64_t)0xFFF);
-                  """
+def format Store(memacc_code, ea_code={{EA = Rs1 + offset;}},
+        offset_code={{offset = sext<12>(IMM5 | (IMM7 << 5));}},
+        mem_flags=[], inst_flags=[]) {{
     (header_output, decoder_output, decode_block, exec_output) = \
         LoadStoreBase(name, Name, offset_code, ea_code, memacc_code, mem_flags,
         inst_flags, 'Store', exec_template_base='Store')
diff --git a/src/arch/riscv/isa/formats/standard.isa b/src/arch/riscv/isa/formats/standard.isa
index e68cedf..517313d 100644
--- a/src/arch/riscv/isa/formats/standard.isa
+++ b/src/arch/riscv/isa/formats/standard.isa
@@ -219,10 +219,9 @@
 }};
 
 def format IOp(code, *opt_flags) {{
-    imm_code = 'imm = IMM12; if (IMMSIGN > 0) imm |= ~((uint64_t)0x7FF);'
     regs = ['_destRegIdx[0]','_srcRegIdx[0]']
     iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
-        {'code': code, 'imm_code': imm_code,
+        {'code': code, 'imm_code': 'imm = sext<12>(IMM12);',
          'regs': ','.join(regs)}, opt_flags)
     header_output = ImmDeclare.subst(iop)
     decoder_output = ImmConstructor.subst(iop)
@@ -232,11 +231,11 @@
 
 def format BOp(code, *opt_flags) {{
     imm_code = """
-                imm |= BIMM12BIT11 << 11;
-                imm |= BIMM12BITS4TO1 << 1;
-                imm |= BIMM12BITS10TO5 << 5;
-                if (IMMSIGN > 0)
-                    imm |= ~((uint64_t)0xFFF);
+                imm = BIMM12BITS4TO1 << 1  |
+                      BIMM12BITS10TO5 << 5 |
+                      BIMM12BIT11 << 11    |
+                      IMMSIGN << 12;
+                imm = sext<13>(imm);
                """
     regs = ['_srcRegIdx[0]','_srcRegIdx[1]']
     iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
@@ -249,10 +248,9 @@
 }};
 
 def format Jump(code, *opt_flags) {{
-    imm_code = 'imm = IMM12; if (IMMSIGN > 0) imm |= ~((uint64_t)0x7FF);'
-    regs = ['_destRegIdx[0]','_srcRegIdx[0]']
+    regs = ['_destRegIdx[0]', '_srcRegIdx[0]']
     iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
-        {'code': code, 'imm_code': imm_code,
+        {'code': code, 'imm_code': 'imm = sext<12>(IMM12);',
          'regs': ','.join(regs)}, opt_flags)
     header_output = JumpDeclare.subst(iop)
     decoder_output = ImmConstructor.subst(iop)
@@ -261,10 +259,9 @@
 }};
 
 def format UOp(code, *opt_flags) {{
-    imm_code = 'imm = (int32_t)(IMM20 << 12);'
     regs = ['_destRegIdx[0]']
     iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
-        {'code': code, 'imm_code': imm_code,
+        {'code': code, 'imm_code': 'imm = sext<20>(IMM20) << 12;',
          'regs': ','.join(regs)}, opt_flags)
     header_output = ImmDeclare.subst(iop)
     decoder_output = ImmConstructor.subst(iop)
@@ -274,11 +271,11 @@
 
 def format JOp(code, *opt_flags) {{
     imm_code = """
-                imm |= UJIMMBITS19TO12 << 12;
-                imm |= UJIMMBIT11 << 11;
-                imm |= UJIMMBITS10TO1 << 1;
-                if (IMMSIGN > 0)
-                    imm |= ~((uint64_t)0xFFFFF);
+                imm = UJIMMBITS10TO1 << 1   |
+                      UJIMMBIT11 << 11      |
+                      UJIMMBITS19TO12 << 12 |
+                      IMMSIGN << 20;
+                imm = sext<21>(imm);
                """
     pc = 'pc.set(pc.pc() + imm);'
     regs = ['_destRegIdx[0]']