package graph

import (
	
	
	
)

// MinimumSpanningTree returns a minimum spanning tree within the given graph.
//
// The MST contains all vertices from the given graph as well as the required
// edges for building the MST. The original graph remains unchanged.
func [ comparable,  any]( Graph[, ]) (Graph[, ], error) {
	return spanningTree(, false)
}

// MaximumSpanningTree returns a minimum spanning tree within the given graph.
//
// The MST contains all vertices from the given graph as well as the required
// edges for building the MST. The original graph remains unchanged.
func [ comparable,  any]( Graph[, ]) (Graph[, ], error) {
	return spanningTree(, true)
}

func spanningTree[ comparable,  any]( Graph[, ],  bool) (Graph[, ], error) {
	if .Traits().IsDirected {
		return nil, errors.New("spanning trees can only be determined for undirected graphs")
	}

	,  := .AdjacencyMap()
	if  != nil {
		return nil, fmt.Errorf("failed to get adjacency map: %w", )
	}

	 := make([]Edge[], 0)
	 := newUnionFind[]()

	 := NewLike()

	for ,  := range  {
		, ,  := .VertexWithProperties() //nolint:govet
		if  != nil {
			return nil, fmt.Errorf("failed to get vertex %v: %w", , )
		}

		 = .AddVertex(, copyVertexProperties())
		if  != nil {
			return nil, fmt.Errorf("failed to add vertex %v: %w", , )
		}

		.add()

		for ,  := range  {
			 = append(, )
		}
	}

	if  {
		sort.Slice(, func(,  int) bool {
			return [].Properties.Weight > [].Properties.Weight
		})
	} else {
		sort.Slice(, func(,  int) bool {
			return [].Properties.Weight < [].Properties.Weight
		})
	}

	for ,  := range  {
		 := .find(.Source)
		 := .find(.Target)

		if  !=  {
			.union(, )

			if  = .AddEdge(copyEdge());  != nil {
				return nil, fmt.Errorf("failed to add edge (%v, %v): %w", .Source, .Target, )
			}
		}
	}

	return , nil
}