package boil

import (
	"bytes"
	"fmt"
	"reflect"
	"sort"
	"strings"
	"unicode"

	"github.com/nullbio/sqlboiler/strmangle"
)

// SetComplement subtracts the elements in b from a
func SetComplement(a []string, b []string) []string {
	c := make([]string, 0, len(a))

	for _, aVal := range a {
		found := false
		for _, bVal := range b {
			if aVal == bVal {
				found = true
				break
			}
		}
		if !found {
			c = append(c, aVal)
		}
	}

	return c
}

// SetIntersect returns the elements that are common in a and b
func SetIntersect(a []string, b []string) []string {
	c := make([]string, 0, len(a))

	for _, aVal := range a {
		found := false
		for _, bVal := range b {
			if aVal == bVal {
				found = true
				break
			}
		}
		if found {
			c = append(c, aVal)
		}
	}

	return c
}

// SetMerge will return a merged slice without duplicates
func SetMerge(a []string, b []string) []string {
	var x, merged []string

	x = append(x, a...)
	x = append(x, b...)

	check := map[string]bool{}
	for _, v := range x {
		if check[v] == true {
			continue
		}

		merged = append(merged, v)
		check[v] = true
	}

	return merged
}

// NonZeroDefaultSet returns the fields included in the
// defaults slice that are non zero values
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
	c := make([]string, 0, len(defaults))

	val := reflect.Indirect(reflect.ValueOf(obj))

	for _, d := range defaults {
		fieldName := strmangle.TitleCase(d)
		field := val.FieldByName(fieldName)
		if !field.IsValid() {
			panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))
		}

		zero := reflect.Zero(field.Type())
		if !reflect.DeepEqual(zero.Interface(), field.Interface()) {
			c = append(c, d)
		}
	}

	return c
}

// SortByKeys returns a new ordered slice based on the keys ordering
func SortByKeys(keys []string, strs []string) []string {
	c := make([]string, len(strs))

	index := 0
Outer:
	for _, v := range keys {
		for _, k := range strs {
			if v == k {
				c[index] = v
				index++

				if index > len(strs)-1 {
					break Outer
				}
				break
			}
		}
	}

	return c
}

// GenerateParamFlags generates the SQL statement parameter flags
// For example, $1,$2,$3 etc. It will start counting at startAt.
func GenerateParamFlags(colCount int, startAt int) string {
	return strmangle.GenerateParamFlags(colCount, startAt)
}

// WherePrimaryKeyIn generates a "in" string for where queries
// For example: ("col1","col2") IN (($1,$2), ($3,$4))
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
	in := &bytes.Buffer{}

	if len(keyNames) == 0 {
		return ""
	}

	in.WriteByte('(')
	for i := 0; i < len(keyNames); i++ {
		in.WriteString(`"` + keyNames[i] + `"`)
		if i < len(keyNames)-1 {
			in.WriteByte(',')
		}
	}

	in.WriteString(") IN (")

	c := 1
	for i := 0; i < numRows; i++ {
		for y := 0; y < len(keyNames); y++ {
			if len(keyNames) > 1 && y == 0 {
				in.WriteByte('(')
			}

			in.WriteString(fmt.Sprintf("$%d", c))
			c++

			if len(keyNames) > 1 && y == len(keyNames)-1 {
				in.WriteByte(')')
			}

			if i != numRows-1 || y != len(keyNames)-1 {
				in.WriteByte(',')
			}
		}
	}
	in.WriteByte(')')

	return in.String()
}

// SelectNames returns the column names for a select statement
// Eg: "col1", "col2", "col3"
func SelectNames(results interface{}) string {
	var names []string

	structValue := reflect.Indirect(reflect.ValueOf(results))

	structType := structValue.Type()
	for i := 0; i < structValue.NumField(); i++ {
		field := structType.Field(i)
		var name string

		if db := field.Tag.Get("db"); len(db) != 0 {
			name = db
		} else {
			name = goVarToSQLName(field.Name)
		}

		names = append(names, fmt.Sprintf(`"%s"`, name))
	}

	return strings.Join(names, ", ")
}

// WhereClause returns the where clause for an sql statement
// eg: "col1"=$1 AND "col2"=$2 AND "col3"=$3
func WhereClause(columns []string) string {
	names := make([]string, 0, len(columns))

	for _, c := range columns {
		names = append(names, c)
	}

	for i, c := range names {
		names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
	}

	return strings.Join(names, " AND ")
}

// Update returns the column list for an update statement SET clause
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func Update(columns map[string]interface{}) string {
	names := make([]string, 0, len(columns))

	for c := range columns {
		names = append(names, c)
	}

	sort.Strings(names)

	for i, c := range names {
		names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
	}

	return strings.Join(names, ", ")
}

// SetParamNames takes a slice of columns and returns a comma seperated
// list of parameter names for a template statement SET clause.
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string {
	names := make([]string, 0, len(columns))
	counter := 0
	for _, c := range columns {
		counter++
		names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
	}
	return strings.Join(names, ", ")
}

// WherePrimaryKey returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WherePrimaryKey(start int, pkeys ...string) string {
	var output string
	for i, c := range pkeys {
		output = fmt.Sprintf(`%s"%s"=$%d`, output, c, start)
		start++

		if i < len(pkeys)-1 {
			output = fmt.Sprintf("%s AND ", output)
		}
	}

	return output
}

// goVarToSQLName converts a go variable name to a column name
// example: HelloFriendID to hello_friend_id
func goVarToSQLName(name string) string {
	str := &bytes.Buffer{}
	isUpper, upperStreak := false, false

	for i := 0; i < len(name); i++ {
		c := rune(name[i])
		if unicode.IsDigit(c) || unicode.IsLower(c) {
			isUpper = false
			upperStreak = false

			str.WriteRune(c)
			continue
		}

		if isUpper {
			upperStreak = true
		} else if i != 0 {
			str.WriteByte('_')
		}
		isUpper = true

		if j := i + 1; j < len(name) && upperStreak && unicode.IsLower(rune(name[j])) {
			str.WriteByte('_')
		}

		str.WriteRune(unicode.ToLower(c))
	}

	return str.String()
}