/*
 * Copyright (c) 2014 The btcsuite developers
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

// This file intended to be copied into each backend driver directory.  Each
// driver should have their own driver_test.go file which creates a database and
// invokes the testInterface function in this file to ensure the driver properly
// implements the interface.  See the bdb backend driver for a working example.
//
// NOTE: When copying this file into the backend driver folder, the package name
// will need to be changed accordingly.

package walletdb_test

import (
	"fmt"
	"reflect"
	"testing"

	"github.com/btcsuite/btcwallet/walletdb"
)

// subTestFailError is used to signal that a sub test returned false.
var subTestFailError = fmt.Errorf("sub test failure")

// testContext is used to store context information about a running test which
// is passed into helper functions.
type testContext struct {
	t           *testing.T
	db          walletdb.DB
	bucketDepth int
	isWritable  bool
}

// rollbackValues returns a copy of the provided map with all values set to an
// empty string.  This is used to test that values are properly rolled back.
func rollbackValues(values map[string]string) map[string]string {
	retMap := make(map[string]string, len(values))
	for k := range values {
		retMap[k] = ""
	}
	return retMap
}

// testGetValues checks that all of the provided key/value pairs can be
// retrieved from the database and the retrieved values match the provided
// values.
func testGetValues(tc *testContext, bucket walletdb.Bucket, values map[string]string) bool {
	for k, v := range values {
		var vBytes []byte
		if v != "" {
			vBytes = []byte(v)
		}

		gotValue := bucket.Get([]byte(k))
		if !reflect.DeepEqual(gotValue, vBytes) {
			tc.t.Errorf("Get: unexpected value - got %s, want %s",
				gotValue, vBytes)
			return false
		}
	}

	return true
}

// testPutValues stores all of the provided key/value pairs in the provided
// bucket while checking for errors.
func testPutValues(tc *testContext, bucket walletdb.Bucket, values map[string]string) bool {
	for k, v := range values {
		var vBytes []byte
		if v != "" {
			vBytes = []byte(v)
		}
		if err := bucket.Put([]byte(k), vBytes); err != nil {
			tc.t.Errorf("Put: unexpected error: %v", err)
			return false
		}
	}

	return true
}

// testDeleteValues removes all of the provided key/value pairs from the
// provided bucket.
func testDeleteValues(tc *testContext, bucket walletdb.Bucket, values map[string]string) bool {
	for k := range values {
		if err := bucket.Delete([]byte(k)); err != nil {
			tc.t.Errorf("Delete: unexpected error: %v", err)
			return false
		}
	}

	return true
}

// testNestedBucket reruns the testBucketInterface against a nested bucket along
// with a counter to only test a couple of level deep.
func testNestedBucket(tc *testContext, testBucket walletdb.Bucket) bool {
	// Don't go more than 2 nested level deep.
	if tc.bucketDepth > 1 {
		return true
	}

	tc.bucketDepth++
	defer func() {
		tc.bucketDepth--
	}()
	if !testBucketInterface(tc, testBucket) {
		return false
	}

	return true
}

// testBucketInterface ensures the bucket interface is working properly by
// exercising all of its functions.
func testBucketInterface(tc *testContext, bucket walletdb.Bucket) bool {
	if bucket.Writable() != tc.isWritable {
		tc.t.Errorf("Bucket writable state does not match.")
		return false
	}

	if tc.isWritable {
		// keyValues holds the keys and values to use when putting
		// values into the bucket.
		var keyValues = map[string]string{
			"bucketkey1": "foo1",
			"bucketkey2": "foo2",
			"bucketkey3": "foo3",
		}
		if !testPutValues(tc, bucket, keyValues) {
			return false
		}

		if !testGetValues(tc, bucket, keyValues) {
			return false
		}

		// Iterate all of the keys using ForEach while making sure the
		// stored values are the expected values.
		keysFound := make(map[string]struct{}, len(keyValues))
		err := bucket.ForEach(func(k, v []byte) error {
			kString := string(k)
			wantV, ok := keyValues[kString]
			if !ok {
				return fmt.Errorf("ForEach: key '%s' should "+
					"exist", kString)
			}

			if !reflect.DeepEqual(v, []byte(wantV)) {
				return fmt.Errorf("ForEach: value for key '%s' "+
					"does not match - got %s, want %s",
					kString, v, wantV)
			}

			keysFound[kString] = struct{}{}
			return nil
		})
		if err != nil {
			tc.t.Errorf("%v", err)
			return false
		}

		// Ensure all keys were iterated.
		for k := range keyValues {
			if _, ok := keysFound[k]; !ok {
				tc.t.Errorf("ForEach: key '%s' was not iterated "+
					"when it should have been", k)
				return false
			}
		}

		// Delete the keys and ensure they were deleted.
		if !testDeleteValues(tc, bucket, keyValues) {
			return false
		}
		if !testGetValues(tc, bucket, rollbackValues(keyValues)) {
			return false
		}

		// Ensure creating a new bucket works as expected.
		testBucketName := []byte("testbucket")
		testBucket, err := bucket.CreateBucket(testBucketName)
		if err != nil {
			tc.t.Errorf("CreateBucket: unexpected error: %v", err)
			return false
		}
		if !testNestedBucket(tc, testBucket) {
			return false
		}

		// Ensure creating a bucket that already exists fails with the
		// expected error.
		wantErr := walletdb.ErrBucketExists
		if _, err := bucket.CreateBucket(testBucketName); err != wantErr {
			tc.t.Errorf("CreateBucket: unexpected error - got %v, "+
				"want %v", err, wantErr)
			return false
		}

		// Ensure CreateBucketIfNotExists returns an existing bucket.
		testBucket, err = bucket.CreateBucketIfNotExists(testBucketName)
		if err != nil {
			tc.t.Errorf("CreateBucketIfNotExists: unexpected "+
				"error: %v", err)
			return false
		}
		if !testNestedBucket(tc, testBucket) {
			return false
		}

		// Ensure retrieving and existing bucket works as expected.
		testBucket = bucket.Bucket(testBucketName)
		if !testNestedBucket(tc, testBucket) {
			return false
		}

		// Ensure deleting a bucket works as intended.
		if err := bucket.DeleteBucket(testBucketName); err != nil {
			tc.t.Errorf("DeleteBucket: unexpected error: %v", err)
			return false
		}
		if b := bucket.Bucket(testBucketName); b != nil {
			tc.t.Errorf("DeleteBucket: bucket '%s' still exists",
				testBucketName)
			return false
		}

		// Ensure deleting a bucket that doesn't exist returns the
		// expected error.
		wantErr = walletdb.ErrBucketNotFound
		if err := bucket.DeleteBucket(testBucketName); err != wantErr {
			tc.t.Errorf("DeleteBucket: unexpected error - got %v, "+
				"want %v", err, wantErr)
			return false
		}

		// Ensure CreateBucketIfNotExists creates a new bucket when
		// it doesn't already exist.
		testBucket, err = bucket.CreateBucketIfNotExists(testBucketName)
		if err != nil {
			tc.t.Errorf("CreateBucketIfNotExists: unexpected "+
				"error: %v", err)
			return false
		}
		if !testNestedBucket(tc, testBucket) {
			return false
		}

		// Delete the test bucket to avoid leaving it around for future
		// calls.
		if err := bucket.DeleteBucket(testBucketName); err != nil {
			tc.t.Errorf("DeleteBucket: unexpected error: %v", err)
			return false
		}
		if b := bucket.Bucket(testBucketName); b != nil {
			tc.t.Errorf("DeleteBucket: bucket '%s' still exists",
				testBucketName)
			return false
		}
	} else {
		// Put should fail with bucket that is not writable.
		wantErr := walletdb.ErrTxNotWritable
		failBytes := []byte("fail")
		if err := bucket.Put(failBytes, failBytes); err != wantErr {
			tc.t.Errorf("Put did not fail with unwritable bucket")
			return false
		}

		// Delete should fail with bucket that is not writable.
		if err := bucket.Delete(failBytes); err != wantErr {
			tc.t.Errorf("Put did not fail with unwritable bucket")
			return false
		}

		// CreateBucket should fail with bucket that is not writable.
		if _, err := bucket.CreateBucket(failBytes); err != wantErr {
			tc.t.Errorf("CreateBucket did not fail with unwritable " +
				"bucket")
			return false
		}

		// CreateBucketIfNotExists should fail with bucket that is not
		// writable.
		if _, err := bucket.CreateBucketIfNotExists(failBytes); err != wantErr {
			tc.t.Errorf("CreateBucketIfNotExists did not fail with " +
				"unwritable bucket")
			return false
		}

		// DeleteBucket should fail with bucket that is not writable.
		if err := bucket.DeleteBucket(failBytes); err != wantErr {
			tc.t.Errorf("DeleteBucket did not fail with unwritable " +
				"bucket")
			return false
		}
	}

	return true
}

// testManualTxInterface ensures that manual transactions work as expected.
func testManualTxInterface(tc *testContext, namespace walletdb.Namespace) bool {
	// populateValues tests that populating values works as expected.
	//
	// When the writable flag is false, a read-only tranasction is created,
	// standard bucket tests for read-only transactions are performed, and
	// the Commit function is checked to ensure it fails as expected.
	//
	// Otherwise, a read-write transaction is created, the values are
	// written, standard bucket tests for read-write transactions are
	// performed, and then the transaction is either commited or rolled
	// back depending on the flag.
	populateValues := func(writable, rollback bool, putValues map[string]string) bool {
		tx, err := namespace.Begin(writable)
		if err != nil {
			tc.t.Errorf("Begin: unexpected error %v", err)
			return false
		}

		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			tc.t.Errorf("RootBucket: unexpected nil root bucket")
			_ = tx.Rollback()
			return false
		}

		tc.isWritable = writable
		if !testBucketInterface(tc, rootBucket) {
			_ = tx.Rollback()
			return false
		}

		if !writable {
			// The transaction is not writable, so it should fail
			// the commit.
			if err := tx.Commit(); err != walletdb.ErrTxNotWritable {
				tc.t.Errorf("Commit: unexpected error %v, "+
					"want %v", err, walletdb.ErrTxNotWritable)
				_ = tx.Rollback()
				return false
			}

			// Rollback the transaction.
			if err := tx.Rollback(); err != nil {
				tc.t.Errorf("Commit: unexpected error %v", err)
				return false
			}
		} else {
			if !testPutValues(tc, rootBucket, putValues) {
				return false
			}

			if rollback {
				// Rollback the transaction.
				if err := tx.Rollback(); err != nil {
					tc.t.Errorf("Rollback: unexpected "+
						"error %v", err)
					return false
				}
			} else {
				// The commit should succeed.
				if err := tx.Commit(); err != nil {
					tc.t.Errorf("Commit: unexpected error "+
						"%v", err)
					return false
				}
			}
		}

		return true
	}

	// checkValues starts a read-only transaction and checks that all of
	// the key/value pairs specified in the expectedValues parameter match
	// what's in the database.
	checkValues := func(expectedValues map[string]string) bool {
		// Begin another read-only transaction to ensure...
		tx, err := namespace.Begin(false)
		if err != nil {
			tc.t.Errorf("Begin: unexpected error %v", err)
			return false
		}

		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			tc.t.Errorf("RootBucket: unexpected nil root bucket")
			_ = tx.Rollback()
			return false
		}

		if !testGetValues(tc, rootBucket, expectedValues) {
			_ = tx.Rollback()
			return false
		}

		// Rollback the read-only transaction.
		if err := tx.Rollback(); err != nil {
			tc.t.Errorf("Commit: unexpected error %v", err)
			return false
		}

		return true
	}

	// deleteValues starts a read-write transaction and deletes the keys
	// in the passed key/value pairs.
	deleteValues := func(values map[string]string) bool {
		tx, err := namespace.Begin(true)
		if err != nil {

		}

		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			tc.t.Errorf("RootBucket: unexpected nil root bucket")
			_ = tx.Rollback()
			return false
		}

		// Delete the keys and ensure they were deleted.
		if !testDeleteValues(tc, rootBucket, values) {
			_ = tx.Rollback()
			return false
		}
		if !testGetValues(tc, rootBucket, rollbackValues(values)) {
			_ = tx.Rollback()
			return false
		}

		// Commit the changes and ensure it was successful.
		if err := tx.Commit(); err != nil {
			tc.t.Errorf("Commit: unexpected error %v", err)
			return false
		}

		return true
	}

	// keyValues holds the keys and values to use when putting values
	// into a bucket.
	var keyValues = map[string]string{
		"umtxkey1": "foo1",
		"umtxkey2": "foo2",
		"umtxkey3": "foo3",
	}

	// Ensure that attempting populating the values using a read-only
	// transaction fails as expected.
	if !populateValues(false, true, keyValues) {
		return false
	}
	if !checkValues(rollbackValues(keyValues)) {
		return false
	}

	// Ensure that attempting populating the values using a read-write
	// transaction and then rolling it back yields the expected values.
	if !populateValues(true, true, keyValues) {
		return false
	}
	if !checkValues(rollbackValues(keyValues)) {
		return false
	}

	// Ensure that attempting populating the values using a read-write
	// transaction and then committing it stores the expected values.
	if !populateValues(true, false, keyValues) {
		return false
	}
	if !checkValues(keyValues) {
		return false
	}

	// Clean up the keys.
	if !deleteValues(keyValues) {
		return false
	}

	return true
}

// testNamespaceAndTxInterfaces creates a namespace using the provided key and
// tests all facets of it interface as well as  transaction and bucket
// interfaces under it.
func testNamespaceAndTxInterfaces(tc *testContext, namespaceKey string) bool {
	namespaceKeyBytes := []byte(namespaceKey)
	namespace, err := tc.db.Namespace(namespaceKeyBytes)
	if err != nil {
		tc.t.Errorf("Namespace: unexpected error: %v", err)
		return false
	}
	defer func() {
		// Remove the namespace now that the tests are done for it.
		if err := tc.db.DeleteNamespace(namespaceKeyBytes); err != nil {
			tc.t.Errorf("DeleteNamespace: unexpected error: %v", err)
			return
		}
	}()

	if !testManualTxInterface(tc, namespace) {
		return false
	}

	// keyValues holds the keys and values to use when putting values
	// into a bucket.
	var keyValues = map[string]string{
		"mtxkey1": "foo1",
		"mtxkey2": "foo2",
		"mtxkey3": "foo3",
	}

	// Test the bucket interface via a managed read-only transaction.
	err = namespace.View(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		tc.isWritable = false
		if !testBucketInterface(tc, rootBucket) {
			return subTestFailError
		}

		return nil
	})
	if err != nil {
		if err != subTestFailError {
			tc.t.Errorf("%v", err)
		}
		return false
	}

	// Ensure errors returned from the user-supplied View function are
	// returned.
	viewError := fmt.Errorf("example view error")
	err = namespace.View(func(tx walletdb.Tx) error {
		return viewError
	})
	if err != viewError {
		tc.t.Errorf("View: inner function error not returned - got "+
			"%v, want %v", err, viewError)
		return false
	}

	// Test the bucket interface via a managed read-write transaction.
	// Also, put a series of values and force a rollback so the following
	// code can ensure the values were not stored.
	forceRollbackError := fmt.Errorf("force rollback")
	err = namespace.Update(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		tc.isWritable = true
		if !testBucketInterface(tc, rootBucket) {
			return subTestFailError
		}

		if !testPutValues(tc, rootBucket, keyValues) {
			return subTestFailError
		}

		// Return an error to force a rollback.
		return forceRollbackError
	})
	if err != forceRollbackError {
		if err == subTestFailError {
			return false
		}

		tc.t.Errorf("Update: inner function error not returned - got "+
			"%v, want %v", err, forceRollbackError)
		return false
	}

	// Ensure the values that should have not been stored due to the forced
	// rollback above were not actually stored.
	err = namespace.View(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		if !testGetValues(tc, rootBucket, rollbackValues(keyValues)) {
			return subTestFailError
		}

		return nil
	})
	if err != nil {
		if err != subTestFailError {
			tc.t.Errorf("%v", err)
		}
		return false
	}

	// Store a series of values via a managed read-write transaction.
	err = namespace.Update(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		if !testPutValues(tc, rootBucket, keyValues) {
			return subTestFailError
		}

		return nil
	})
	if err != nil {
		if err != subTestFailError {
			tc.t.Errorf("%v", err)
		}
		return false
	}

	// Ensure the values stored above were committed as expected.
	err = namespace.View(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		if !testGetValues(tc, rootBucket, keyValues) {
			return subTestFailError
		}

		return nil
	})
	if err != nil {
		if err != subTestFailError {
			tc.t.Errorf("%v", err)
		}
		return false
	}

	// Clean up the values stored above in a managed read-write transaction.
	err = namespace.Update(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		if !testDeleteValues(tc, rootBucket, keyValues) {
			return subTestFailError
		}

		return nil
	})
	if err != nil {
		if err != subTestFailError {
			tc.t.Errorf("%v", err)
		}
		return false
	}

	return true
}

// testAdditionalErrors performs some tests for error cases not covered
// elsewhere in the tests and therefore improves negative test coverage.
func testAdditionalErrors(tc *testContext) bool {
	// Create a new namespace and then intentionally delete the namespace
	// bucket out from under it to force errors.
	ns3Key := []byte("ns3")
	ns3, err := tc.db.Namespace(ns3Key)
	if err != nil {
		tc.t.Errorf("Namespace: unexpected error: %v", err)
		return false
	}
	if err := tc.db.DeleteNamespace(ns3Key); err != nil {
		tc.t.Errorf("DeleteNamespace: unexpected error: %v", err)
		return false
	}

	// Ensure Begin fails when the namespace bucket does not exist.
	wantErr := walletdb.ErrBucketNotFound
	if _, err := ns3.Begin(false); err != wantErr {
		tc.t.Errorf("Begin: did not receive expected error - got %v, "+
			"want %v", err, wantErr)
		return false
	}

	// Ensure View fails when the namespace bucket does not exist.
	err = ns3.View(func(tx walletdb.Tx) error {
		return nil
	})
	if err != wantErr {
		tc.t.Errorf("View: did not receive expected error - got %v, "+
			"want %v", err, wantErr)
		return false
	}

	// Ensure Update fails when the namespace bucket does not exist.
	err = ns3.Update(func(tx walletdb.Tx) error {
		return nil
	})
	if err != wantErr {
		tc.t.Errorf("View: did not receive expected error - got %v, "+
			"want %v", err, wantErr)
		return false
	}

	// Recreate the namespace to bring the bucket back.
	ns3, err = tc.db.Namespace(ns3Key)
	if err != nil {
		tc.t.Errorf("Namespace: unexpected error: %v", err)
		return false
	}
	defer func() {
		// Remove the namespace now that the tests are done for it.
		if err := tc.db.DeleteNamespace(ns3Key); err != nil {
			tc.t.Errorf("DeleteNamespace: unexpected error: %v", err)
			return
		}
	}()

	err = ns3.Update(func(tx walletdb.Tx) error {
		rootBucket := tx.RootBucket()
		if rootBucket == nil {
			return fmt.Errorf("RootBucket: unexpected nil root bucket")
		}

		// Ensure CreateBucket returns the expected error when no bucket
		// key is specified.
		wantErr := walletdb.ErrBucketNameRequired
		if _, err := rootBucket.CreateBucket(nil); err != wantErr {
			return fmt.Errorf("CreateBucket: unexpected error - "+
				"got %v, want %v", err, wantErr)
		}

		// Ensure DeleteBucket returns the expected error when no bucket
		// key is specified.
		wantErr = walletdb.ErrIncompatibleValue
		if err := rootBucket.DeleteBucket(nil); err != wantErr {
			return fmt.Errorf("DeleteBucket: unexpected error - "+
				"got %v, want %v", err, wantErr)
		}

		// Ensure Put returns the expected error when no key is
		// specified.
		wantErr = walletdb.ErrKeyRequired
		if err := rootBucket.Put(nil, nil); err != wantErr {
			return fmt.Errorf("Put: unexpected error - got %v, "+
				"want %v", err, wantErr)
		}

		return nil
	})
	if err != nil {
		if err != subTestFailError {
			tc.t.Errorf("%v", err)
		}
		return false
	}

	// Ensure that attempting to rollback or commit a transaction that is
	// already closed returns the expected error.
	tx, err := ns3.Begin(false)
	if err != nil {
		tc.t.Errorf("Begin: unexpected error: %v", err)
		return false
	}
	if err := tx.Rollback(); err != nil {
		tc.t.Errorf("Rollback: unexpected error: %v", err)
		return false
	}
	wantErr = walletdb.ErrTxClosed
	if err := tx.Rollback(); err != wantErr {
		tc.t.Errorf("Rollback: unexpected error - got %v, want %v", err,
			wantErr)
		return false
	}
	if err := tx.Commit(); err != wantErr {
		tc.t.Errorf("Commit: unexpected error - got %v, want %v", err,
			wantErr)
		return false
	}

	return true
}

// testInterface tests performs tests for the various interfaces of walletdb
// which require state in the database for the given database type.
func testInterface(t *testing.T, db walletdb.DB) {
	// Create a test context to pass around.
	context := testContext{t: t, db: db}

	// Create a namespace and test the interface for it.
	if !testNamespaceAndTxInterfaces(&context, "ns1") {
		return
	}

	// Create a second namespace and test the interface for it.
	if !testNamespaceAndTxInterfaces(&context, "ns2") {
		return
	}

	// Check a few more error conditions not covered elsewhere.
	if !testAdditionalErrors(&context) {
		return
	}
}