// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package bpf

import 

// An Instruction is one instruction executed by the BPF virtual
// machine.
type Instruction interface {
	// Assemble assembles the Instruction into a RawInstruction.
	Assemble() (RawInstruction, error)
}

// A RawInstruction is a raw BPF virtual machine instruction.
type RawInstruction struct {
	// Operation to execute.
	Op uint16
	// For conditional jump instructions, the number of instructions
	// to skip if the condition is true/false.
	Jt uint8
	Jf uint8
	// Constant parameter. The meaning depends on the Op.
	K uint32
}

// Assemble implements the Instruction Assemble method.
func ( RawInstruction) () (RawInstruction, error) { return , nil }

// Disassemble parses ri into an Instruction and returns it. If ri is
// not recognized by this package, ri itself is returned.
func ( RawInstruction) () Instruction {
	switch .Op & opMaskCls {
	case opClsLoadA, opClsLoadX:
		 := Register(.Op & opMaskLoadDest)
		 := 0
		switch .Op & opMaskLoadWidth {
		case opLoadWidth4:
			 = 4
		case opLoadWidth2:
			 = 2
		case opLoadWidth1:
			 = 1
		default:
			return 
		}
		switch .Op & opMaskLoadMode {
		case opAddrModeImmediate:
			if  != 4 {
				return 
			}
			return LoadConstant{Dst: , Val: .K}
		case opAddrModeScratch:
			if  != 4 || .K > 15 {
				return 
			}
			return LoadScratch{Dst: , N: int(.K)}
		case opAddrModeAbsolute:
			if .K > extOffset+0xffffffff {
				return LoadExtension{Num: Extension(-extOffset + .K)}
			}
			return LoadAbsolute{Size: , Off: .K}
		case opAddrModeIndirect:
			return LoadIndirect{Size: , Off: .K}
		case opAddrModePacketLen:
			if  != 4 {
				return 
			}
			return LoadExtension{Num: ExtLen}
		case opAddrModeMemShift:
			return LoadMemShift{Off: .K}
		default:
			return 
		}

	case opClsStoreA:
		if .Op != opClsStoreA || .K > 15 {
			return 
		}
		return StoreScratch{Src: RegA, N: int(.K)}

	case opClsStoreX:
		if .Op != opClsStoreX || .K > 15 {
			return 
		}
		return StoreScratch{Src: RegX, N: int(.K)}

	case opClsALU:
		switch  := ALUOp(.Op & opMaskOperator);  {
		case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor:
			switch  := opOperand(.Op & opMaskOperand);  {
			case opOperandX:
				return ALUOpX{Op: }
			case opOperandConstant:
				return ALUOpConstant{Op: , Val: .K}
			default:
				return 
			}
		case aluOpNeg:
			return NegateA{}
		default:
			return 
		}

	case opClsJump:
		switch  := jumpOp(.Op & opMaskOperator);  {
		case opJumpAlways:
			return Jump{Skip: .K}
		case opJumpEqual, opJumpGT, opJumpGE, opJumpSet:
			, ,  := jumpOpToTest(, .Jt, .Jf)
			switch  := opOperand(.Op & opMaskOperand);  {
			case opOperandX:
				return JumpIfX{Cond: , SkipTrue: , SkipFalse: }
			case opOperandConstant:
				return JumpIf{Cond: , Val: .K, SkipTrue: , SkipFalse: }
			default:
				return 
			}
		default:
			return 
		}

	case opClsReturn:
		switch .Op {
		case opClsReturn | opRetSrcA:
			return RetA{}
		case opClsReturn | opRetSrcConstant:
			return RetConstant{Val: .K}
		default:
			return 
		}

	case opClsMisc:
		switch .Op {
		case opClsMisc | opMiscTAX:
			return TAX{}
		case opClsMisc | opMiscTXA:
			return TXA{}
		default:
			return 
		}

	default:
		panic("unreachable") // switch is exhaustive on the bit pattern
	}
}

