// Copyright (c) 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package edwards25519

// This file contains additional functionality that is not included in the
// upstream crypto/internal/edwards25519 package.

import (
	

	
)

// ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
func ( *Point) () (, , ,  *field.Element) {
	// This function is outlined to make the allocations inline in the caller
	// rather than happen on the heap. Don't change the style without making
	// sure it doesn't increase the inliner cost.
	var  [4]field.Element
	, , ,  = .extendedCoordinates(&)
	return
}

func ( *Point) ( *[4]field.Element) (, , ,  *field.Element) {
	checkInitialized()
	 = [0].Set(&.x)
	 = [1].Set(&.y)
	 = [2].Set(&.z)
	 = [3].Set(&.t)
	return
}

// SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
// x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
//
// If the coordinates are invalid or don't represent a valid point on the curve,
// SetExtendedCoordinates returns nil and an error and the receiver is
// unchanged. Otherwise, SetExtendedCoordinates returns v.
func ( *Point) (, , ,  *field.Element) (*Point, error) {
	if !isOnCurve(, , , ) {
		return nil, errors.New("edwards25519: invalid point coordinates")
	}
	.x.Set()
	.y.Set()
	.z.Set()
	.t.Set()
	return , nil
}

func isOnCurve(, , ,  *field.Element) bool {
	var ,  field.Element
	 := new(field.Element).Square()
	 := new(field.Element).Square()
	 := new(field.Element).Square()
	 := new(field.Element).Square()
	// -x² + y² = 1 + dx²y²
	// -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
	// -X² + Y² = Z² + dT²
	.Subtract(, )
	.Multiply(d, ).Add(&, )
	if .Equal(&) != 1 {
		return false
	}
	// xy = T/Z
	// XY/Z² = T/Z
	// XY = TZ
	.Multiply(, )
	.Multiply(, )
	return .Equal(&) == 1
}

// BytesMontgomery converts v to a point on the birationally-equivalent
// Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
// according to RFC 7748.
//
// Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
// to the same value. If v is the identity point, BytesMontgomery returns 32
// zero bytes, analogously to the X25519 function.
//
// The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
// while every valid edwards25519 point has a unique u-coordinate Montgomery
// encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
// to any edwards25519 point, and every other X25519 input corresponds to two
// edwards25519 points.
func ( *Point) () []byte {
	// This function is outlined to make the allocations inline in the caller
	// rather than happen on the heap.
	var  [32]byte
	return .bytesMontgomery(&)
}

func ( *Point) ( *[32]byte) []byte {
	checkInitialized()

	// RFC 7748, Section 4.1 provides the bilinear map to calculate the
	// Montgomery u-coordinate
	//
	//              u = (1 + y) / (1 - y)
	//
	// where y = Y / Z.

	var , ,  field.Element

	.Multiply(&.y, .Invert(&.z))        // y = Y / Z
	.Invert(.Subtract(feOne, &)) // r = 1/(1 - y)
	.Multiply(.Add(feOne, &), &)    // u = (1 + y)*r

	return copyFieldElement(, &)
}

// MultByCofactor sets v = 8 * p, and returns v.
func ( *Point) ( *Point) *Point {
	checkInitialized()
	 := projP1xP1{}
	 := (&projP2{}).FromP3()
	.Double()
	.FromP1xP1(&)
	.Double()
	.FromP1xP1(&)
	.Double()
	return .fromP1xP1(&)
}

// Given k > 0, set s = s**(2*i).
func ( *Scalar) ( int) {
	for  := 0;  < ; ++ {
		.Multiply(, )
	}
}

// Invert sets s to the inverse of a nonzero scalar v, and returns s.
//
// If t is zero, Invert returns zero.
func ( *Scalar) ( *Scalar) *Scalar {
	// Uses a hardcoded sliding window of width 4.
	var  [8]Scalar
	var  Scalar
	.Multiply(, )
	[0] = *
	for  := 0;  < 7; ++ {
		[+1].Multiply(&[], &)
	}
	// Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
	// so t**k = t[k/2] for odd k

	// To compute the sliding window digits, use the following Sage script:

	// sage: import itertools
	// sage: def sliding_window(w,k):
	// ....:     digits = []
	// ....:     while k > 0:
	// ....:         if k % 2 == 1:
	// ....:             kmod = k % (2**w)
	// ....:             digits.append(kmod)
	// ....:             k = k - kmod
	// ....:         else:
	// ....:             digits.append(0)
	// ....:         k = k // 2
	// ....:     return digits

	// Now we can compute s roughly as follows:

	// sage: s = 1
	// sage: for coeff in reversed(sliding_window(4,l-2)):
	// ....:     s = s*s
	// ....:     if coeff > 0 :
	// ....:         s = s*t**coeff

	// This works on one bit at a time, with many runs of zeros.
	// The digits can be collapsed into [(count, coeff)] as follows:

	// sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]

	// Entries of the form (k, 0) turn into pow2k(k)
	// Entries of the form (1, coeff) turn into a squaring and then a table lookup.
	// We can fold the squaring into the previous pow2k(k) as pow2k(k+1).

	* = [1/2]
	.pow2k(127 + 1)
	.Multiply(, &[1/2])
	.pow2k(4 + 1)
	.Multiply(, &[9/2])
	.pow2k(3 + 1)
	.Multiply(, &[11/2])
	.pow2k(3 + 1)
	.Multiply(, &[13/2])
	.pow2k(3 + 1)
	.Multiply(, &[15/2])
	.pow2k(4 + 1)
	.Multiply(, &[7/2])
	.pow2k(4 + 1)
	.Multiply(, &[15/2])
	.pow2k(3 + 1)
	.Multiply(, &[5/2])
	.pow2k(3 + 1)
	.Multiply(, &[1/2])
	.pow2k(4 + 1)
	.Multiply(, &[15/2])
	.pow2k(4 + 1)
	.Multiply(, &[15/2])
	.pow2k(4 + 1)
	.Multiply(, &[7/2])
	.pow2k(3 + 1)
	.Multiply(, &[3/2])
	.pow2k(4 + 1)
	.Multiply(, &[11/2])
	.pow2k(5 + 1)
	.Multiply(, &[11/2])
	.pow2k(9 + 1)
	.Multiply(, &[9/2])
	.pow2k(3 + 1)
	.Multiply(, &[3/2])
	.pow2k(4 + 1)
	.Multiply(, &[3/2])
	.pow2k(4 + 1)
	.Multiply(, &[3/2])
	.pow2k(4 + 1)
	.Multiply(, &[9/2])
	.pow2k(3 + 1)
	.Multiply(, &[7/2])
	.pow2k(3 + 1)
	.Multiply(, &[3/2])
	.pow2k(3 + 1)
	.Multiply(, &[13/2])
	.pow2k(3 + 1)
	.Multiply(, &[7/2])
	.pow2k(4 + 1)
	.Multiply(, &[9/2])
	.pow2k(3 + 1)
	.Multiply(, &[15/2])
	.pow2k(4 + 1)
	.Multiply(, &[11/2])

	return 
}

// MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends only on the lengths of the two slices, which must match.
func ( *Point) ( []*Scalar,  []*Point) *Point {
	if len() != len() {
		panic("edwards25519: called MultiScalarMult with different size inputs")
	}
	checkInitialized(...)

	// Proceed as in the single-base case, but share doublings
	// between each point in the multiscalar equation.

	// Build lookup tables for each point
	 := make([]projLookupTable, len())
	for  := range  {
		[].FromP3([])
	}
	// Compute signed radix-16 digits for each scalar
	 := make([][64]int8, len())
	for  := range  {
		[] = [].signedRadix16()
	}

	// Unwrap first loop iteration to save computing 16*identity
	 := &projCached{}
	 := &projP1xP1{}
	 := &projP2{}
	// Lookup-and-add the appropriate multiple of each input point
	for  := range  {
		[].SelectInto(, [][63])
		.Add(, ) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
		.fromP1xP1()     // update v
	}
	.FromP3() // set up tmp2 = v in P2 coords for next iteration
	for  := 62;  >= 0; -- {
		.Double()    // tmp1 =  2*(prev) in P1xP1 coords
		.FromP1xP1() // tmp2 =  2*(prev) in P2 coords
		.Double()    // tmp1 =  4*(prev) in P1xP1 coords
		.FromP1xP1() // tmp2 =  4*(prev) in P2 coords
		.Double()    // tmp1 =  8*(prev) in P1xP1 coords
		.FromP1xP1() // tmp2 =  8*(prev) in P2 coords
		.Double()    // tmp1 = 16*(prev) in P1xP1 coords
		.fromP1xP1()    //    v = 16*(prev) in P3 coords
		// Lookup-and-add the appropriate multiple of each input point
		for  := range  {
			[].SelectInto(, [][])
			.Add(, ) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
			.fromP1xP1()     // update v
		}
		.FromP3() // set up tmp2 = v in P2 coords for next iteration
	}
	return 
}

// VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
//
// Execution time depends on the inputs.
func ( *Point) ( []*Scalar,  []*Point) *Point {
	if len() != len() {
		panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
	}
	checkInitialized(...)

	// Generalize double-base NAF computation to arbitrary sizes.
	// Here all the points are dynamic, so we only use the smaller
	// tables.

	// Build lookup tables for each point
	 := make([]nafLookupTable5, len())
	for  := range  {
		[].FromP3([])
	}
	// Compute a NAF for each scalar
	 := make([][256]int8, len())
	for  := range  {
		[] = [].nonAdjacentForm(5)
	}

	 := &projCached{}
	 := &projP1xP1{}
	 := &projP2{}
	.Zero()

	// Move from high to low bits, doubling the accumulator
	// at each iteration and checking whether there is a nonzero
	// coefficient to look up a multiple of.
	//
	// Skip trying to find the first nonzero coefficent, because
	// searching might be more work than a few extra doublings.
	for  := 255;  >= 0; -- {
		.Double()

		for  := range  {
			if [][] > 0 {
				.fromP1xP1()
				[].SelectInto(, [][])
				.Add(, )
			} else if [][] < 0 {
				.fromP1xP1()
				[].SelectInto(, -[][])
				.Sub(, )
			}
		}

		.FromP1xP1()
	}

	.fromP2()
	return 
}