func jumpOpToTest( jumpOp,  uint8,  uint8) (JumpTest, uint8, uint8) {
	var  JumpTest

	// Decode "fake" jump conditions that don't appear in machine code
	// Ensures the Assemble -> Disassemble stage recreates the same instructions
	// See https://github.com/golang/go/issues/18470
	if  == 0 {
		switch  {
		case opJumpEqual:
			 = JumpNotEqual
		case opJumpGT:
			 = JumpLessOrEqual
		case opJumpGE:
			 = JumpLessThan
		case opJumpSet:
			 = JumpBitsNotSet
		}

		return , , 0
	}

	switch  {
	case opJumpEqual:
		 = JumpEqual
	case opJumpGT:
		 = JumpGreaterThan
	case opJumpGE:
		 = JumpGreaterOrEqual
	case opJumpSet:
		 = JumpBitsSet
	}

	return , , 
}

// LoadConstant loads Val into register Dst.
type LoadConstant struct {
	Dst Register
	Val uint32
}

// Assemble implements the Instruction Assemble method.
func ( LoadConstant) () (RawInstruction, error) {
	return assembleLoad(.Dst, 4, opAddrModeImmediate, .Val)
}

// String returns the instruction in assembler notation.
func ( LoadConstant) () string {
	switch .Dst {
	case RegA:
		return fmt.Sprintf("ld #%d", .Val)
	case RegX:
		return fmt.Sprintf("ldx #%d", .Val)
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// LoadScratch loads scratch[N] into register Dst.
type LoadScratch struct {
	Dst Register
	N   int // 0-15
}

// Assemble implements the Instruction Assemble method.
func ( LoadScratch) () (RawInstruction, error) {
	if .N < 0 || .N > 15 {
		return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", .N)
	}
	return assembleLoad(.Dst, 4, opAddrModeScratch, uint32(.N))
}

// String returns the instruction in assembler notation.
func ( LoadScratch) () string {
	switch .Dst {
	case RegA:
		return fmt.Sprintf("ld M[%d]", .N)
	case RegX:
		return fmt.Sprintf("ldx M[%d]", .N)
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// LoadAbsolute loads packet[Off:Off+Size] as an integer value into
// register A.
type LoadAbsolute struct {
	Off  uint32
	Size int // 1, 2 or 4
}

// Assemble implements the Instruction Assemble method.
func ( LoadAbsolute) () (RawInstruction, error) {
	return assembleLoad(RegA, .Size, opAddrModeAbsolute, .Off)
}

// String returns the instruction in assembler notation.
func ( LoadAbsolute) () string {
	switch .Size {
	case 1: // byte
		return fmt.Sprintf("ldb [%d]", .Off)
	case 2: // half word
		return fmt.Sprintf("ldh [%d]", .Off)
	case 4: // word
		if .Off > extOffset+0xffffffff {
			return LoadExtension{Num: Extension(.Off + 0x1000)}.String()
		}
		return fmt.Sprintf("ld [%d]", .Off)
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// LoadIndirect loads packet[X+Off:X+Off+Size] as an integer value
// into register A.
type LoadIndirect struct {
	Off  uint32
	Size int // 1, 2 or 4
}

// Assemble implements the Instruction Assemble method.
func ( LoadIndirect) () (RawInstruction, error) {
	return assembleLoad(RegA, .Size, opAddrModeIndirect, .Off)
}

// String returns the instruction in assembler notation.
func ( LoadIndirect) () string {
	switch .Size {
	case 1: // byte
		return fmt.Sprintf("ldb [x + %d]", .Off)
	case 2: // half word
		return fmt.Sprintf("ldh [x + %d]", .Off)
	case 4: // word
		return fmt.Sprintf("ld [x + %d]", .Off)
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// LoadMemShift multiplies the first 4 bits of the byte at packet[Off]
// by 4 and stores the result in register X.
//
// This instruction is mainly useful to load into X the length of an
// IPv4 packet header in a single instruction, rather than have to do
// the arithmetic on the header's first byte by hand.
type LoadMemShift struct {
	Off uint32
}

// Assemble implements the Instruction Assemble method.
func ( LoadMemShift) () (RawInstruction, error) {
	return assembleLoad(RegX, 1, opAddrModeMemShift, .Off)
}

// String returns the instruction in assembler notation.
func ( LoadMemShift) () string {
	return fmt.Sprintf("ldx 4*([%d]&0xf)", .Off)
}

// LoadExtension invokes a linux-specific extension and stores the
// result in register A.
type LoadExtension struct {
	Num Extension
}

// Assemble implements the Instruction Assemble method.
func ( LoadExtension) () (RawInstruction, error) {
	if .Num == ExtLen {
		return assembleLoad(RegA, 4, opAddrModePacketLen, 0)
	}
	return assembleLoad(RegA, 4, opAddrModeAbsolute, uint32(extOffset+.Num))
}

// String returns the instruction in assembler notation.
func ( LoadExtension) () string {
	switch .Num {
	case ExtLen:
		return "ld #len"
	case ExtProto:
		return "ld #proto"
	case ExtType:
		return "ld #type"
	case ExtPayloadOffset:
		return "ld #poff"
	case ExtInterfaceIndex:
		return "ld #ifidx"
	case ExtNetlinkAttr:
		return "ld #nla"
	case ExtNetlinkAttrNested:
		return "ld #nlan"
	case ExtMark:
		return "ld #mark"
	case ExtQueue:
		return "ld #queue"
	case ExtLinkLayerType:
		return "ld #hatype"
	case ExtRXHash:
		return "ld #rxhash"
	case ExtCPUID:
		return "ld #cpu"
	case ExtVLANTag:
		return "ld #vlan_tci"
	case ExtVLANTagPresent:
		return "ld #vlan_avail"
	case ExtVLANProto:
		return "ld #vlan_tpid"
	case ExtRand:
		return "ld #rand"
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// StoreScratch stores register Src into scratch[N].
type StoreScratch struct {
	Src Register
	N   int // 0-15
}

// Assemble implements the Instruction Assemble method.
func ( StoreScratch) () (RawInstruction, error) {
	if .N < 0 || .N > 15 {
		return RawInstruction{}, fmt.Errorf("invalid scratch slot %d", .N)
	}
	var  uint16
	switch .Src {
	case RegA:
		 = opClsStoreA
	case RegX:
		 = opClsStoreX
	default:
		return RawInstruction{}, fmt.Errorf("invalid source register %v", .Src)
	}

	return RawInstruction{
		Op: ,
		K:  uint32(.N),
	}, nil
}

// String returns the instruction in assembler notation.
func ( StoreScratch) () string {
	switch .Src {
	case RegA:
		return fmt.Sprintf("st M[%d]", .N)
	case RegX:
		return fmt.Sprintf("stx M[%d]", .N)
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// ALUOpConstant executes A = A <Op> Val.
type ALUOpConstant struct {
	Op  ALUOp
	Val uint32
}

// Assemble implements the Instruction Assemble method.
func ( ALUOpConstant) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsALU | uint16(opOperandConstant) | uint16(.Op),
		K:  .Val,
	}, nil
}

// String returns the instruction in assembler notation.
func ( ALUOpConstant) () string {
	switch .Op {
	case ALUOpAdd:
		return fmt.Sprintf("add #%d", .Val)
	case ALUOpSub:
		return fmt.Sprintf("sub #%d", .Val)
	case ALUOpMul:
		return fmt.Sprintf("mul #%d", .Val)
	case ALUOpDiv:
		return fmt.Sprintf("div #%d", .Val)
	case ALUOpMod:
		return fmt.Sprintf("mod #%d", .Val)
	case ALUOpAnd:
		return fmt.Sprintf("and #%d", .Val)
	case ALUOpOr:
		return fmt.Sprintf("or #%d", .Val)
	case ALUOpXor:
		return fmt.Sprintf("xor #%d", .Val)
	case ALUOpShiftLeft:
		return fmt.Sprintf("lsh #%d", .Val)
	case ALUOpShiftRight:
		return fmt.Sprintf("rsh #%d", .Val)
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// ALUOpX executes A = A <Op> X
type ALUOpX struct {
	Op ALUOp
}

// Assemble implements the Instruction Assemble method.
func ( ALUOpX) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsALU | uint16(opOperandX) | uint16(.Op),
	}, nil
}

// String returns the instruction in assembler notation.
func ( ALUOpX) () string {
	switch .Op {
	case ALUOpAdd:
		return "add x"
	case ALUOpSub:
		return "sub x"
	case ALUOpMul:
		return "mul x"
	case ALUOpDiv:
		return "div x"
	case ALUOpMod:
		return "mod x"
	case ALUOpAnd:
		return "and x"
	case ALUOpOr:
		return "or x"
	case ALUOpXor:
		return "xor x"
	case ALUOpShiftLeft:
		return "lsh x"
	case ALUOpShiftRight:
		return "rsh x"
	default:
		return fmt.Sprintf("unknown instruction: %#v", )
	}
}

// NegateA executes A = -A.
type NegateA struct{}

// Assemble implements the Instruction Assemble method.
func ( NegateA) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsALU | uint16(aluOpNeg),
	}, nil
}

// String returns the instruction in assembler notation.
func ( NegateA) () string {
	return fmt.Sprintf("neg")
}

// Jump skips the following Skip instructions in the program.
type Jump struct {
	Skip uint32
}

// Assemble implements the Instruction Assemble method.
func ( Jump) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsJump | uint16(opJumpAlways),
		K:  .Skip,
	}, nil
}

// String returns the instruction in assembler notation.
func ( Jump) () string {
	return fmt.Sprintf("ja %d", .Skip)
}

// JumpIf skips the following Skip instructions in the program if A
// <Cond> Val is true.
type JumpIf struct {
	Cond      JumpTest
	Val       uint32
	SkipTrue  uint8
	SkipFalse uint8
}

// Assemble implements the Instruction Assemble method.
func ( JumpIf) () (RawInstruction, error) {
	return jumpToRaw(.Cond, opOperandConstant, .Val, .SkipTrue, .SkipFalse)
}

// String returns the instruction in assembler notation.
func ( JumpIf) () string {
	return jumpToString(.Cond, fmt.Sprintf("#%d", .Val), .SkipTrue, .SkipFalse)
}

// JumpIfX skips the following Skip instructions in the program if A
// <Cond> X is true.
type JumpIfX struct {
	Cond      JumpTest
	SkipTrue  uint8
	SkipFalse uint8
}

// Assemble implements the Instruction Assemble method.
func ( JumpIfX) () (RawInstruction, error) {
	return jumpToRaw(.Cond, opOperandX, 0, .SkipTrue, .SkipFalse)
}

// String returns the instruction in assembler notation.
func ( JumpIfX) () string {
	return jumpToString(.Cond, "x", .SkipTrue, .SkipFalse)
}

// jumpToRaw assembles a jump instruction into a RawInstruction
func jumpToRaw( JumpTest,  opOperand,  uint32, ,  uint8) (RawInstruction, error) {
	var (
		 jumpOp
		 bool
	)
	switch  {
	case JumpEqual:
		 = opJumpEqual
	case JumpNotEqual:
		,  = opJumpEqual, true
	case JumpGreaterThan:
		 = opJumpGT
	case JumpLessThan:
		,  = opJumpGE, true
	case JumpGreaterOrEqual:
		 = opJumpGE
	case JumpLessOrEqual:
		,  = opJumpGT, true
	case JumpBitsSet:
		 = opJumpSet
	case JumpBitsNotSet:
		,  = opJumpSet, true
	default:
		return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", )
	}
	,  := , 
	if  {
		,  = , 
	}
	return RawInstruction{
		Op: opClsJump | uint16() | uint16(),
		Jt: ,
		Jf: ,
		K:  ,
	}, nil
}

// jumpToString converts a jump instruction to assembler notation
func jumpToString( JumpTest,  string, ,  uint8) string {
	switch  {
	// K == A
	case JumpEqual:
		return conditionalJump(, , , "jeq", "jneq")
	// K != A
	case JumpNotEqual:
		return fmt.Sprintf("jneq %s,%d", , )
	// K > A
	case JumpGreaterThan:
		return conditionalJump(, , , "jgt", "jle")
	// K < A
	case JumpLessThan:
		return fmt.Sprintf("jlt %s,%d", , )
	// K >= A
	case JumpGreaterOrEqual:
		return conditionalJump(, , , "jge", "jlt")
	// K <= A
	case JumpLessOrEqual:
		return fmt.Sprintf("jle %s,%d", , )
	// K & A != 0
	case JumpBitsSet:
		if  > 0 {
			return fmt.Sprintf("jset %s,%d,%d", , , )
		}
		return fmt.Sprintf("jset %s,%d", , )
	// K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips
	case JumpBitsNotSet:
		return (JumpBitsSet, , , )
	default:
		return fmt.Sprintf("unknown JumpTest %#v", )
	}
}

func conditionalJump( string, ,  uint8, ,  string) string {
	if  > 0 {
		if  > 0 {
			return fmt.Sprintf("%s %s,%d,%d", , , , )
		}
		return fmt.Sprintf("%s %s,%d", , , )
	}
	return fmt.Sprintf("%s %s,%d", , , )
}

// RetA exits the BPF program, returning the value of register A.
type RetA struct{}

// Assemble implements the Instruction Assemble method.
func ( RetA) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsReturn | opRetSrcA,
	}, nil
}

// String returns the instruction in assembler notation.
func ( RetA) () string {
	return fmt.Sprintf("ret a")
}

// RetConstant exits the BPF program, returning a constant value.
type RetConstant struct {
	Val uint32
}

// Assemble implements the Instruction Assemble method.
func ( RetConstant) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsReturn | opRetSrcConstant,
		K:  .Val,
	}, nil
}

// String returns the instruction in assembler notation.
func ( RetConstant) () string {
	return fmt.Sprintf("ret #%d", .Val)
}

// TXA copies the value of register X to register A.
type TXA struct{}

// Assemble implements the Instruction Assemble method.
func ( TXA) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsMisc | opMiscTXA,
	}, nil
}

// String returns the instruction in assembler notation.
func ( TXA) () string {
	return fmt.Sprintf("txa")
}

// TAX copies the value of register A to register X.
type TAX struct{}

// Assemble implements the Instruction Assemble method.
func ( TAX) () (RawInstruction, error) {
	return RawInstruction{
		Op: opClsMisc | opMiscTAX,
	}, nil
}

// String returns the instruction in assembler notation.
func ( TAX) () string {
	return fmt.Sprintf("tax")
}

func assembleLoad( Register,  int,  uint16,  uint32) (RawInstruction, error) {
	var (
		 uint16
		  uint16
	)
	switch  {
	case RegA:
		 = opClsLoadA
	case RegX:
		 = opClsLoadX
	default:
		return RawInstruction{}, fmt.Errorf("invalid target register %v", )
	}
	switch  {
	case 1:
		 = opLoadWidth1
	case 2:
		 = opLoadWidth2
	case 4:
		 = opLoadWidth4
	default:
		return RawInstruction{}, fmt.Errorf("invalid load byte length %d", )
	}
	return RawInstruction{
		Op:  |  | ,
		K:  ,
	}, nil
}