From 2b83e50c925e8f151e22ce57c4604d50102726e6 Mon Sep 17 00:00:00 2001 From: Brannon King Date: Thu, 24 Oct 2019 13:09:56 -0600 Subject: [PATCH] first tests ran, working to make takeover height unnecessary --- src/claimtrie.cpp | 193 ++- src/claimtrie.h | 29 +- src/claimtrieforks.cpp | 57 +- src/miner.cpp | 3 +- src/scheduler.cpp | 2 +- src/sqlite/hdr/sqlite_modern_cpp.h | 1291 ++++++++--------- src/sqlite/hdr/sqlite_modern_cpp/errors.h | 67 +- .../hdr/sqlite_modern_cpp/type_wrapper.h | 508 +++---- src/test/claimtriecache_tests.cpp | 9 +- src/test/claimtrienormalization_tests.cpp | 8 +- src/undo.h | 2 - src/validation.cpp | 4 +- 12 files changed, 1075 insertions(+), 1098 deletions(-) diff --git a/src/claimtrie.cpp b/src/claimtrie.cpp index d005a2493..ab1c3d720 100644 --- a/src/claimtrie.cpp +++ b/src/claimtrie.cpp @@ -1,4 +1,3 @@ - #include #include #include @@ -54,8 +53,7 @@ CClaimTrie::CClaimTrie(bool fMemory, bool fWipe, int height, int proportionalDel _db.define("merkle_pair", [](const std::vector& blob1, const std::vector& blob2) { return Hash(blob1.begin(), blob1.end(), blob2.begin(), blob2.end()); }); _db.define("merkle", [](const std::vector& blob1) { return Hash(blob1.begin(), blob1.end()); }); - _db << "CREATE TABLE IF NOT EXISTS nodes (name TEXT NOT NULL PRIMARY KEY, parent TEXT, " - "lastTakeoverHeight INTEGER NOT NULL DEFAULT 0, hash BLOB)"; + _db << "CREATE TABLE IF NOT EXISTS nodes (name TEXT NOT NULL PRIMARY KEY, parent TEXT, hash BLOB)"; _db << "CREATE INDEX nodes_hash ON nodes (hash)"; _db << "CREATE INDEX nodes_parent ON nodes (parent)"; @@ -110,15 +108,17 @@ bool CClaimTrie::empty() { bool CClaimTrieCacheBase::haveClaim(const std::string& name, const COutPoint& outPoint) const { - auto query = base->_db << "SELECT 1 FROM claims WHERE nodeName = ? AND txID = ? AND txN = ? AND validHeight < ? LIMIT 1" - << name << outPoint.hash << outPoint.n << nNextHeight; + auto query = base->_db << "SELECT 1 FROM claims WHERE nodeName = ? AND txID = ? AND txN = ? " + "AND validHeight < ? AND expirationHeight >= ? LIMIT 1" + << name << outPoint.hash << outPoint.n << nNextHeight << nNextHeight; return query.begin() != query.end(); } bool CClaimTrieCacheBase::haveSupport(const std::string& name, const COutPoint& outPoint) const { - auto query = base->_db << "SELECT 1 FROM supports WHERE nodeName = ? AND txID = ? AND txN = ? AND validHeight < ? LIMIT 1" - << name << outPoint.hash << outPoint.n << nNextHeight; + auto query = base->_db << "SELECT 1 FROM supports WHERE nodeName = ? AND txID = ? AND txN = ? " + "AND validHeight < ? AND expirationHeight >= ? LIMIT 1" + << name << outPoint.hash << outPoint.n << nNextHeight << nNextHeight; return query.begin() != query.end(); } @@ -126,7 +126,7 @@ supportEntryType CClaimTrieCacheBase::getSupportsForName(const std::string& name { // includes values that are not yet valid auto query = base->_db << "SELECT supportedClaimID, txID, txN, blockHeight, validHeight, amount " - "FROM supports WHERE nodeName = ?" << name; + "FROM supports WHERE nodeName = ? AND expirationHeight >= ?" << name << nNextHeight; supportEntryType ret; for (auto&& row: query) { CSupportValue value; @@ -139,8 +139,9 @@ supportEntryType CClaimTrieCacheBase::getSupportsForName(const std::string& name bool CClaimTrieCacheBase::haveClaimInQueue(const std::string& name, const COutPoint& outPoint, int& nValidAtHeight) const { - auto query = base->_db << "SELECT validHeight FROM claims WHERE nodeName = ? AND txID = ? AND txN = ? AND validHeight >= ? LIMIT 1" - << name << outPoint.hash << outPoint.n << nNextHeight; + auto query = base->_db << "SELECT validHeight FROM claims WHERE nodeName = ? AND txID = ? AND txN = ? " + "AND validHeight >= ? AND expirationHeight >= ? LIMIT 1" + << name << outPoint.hash << outPoint.n << nNextHeight << nNextHeight; for (auto&& row: query) { row >> nValidAtHeight; return true; @@ -150,8 +151,9 @@ bool CClaimTrieCacheBase::haveClaimInQueue(const std::string& name, const COutPo bool CClaimTrieCacheBase::haveSupportInQueue(const std::string& name, const COutPoint& outPoint, int& nValidAtHeight) const { - auto query = base->_db << "SELECT validHeight FROM supports WHERE nodeName = ? AND txID = ? AND txN = ? AND validHeight >= ? LIMIT 1" - << name << outPoint.hash << outPoint.n << nNextHeight; + auto query = base->_db << "SELECT validHeight FROM supports WHERE nodeName = ? AND txID = ? AND txN = ? " + "AND validHeight >= ? AND expirationHeight >= ? LIMIT 1" + << name << outPoint.hash << outPoint.n << nNextHeight << nNextHeight; for (auto&& row: query) { row >> nValidAtHeight; return true; @@ -172,7 +174,8 @@ bool CClaimTrieCacheBase::deleteNodeIfPossible(const std::string& name, std::str if (name.empty()) return false; // to remove a node it must have one or less children and no claims vector_builder claimsBuilder; - base->_db << "SELECT name FROM claims WHERE name = ?" << name >> claimsBuilder; + base->_db << "SELECT name FROM claims WHERE name = ? AND validHeight < ? AND expirationHeight >= ? " + << name << nNextHeight << nNextHeight >> claimsBuilder; claims = std::move(claimsBuilder); if (!claims.empty()) return false; // still has claims // we now know it has no claims, but we need to check its children @@ -210,9 +213,10 @@ void CClaimTrieCacheBase::ensureTreeStructureIsUpToDate() { // should we do the same to remove nodes? no; we need their last takeover height if they come back //float time = 0; + // assume parents are not set correctly here: auto parentQuery = base->_db << "SELECT name FROM nodes WHERE parent IS NOT NULL AND " "name IN (WITH RECURSIVE prefix(p) AS (VALUES(?) UNION ALL " - "SELECT SUBSTR(p, 0, LENGTH(p)) FROM prefix WHERE p != '') SELECT p FROM prefix) " + "SELECT SUBSTR(p, 1, LENGTH(p)) FROM prefix WHERE p != '') SELECT p FROM prefix) " "ORDER BY LENGTH(name) DESC LIMIT 1"; for (auto& name: names) { @@ -276,40 +280,50 @@ void CClaimTrieCacheBase::ensureTreeStructureIsUpToDate() { if (splitPos == 0) base->_db << "UPDATE nodes SET hash = NULL WHERE name = ?" << parent; } + + // now we need to percolate the nulls up the tree + // parents should all be set right + base->_db << "UPDATE nodes SET hash = NULL WHERE name IN (WITH RECURSIVE prefix(p) AS (" + "SELECT parent WHERE hash IS NULL ORDER BY name DESC UNION " + "SELECT parent FROM prefix,nodes WHERE name = p AND p IS NOT NULL)"; } std::size_t CClaimTrieCacheBase::getTotalNamesInTrie() const { // you could do this select from the nodes table, but you would have to ensure it is not dirty first std::size_t ret; - base->_db << "SELECT COUNT(DISTINCT nodeName) FROM claims" >> ret; + base->_db << "SELECT COUNT(DISTINCT nodeName) FROM claims WHERE validHeight < ? AND expirationHeight >= ?" + << nNextHeight << nNextHeight >> ret; return ret; } std::size_t CClaimTrieCacheBase::getTotalClaimsInTrie() const { std::size_t ret; - base->_db << "SELECT COUNT(*) FROM claims" >> ret; + base->_db << "SELECT COUNT(*) FROM claims WHERE validHeight < ? AND expirationHeight >= ?" + << nNextHeight << nNextHeight >> ret; return ret; } CAmount CClaimTrieCacheBase::getTotalValueOfClaimsInTrie(bool fControllingOnly) const { CAmount ret = 0; - std::string query("SELECT c.amount + SUM(SELECT s.amount FROM supports s WHERE s.supportedClaimID = c.claimID AND s.validHeight < ?)" - " FROM claims c WHERE c.validHeight < ?"); + std::string query("SELECT c.amount + SUM(SELECT s.amount FROM supports s " + "WHERE s.supportedClaimID = c.claimID AND s.validHeight < ? AND s.expirationHeight >= ?) " + "FROM claims c WHERE c.validHeight < ? AND s.expirationHeight >= ?"); if (fControllingOnly) throw std::runtime_error("not implemented yet"); // TODO: finish this - base->_db << query << nNextHeight << nNextHeight >> ret; + base->_db << query << nNextHeight << nNextHeight << nNextHeight << nNextHeight >> ret; return ret; } bool CClaimTrieCacheBase::getInfoForName(const std::string& name, CClaimValue& claim) const { - auto query = base->_db << "SELECT c.claimID, c.txID, c.txN, c.blockHeight, c.validHeight, c.amount, c.amount + " - "SUM(SELECT s.amount FROM supports s WHERE s.supportedClaimID = c.claimID AND s.validHeight < ?) as effectiveAmount" - "FROM claims c WHERE c.nodeName = ? AND c.validHeight < ? " - "ORDER BY effectiveAmount DESC, c.blockHeight, c.txID, c.txN LIMIT 1" << nNextHeight << name << nNextHeight; + auto query = base->_db << "SELECT c.claimID, c.txID, c.txN, c.blockHeight, c.validHeight, c.amount, " + "(SELECT TOTAL(s.amount)+c.amount FROM supports s WHERE s.supportedClaimID = c.claimID AND s.validHeight < ? AND s.expirationHeight >= ?) as effectiveAmount " + "FROM claims c WHERE c.nodeName = ? AND c.validHeight < ? AND c.expirationHeight >= ? " + "ORDER BY effectiveAmount DESC, c.blockHeight, c.txID, c.txN LIMIT 1" + << nNextHeight << nNextHeight << name << nNextHeight << nNextHeight; for (auto&& row: query) { row >> claim.claimId >> claim.outPoint.hash >> claim.outPoint.n >> claim.nHeight >> claim.nValidAtHeight >> claim.nAmount >> claim.nEffectiveAmount; @@ -325,8 +339,8 @@ CClaimSupportToName CClaimTrieCacheBase::getClaimsForName(const std::string& nam auto supports = getSupportsForName(name); auto query = base->_db << "SELECT claimID, txID, txN, blockHeight, validHeight, amount " - "FROM claims WHERE nodeName = ?" - << name; + "FROM claims WHERE nodeName = ? AND expirationHeight >= ?" + << name << nNextHeight; for (auto&& row: query) { CClaimValue claim; row >> claim.claimId >> claim.outPoint.hash >> claim.outPoint.n @@ -369,19 +383,18 @@ void completeHash(uint256& partialHash, const std::string& key, std::size_t to) .Finalize(partialHash.begin()); } -uint256 CClaimTrieCacheBase::recursiveComputeMerkleHash(const std::string& name, int lastTakeoverHeight, bool checkOnly) +uint256 CClaimTrieCacheBase::recursiveComputeMerkleHash(const std::string& name, bool checkOnly) { std::vector vchToHash; const auto pos = name.size(); - auto query = base->_db << "SELECT name, hash, lastTakeoverHeight FROM nodes WHERE parent = ? ORDER BY name" << name; + auto query = base->_db << "SELECT name, hash FROM nodes WHERE parent = ? ORDER BY name" << name; for (auto&& row : query) { std::string key; - int keyLastTakeoverHeight; std::unique_ptr hash; - row >> key >> hash >> keyLastTakeoverHeight; + row >> key >> hash; if (hash == nullptr) hash = std::make_unique(); if (hash->IsNull()) { - *hash = recursiveComputeMerkleHash(key, keyLastTakeoverHeight, checkOnly); + *hash = recursiveComputeMerkleHash(key, checkOnly); } completeHash(*hash, key, pos); vchToHash.push_back(key[pos]); @@ -390,11 +403,11 @@ uint256 CClaimTrieCacheBase::recursiveComputeMerkleHash(const std::string& name, CClaimValue claim; if (getInfoForName(name, claim)) { - uint256 valueHash = getValueHash(claim.outPoint, lastTakeoverHeight); + uint256 valueHash = getValueHash(claim.outPoint, claim.nValidAtHeight); vchToHash.insert(vchToHash.end(), valueHash.begin(), valueHash.end()); } - auto computedHash = Hash(vchToHash.begin(), vchToHash.end()); + auto computedHash = vchToHash.empty() ? one : Hash(vchToHash.begin(), vchToHash.end()); if (!checkOnly) base->_db << "UPDATE nodes SET hash = ? WHERE name = ?" << computedHash << name; return computedHash; @@ -404,13 +417,12 @@ bool CClaimTrieCacheBase::checkConsistency() { // verify that all claims hash to the values on the nodes - auto query = base->_db << "SELECT name, hash, lastTakeoverHeight FROM nodes"; + auto query = base->_db << "SELECT name, hash FROM nodes"; for (auto&& row: query) { std::string name; uint256 hash; - int takeoverHeight; - row >> name >> hash >> takeoverHeight; - auto computedHash = recursiveComputeMerkleHash(name, takeoverHeight, true); + row >> name >> hash; + auto computedHash = recursiveComputeMerkleHash(name, true); if (computedHash != hash) return false; } @@ -475,11 +487,10 @@ int CClaimTrieCacheBase::expirationTime() const uint256 CClaimTrieCacheBase::getMerkleHash() { ensureTreeStructureIsUpToDate(); - int lastTakeover; std::unique_ptr hash; - base->_db << "SELECT hash, lastTakeoverHeight FROM nodes WHERE name = ''" >> std::tie(hash, lastTakeover); + base->_db << "SELECT hash FROM nodes WHERE name = ''" >> hash; if (hash == nullptr || hash->IsNull()) - return recursiveComputeMerkleHash("", lastTakeover, false); + return recursiveComputeMerkleHash("", false); return *hash; } @@ -501,7 +512,9 @@ bool CClaimTrieCacheBase::addClaim(const std::string& name, const COutPoint& out auto expires = expirationTime() + nHeight; auto validHeight = nHeight + delay; base->_db << "INSERT INTO claims(claimID, name, nodeName, txID, txN, amount, blockHeight, validHeight, expirationHeight, metadata) " - "VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" << claimId << name << nodeName + "VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(claimID) DO UPDATE SET name = excluded.name, " + "nodeName = excluded.nodeName, txID = excluded.txID, txN = excluded.txN, amount = excluded.amount, " + "expirationHeight = excluded.expirationHeight, metadata = excluded.metadata" << claimId << name << nodeName << outPoint.hash << outPoint.n << nAmount << nHeight << validHeight << expires << metadata; base->_db << "INSERT INTO nodes(name) VALUES(?) ON CONFLICT(name) DO UPDATE SET hash = NULL" << nodeName; @@ -837,7 +850,7 @@ static const boost::container::flat_map, int> takeov {{ 653524, "celtic-folk-music-full-live-concert-mps" }, 589762}, }; -bool CClaimTrieCacheBase::incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo, std::vector>& takeoverHeightUndo) +bool CClaimTrieCacheBase::incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo) { // the plan: // for every claim and support that becomes active this block set its node hash to null (aka, dirty) @@ -863,8 +876,6 @@ bool CClaimTrieCacheBase::incrementBlock(insertUndoType& insertUndo, claimQueueR expireUndo.emplace_back(name, value); } } - base->_db << "UPDATE claims SET active = 0 WHERE expirationHeight = ?" - << nNextHeight; base->_db << "UPDATE nodes SET hash = NULL WHERE name IN (SELECT nodeName FROM claims WHERE expirationHeight = ?)" << nNextHeight; @@ -880,54 +891,26 @@ bool CClaimTrieCacheBase::incrementBlock(insertUndoType& insertUndo, claimQueueR expireSupportUndo.emplace_back(name, value); } } - base->_db << "UPDATE supports SET active = 0 WHERE expirationHeight = ?" - << nNextHeight; base->_db << "UPDATE nodes SET hash = NULL WHERE name IN (SELECT nodeName FROM supports WHERE expirationHeight = ?)" << nNextHeight; // takeover handling: - std::vector> takeovers; - base->_db << "SELECT name, lastTakeoverHeight FROM nodes WHERE hash IS NULL" >> takeovers; - - for (const auto& takeover : takeovers) { - // the plan: select the old and new bests - // if they are different, record the valid heights of the current claims that are not active - // then make them all active + vector_builder takeovers; + base->_db << "SELECT name FROM nodes WHERE hash IS NULL" >> takeovers; + for (const auto& nameWithTakeover : takeovers) { if (nNextHeight >= 496856 && nNextHeight <= 653524) { - auto wit = takeoverWorkarounds.find(std::make_pair(nNextHeight, takeover.first)); + auto wit = takeoverWorkarounds.find(std::make_pair(nNextHeight, nameWithTakeover)); if (wit != takeoverWorkarounds.end()) { - activateAllFor(insertUndo, insertSupportUndo, takeover.first); - base->_db << "UPDATE nodes SET lastTakeoverHeight = ? WHERE nodeName = ?" << wit->second << takeover.first; - takeoverHeightUndo.emplace_back(takeover.first, takeover.second); + activateAllFor(insertUndo, insertSupportUndo, nameWithTakeover); continue; } } - int lastTakeoverHeight = 0; - auto findBestValid = base->_db << "SELECT c.validHeight, c.amount + " - "SUM(SELECT s.amount FROM supports s WHERE s.supportedClaimID = c.claimID AND s.validHeight < ?) as effectiveAmount" - "FROM claims c WHERE c.nodeName = ? AND c.validHeight < ? " - "ORDER BY effectiveAmount DESC, c.blockHeight, c.txID, c.txN LIMIT 1" - << nNextHeight + 1 << takeover.first << nNextHeight + 1; - - auto lit = findBestValid.begin(); - if (lit == findBestValid.end()) { - takeoverHeightUndo.emplace_back(takeover.first, takeover.second); - continue; - } - *lit >> lastTakeoverHeight; - if (lastTakeoverHeight == takeover.second) - continue; // no takeover happened - activateAllFor(insertUndo, insertSupportUndo, takeover.first); - - // now get the best again: - findBestValid++; - lit = findBestValid.begin(); - *lit >> lastTakeoverHeight; - base->_db << "UPDATE nodes SET lastTakeoverHeight = ? WHERE nodeName = ?" << lastTakeoverHeight << takeover.first; - - takeoverHeightUndo.emplace_back(takeover.first, takeover.second); + // if somebody activates on this block and they are the new best, then everybody activates on this block + CClaimValue value; + if (getInfoForName(nameWithTakeover, value) && value.nValidAtHeight == nNextHeight - 1) + activateAllFor(insertUndo, insertSupportUndo, nameWithTakeover); } nNextHeight++; @@ -938,8 +921,8 @@ void CClaimTrieCacheBase::activateAllFor(insertUndoType& insertUndo, insertUndoT const std::string& name) { // now that we know a takeover is happening, we bring everybody in: { - auto query = base->_db << "SELECT txID, txN, validHeight FROM claims WHERE nodeName = ? AND validHeight > ?" - << name << nNextHeight; + auto query = base->_db << "SELECT txID, txN, validHeight FROM claims WHERE nodeName = ? AND validHeight > ? AND expirationHeight >= ?" + << name << nNextHeight << nNextHeight; for (auto &&row: query) { uint256 hash; uint32_t n; @@ -949,12 +932,13 @@ void CClaimTrieCacheBase::activateAllFor(insertUndoType& insertUndo, insertUndoT } } // and then update them all to activate now: - base->_db << "UPDATE claims SET validHeight = ? WHERE nodeName = ? AND validHeight > ?" << nNextHeight << name << nNextHeight; + base->_db << "UPDATE claims SET validHeight = ? WHERE nodeName = ? AND validHeight > ? AND expirationHeight >= ?" + << nNextHeight << name << nNextHeight << nNextHeight; // then do the same for supports: { - auto query = base->_db << "SELECT txID, txN, validHeight FROM supports WHERE nodeName = ? AND validHeight > ?" - << name << nNextHeight; + auto query = base->_db << "SELECT txID, txN, validHeight FROM supports WHERE nodeName = ? AND validHeight > ? AND expirationHeight >= ?" + << name << nNextHeight << nNextHeight; for (auto &&row: query) { uint256 hash; uint32_t n; @@ -964,7 +948,8 @@ void CClaimTrieCacheBase::activateAllFor(insertUndoType& insertUndo, insertUndoT } } // and then update them all to activate now: - base->_db << "UPDATE supports SET validHeight = ? WHERE nodeName = ? AND validHeight > ?" << nNextHeight << name << nNextHeight; + base->_db << "UPDATE supports SET validHeight = ? WHERE nodeName = ? AND validHeight > ? AND expirationHeight >= ?" + << nNextHeight << name << nNextHeight << nNextHeight; } bool CClaimTrieCacheBase::decrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo) @@ -998,12 +983,8 @@ bool CClaimTrieCacheBase::decrementBlock(insertUndoType& insertUndo, claimQueueR return true; } -bool CClaimTrieCacheBase::finalizeDecrement(std::vector>& takeoverHeightUndo) +bool CClaimTrieCacheBase::finalizeDecrement() { - for (auto it = takeoverHeightUndo.crbegin(); it != takeoverHeightUndo.crend(); ++it) - base->_db << "UPDATE nodes SET lastTakeoverHeight = ?, hash = NULL WHERE name = ?" - << it->second << it->first; - return true; } @@ -1173,11 +1154,11 @@ int CClaimTrieCacheBase::getNumBlocksOfContinuousOwnership(const std::string& na if (nNextHeight <= 646584 && ownershipWorkaround.find(std::make_pair(nNextHeight, name)) != ownershipWorkaround.end()) return 0; - int lastTakeover = -1; - auto query = base->_db << "SELECT lastTakeoverHeight FROM nodes WHERE name = ?" << name; - for (auto&& row: query) - row >> lastTakeover; - return lastTakeover > 0 ? nNextHeight - lastTakeover : 0; + CClaimValue value; + if (getInfoForName(name, value)) + return nNextHeight - value.nValidAtHeight; + + return 0; } int CClaimTrieCacheBase::getDelayForName(const std::string& name) const @@ -1207,19 +1188,18 @@ bool CClaimTrieCacheBase::getProofForName(const std::string& name, const uint160 // cache the parent nodes getMerkleHash(); proof = CClaimTrieProof(); - auto nodeQuery = base->_db << "SELECT name, lastTakeoverHeight FROM nodes WHERE " + auto nodeQuery = base->_db << "SELECT name FROM nodes WHERE " "name IN (WITH RECURSIVE prefix(p) AS (VALUES(?) UNION ALL " - "SELECT SUBSTR(p, 0, LENGTH(p)) FROM prefix WHERE p != '') SELECT p FROM prefix) " + "SELECT SUBSTR(p, 1, LENGTH(p)) FROM prefix WHERE p != '') SELECT p FROM prefix) " "ORDER BY LENGTH(name)" << name; for (auto&& row: nodeQuery) { CClaimValue claim; std::string key; - int lastTakeoverHeight; - row >> key >> lastTakeoverHeight; + row >> key; bool fNodeHasValue = getInfoForName(key, claim); uint256 valueHash; if (fNodeHasValue) - valueHash = getValueHash(claim.outPoint, lastTakeoverHeight); + valueHash = getValueHash(claim.outPoint, claim.nValidAtHeight); const auto pos = key.size(); std::vector> children; @@ -1246,7 +1226,7 @@ bool CClaimTrieCacheBase::getProofForName(const std::string& name, const uint160 proof.hasValue = fNodeHasValue && claim.claimId == finalClaim; if (proof.hasValue) { proof.outPoint = claim.outPoint; - proof.nHeightOfLastTakeover = lastTakeoverHeight; + proof.nHeightOfLastTakeover = claim.nValidAtHeight; } valueHash.SetNull(); } @@ -1256,12 +1236,14 @@ bool CClaimTrieCacheBase::getProofForName(const std::string& name, const uint160 } bool CClaimTrieCacheBase::findNameForClaim(const std::vector& claim, CClaimValue& value, std::string& name) { - auto query = base->_db << "SELECT nodeName, claimId, txID, txN, amount, block_height FROM claims WHERE SUBSTR(claimID, 1, ?) = ?" << claim.size() + 1 << claim; + auto query = base->_db << "SELECT nodeName, claimId, txID, txN, amount, validHeight, block_height " + "FROM claims WHERE SUBSTR(claimID, 1, ?) = ? AND validHeight < ? AND expirationHeight >= ?" + << claim.size() + 1 << claim << nNextHeight << nNextHeight; auto hit = false; for (auto&& row: query) { if (hit) return false; row >> name >> value.claimId >> value.outPoint.hash >> value.outPoint.n - >> value.nAmount >> value.nHeight; + >> value.nAmount >> value.nValidAtHeight >> value.nHeight; hit = true; } return true; @@ -1269,7 +1251,8 @@ bool CClaimTrieCacheBase::findNameForClaim(const std::vector& cla void CClaimTrieCacheBase::getNamesInTrie(std::function callback) { - auto query = base->_db << "SELECT DISTINCT nodeName FROM claims"; + auto query = base->_db << "SELECT DISTINCT nodeName FROM claims WHERE validHeight < ? AND expirationHeight >= ?" + << nNextHeight << nNextHeight; for (auto&& row: query) { std::string name; row >> name; diff --git a/src/claimtrie.h b/src/claimtrie.h index 8c1249da3..7e0f4355c 100644 --- a/src/claimtrie.h +++ b/src/claimtrie.h @@ -49,19 +49,21 @@ namespace sqlite { struct has_sqlite_type : std::true_type {}; inline uint160 get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - int bytes = sqlite3_column_bytes(stmt, inx); uint160 ret; - assert(bytes == ret.size()); auto ptr = sqlite3_column_blob(stmt, inx); + if (!ptr) return ret; + int bytes = sqlite3_column_bytes(stmt, inx); + assert(bytes == ret.size()); std::memcpy(ret.begin(), ptr, bytes); return ret; } inline uint256 get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - int bytes = sqlite3_column_bytes(stmt, inx); uint256 ret; - assert(bytes == ret.size()); auto ptr = sqlite3_column_blob(stmt, inx); + if (!ptr) return ret; + int bytes = sqlite3_column_bytes(stmt, inx); + assert(bytes == ret.size()); std::memcpy(ret.begin(), ptr, bytes); return ret; } @@ -373,8 +375,7 @@ public: virtual bool incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, - supportQueueRowType& expireSupportUndo, - std::vector>& takeoverHeightUndo); + supportQueueRowType& expireSupportUndo); virtual bool decrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, @@ -386,7 +387,7 @@ public: virtual int expirationTime() const; - virtual bool finalizeDecrement(std::vector>& takeoverHeightUndo); + virtual bool finalizeDecrement(); virtual CClaimSupportToName getClaimsForName(const std::string& name) const; virtual std::string adjustNameForValidHeight(const std::string& name, int validHeight) const; @@ -403,7 +404,7 @@ protected: CClaimTrie* base; bool dirtyNodes; - virtual uint256 recursiveComputeMerkleHash(const std::string& name, int lastTakeoverHeight, bool checkOnly); + virtual uint256 recursiveComputeMerkleHash(const std::string& name, bool checkOnly); supportEntryType getSupportsForName(const std::string& name) const; @@ -435,13 +436,12 @@ public: int expirationTime() const override; virtual void initializeIncrement(); - bool finalizeDecrement(std::vector>& takeoverHeightUndo) override; + bool finalizeDecrement() override; bool incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, - supportQueueRowType& expireSupportUndo, - std::vector>& takeoverHeightUndo) override; + supportQueueRowType& expireSupportUndo) override; bool decrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, @@ -470,8 +470,7 @@ public: bool incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, - supportQueueRowType& expireSupportUndo, - std::vector>& takeoverHeightUndo) override; + supportQueueRowType& expireSupportUndo) override; bool decrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, @@ -497,12 +496,12 @@ public: bool getProofForName(const std::string& name, const uint160& finalClaim, CClaimTrieProof& proof) override; void initializeIncrement() override; - bool finalizeDecrement(std::vector>& takeoverHeightUndo) override; + bool finalizeDecrement() override; bool allowSupportMetadata() const; protected: - uint256 recursiveComputeMerkleHash(const std::string& name, int lastTakeoverHeight, bool checkOnly) override; + uint256 recursiveComputeMerkleHash(const std::string& name, bool checkOnly) override; }; typedef CClaimTrieCacheHashFork CClaimTrieCache; diff --git a/src/claimtrieforks.cpp b/src/claimtrieforks.cpp index 21c9b8cc3..99dccfcbc 100644 --- a/src/claimtrieforks.cpp +++ b/src/claimtrieforks.cpp @@ -25,9 +25,9 @@ int CClaimTrieCacheExpirationFork::expirationTime() const return nExpirationTime; } -bool CClaimTrieCacheExpirationFork::incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo, std::vector>& takeoverHeightUndo) +bool CClaimTrieCacheExpirationFork::incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo) { - if (CClaimTrieCacheBase::incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo, takeoverHeightUndo)) { + if (CClaimTrieCacheBase::incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo)) { setExpirationTime(Params().GetConsensus().GetExpirationTime(nNextHeight)); return true; } @@ -52,9 +52,9 @@ void CClaimTrieCacheExpirationFork::initializeIncrement() forkForExpirationChange(true); } -bool CClaimTrieCacheExpirationFork::finalizeDecrement(std::vector>& takeoverHeightUndo) +bool CClaimTrieCacheExpirationFork::finalizeDecrement() { - auto ret = CClaimTrieCacheBase::finalizeDecrement(takeoverHeightUndo); + auto ret = CClaimTrieCacheBase::finalizeDecrement(); if (ret && nNextHeight == Params().GetConsensus().nExtendedClaimExpirationForkHeight) forkForExpirationChange(false); return ret; @@ -74,7 +74,7 @@ bool CClaimTrieCacheExpirationFork::forkForExpirationChange(bool increment) if (!increment) extension = -extension; base->_db << "UPDATE claims SET expirationHeight = expirationHeight + ? WHERE expirationHeight >= ?" << extension << nNextHeight; base->_db << "UPDATE supports SET expirationHeight = expirationHeight + ? WHERE expirationHeight >= ?" << extension << nNextHeight; - base->_db << "UPDATE nodes SET hash = NULL, claimHash = NULL"; // recompute all hashes (as there aren't that many at this point) + base->_db << "UPDATE nodes SET hash = NULL"; // recompute all hashes (as there aren't that many at this point) dirtyNodes = true; return true; } @@ -135,7 +135,7 @@ bool CClaimTrieCacheNormalizationFork::normalizeAllNamesInTrieIfNecessary(bool f base->_db.define("NORMALIZED", [this](const std::string& str) { return normalizeClaimName(str, true); }); - auto query = base->_db << "SELECT NORMALIZED(name), name, claimID as nn FROM claims HAVING nodeName != nn"; + auto query = base->_db << "SELECT NORMALIZED(name) as nn, name, claimID FROM claims HAVING nodeName != nn"; for(auto&& row: query) { std::string newName, oldName; uint160 claimID; @@ -143,17 +143,17 @@ bool CClaimTrieCacheNormalizationFork::normalizeAllNamesInTrieIfNecessary(bool f if (!forward) std::swap(newName, oldName); base->_db << "UPDATE claims SET nodeName = ? WHERE claimID = ?" << newName << claimID; base->_db << "DELETE FROM nodes WHERE name = ?" << oldName; - base->_db << "INSERT INTO nodes(name) VALUES(?) ON CONFLICT DO UPDATE hash = NULL, claimHash = NULL" << newName; + base->_db << "INSERT INTO nodes(name) VALUES(?) ON CONFLICT DO UPDATE hash = NULL" << newName; } dirtyNodes = true; return true; } -bool CClaimTrieCacheNormalizationFork::incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo, std::vector>& takeoverHeightUndo) +bool CClaimTrieCacheNormalizationFork::incrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo) { normalizeAllNamesInTrieIfNecessary(true); - return CClaimTrieCacheExpirationFork::incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo, takeoverHeightUndo); + return CClaimTrieCacheExpirationFork::incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo); } bool CClaimTrieCacheNormalizationFork::decrementBlock(insertUndoType& insertUndo, claimQueueRowType& expireUndo, insertUndoType& insertSupportUndo, supportQueueRowType& expireSupportUndo) @@ -196,36 +196,37 @@ CClaimTrieCacheHashFork::CClaimTrieCacheHashFork(CClaimTrie* base) : CClaimTrieC static const uint256 leafHash = uint256S("0000000000000000000000000000000000000000000000000000000000000002"); static const uint256 emptyHash = uint256S("0000000000000000000000000000000000000000000000000000000000000003"); -uint256 CClaimTrieCacheHashFork::recursiveComputeMerkleHash(const std::string& name, int lastTakeoverHeight, bool checkOnly) +uint256 CClaimTrieCacheHashFork::recursiveComputeMerkleHash(const std::string& name, bool checkOnly) { if (nNextHeight < Params().GetConsensus().nAllClaimsInMerkleForkHeight) - return CClaimTrieCacheNormalizationFork::recursiveComputeMerkleHash(name, lastTakeoverHeight, checkOnly); + return CClaimTrieCacheNormalizationFork::recursiveComputeMerkleHash(name, checkOnly); - auto childQuery = base->_db << "SELECT name, hash, lastTakeoverHeight FROM nodes WHERE parent = ? ORDER BY name" << name; + auto childQuery = base->_db << "SELECT name, hash FROM nodes WHERE parent = ? ORDER BY name" << name; std::vector childHashes; for (auto&& row: childQuery) { std::string key; - int keyLastTakeoverHeight; std::unique_ptr hash; - row >> key >> hash >> keyLastTakeoverHeight; + row >> key >> hash; if (hash == nullptr) hash = std::make_unique(); if (hash->IsNull()) { - *hash = recursiveComputeMerkleHash(key, keyLastTakeoverHeight, checkOnly); + *hash = recursiveComputeMerkleHash(key, checkOnly); } childHashes.push_back(*hash); } auto claimQuery = base->_db << "SELECT c.txID, c.txN, c.validHeight, c.amount + " - "SUM(SELECT s.amount FROM supports s WHERE s.supportedClaimID = c.claimID AND s.validHeight < ?) as effectiveAmount" - "FROM claims c WHERE c.nodeName = ? AND c.validHeight < ? " - "ORDER BY effectiveAmount DESC, c.blockHeight, c.txID, c.txN" << nNextHeight << name << nNextHeight; + "SUM(SELECT s.amount FROM supports s WHERE s.supportedClaimID = c.claimID " + "AND s.validHeight < ? AND s.expirationHeight >= ?) as effectiveAmount" + "FROM claims c WHERE c.nodeName = ? AND c.validHeight < ? AND c.expirationHeight >= ? " + "ORDER BY effectiveAmount DESC, c.blockHeight, c.txID, c.txN" << nNextHeight << nNextHeight << name << nNextHeight << nNextHeight; std::vector claimHashes; for (auto&& row: claimQuery) { COutPoint p; - row >> p.hash >> p.n; - auto claimHash = getValueHash(p, lastTakeoverHeight); + int validHeight; + row >> p.hash >> p.n >> validHeight; + auto claimHash = getValueHash(p, validHeight); claimHashes.push_back(claimHash); } @@ -307,13 +308,13 @@ bool CClaimTrieCacheHashFork::getProofForName(const std::string& name, const uin // cache the parent nodes getMerkleHash(); proof = CClaimTrieProof(); - auto nodeQuery = base->_db << "SELECT name, lastTakeoverHeight FROM nodes WHERE " + auto nodeQuery = base->_db << "SELECT name FROM nodes WHERE " "name IN (WITH RECURSIVE prefix(p) AS (VALUES(?) UNION ALL " - "SELECT SUBSTR(p, 0, LENGTH(p)) FROM prefix WHERE p != '') SELECT p FROM prefix) " + "SELECT SUBSTR(p, 1, LENGTH(p)) FROM prefix WHERE p != '') SELECT p FROM prefix) " "ORDER BY LENGTH(name)" << name; for (auto&& row: nodeQuery) { - std::string key; int lastTakeover; - row >> key >> lastTakeover; + std::string key;; + row >> key; std::vector childHashes; uint32_t nextCurrentIdx = 0; auto childQuery = base->_db << "SELECT name, hash FROM nodes WHERE parent = ?" << key; @@ -332,7 +333,7 @@ bool CClaimTrieCacheHashFork::getProofForName(const std::string& name, const uin COutPoint finalOutPoint; for (uint32_t i = 0; i < cns.claimsNsupports.size(); ++i) { auto& child = cns.claimsNsupports[i].claim; - claimHashes.push_back(getValueHash(child.outPoint, lastTakeover)); + claimHashes.push_back(getValueHash(child.outPoint, child.nValidAtHeight)); if (child.claimId == finalClaim) { finalClaimIdx = i; finalOutPoint = child.outPoint; @@ -344,7 +345,7 @@ bool CClaimTrieCacheHashFork::getProofForName(const std::string& name, const uin // else it will be hash(x, claims) if (key == name) { proof.outPoint = finalOutPoint; - proof.nHeightOfLastTakeover = lastTakeover; + proof.nHeightOfLastTakeover = cns.nLastTakeoverHeight; proof.hasValue = true; auto hash = childHashes.empty() ? leafHash : ComputeMerkleRoot(childHashes); proof.pairs.emplace_back(true, hash); @@ -369,9 +370,9 @@ void CClaimTrieCacheHashFork::initializeIncrement() base->_db << "UPDATE nodes SET hash = NULL"; } -bool CClaimTrieCacheHashFork::finalizeDecrement(std::vector>& takeoverHeightUndo) +bool CClaimTrieCacheHashFork::finalizeDecrement() { - auto ret = CClaimTrieCacheNormalizationFork::finalizeDecrement(takeoverHeightUndo); + auto ret = CClaimTrieCacheNormalizationFork::finalizeDecrement(); if (ret && nNextHeight == Params().GetConsensus().nAllClaimsInMerkleForkHeight - 1) base->_db << "UPDATE nodes SET hash = NULL"; return ret; diff --git a/src/miner.cpp b/src/miner.cpp index 8c88c7d18..9c3dcfc20 100644 --- a/src/miner.cpp +++ b/src/miner.cpp @@ -49,7 +49,6 @@ void blockToCache(const CBlock* pblock, CClaimTrieCache& trieCache, int nHeight) claimQueueRowType dummyExpireUndo; insertUndoType dummyInsertSupportUndo; supportQueueRowType dummyExpireSupportUndo; - std::vector > dummyTakeoverHeightUndo; CUpdateCacheCallbacks callbacks = { .findScriptKey = [&pblock](const COutPoint& point) { @@ -69,7 +68,7 @@ void blockToCache(const CBlock* pblock, CClaimTrieCache& trieCache, int nHeight) if (!tx->IsCoinBase()) UpdateCache(*tx, trieCache, view, nHeight, callbacks); - trieCache.incrementBlock(dummyInsertUndo, dummyExpireUndo, dummyInsertSupportUndo, dummyExpireSupportUndo, dummyTakeoverHeightUndo); + trieCache.incrementBlock(dummyInsertUndo, dummyExpireUndo, dummyInsertSupportUndo, dummyExpireSupportUndo); } BlockAssembler::Options::Options() { diff --git a/src/scheduler.cpp b/src/scheduler.cpp index fdc859b3a..193a92abb 100644 --- a/src/scheduler.cpp +++ b/src/scheduler.cpp @@ -16,7 +16,7 @@ CScheduler::CScheduler() : nThreadsServicingQueue(0), stopRequested(false), stop CScheduler::~CScheduler() { - assert(nThreadsServicingQueue == 0); + assert(!AreThreadsServicingQueue()); } diff --git a/src/sqlite/hdr/sqlite_modern_cpp.h b/src/sqlite/hdr/sqlite_modern_cpp.h index 4703d35b3..080bf3719 100644 --- a/src/sqlite/hdr/sqlite_modern_cpp.h +++ b/src/sqlite/hdr/sqlite_modern_cpp.h @@ -19,664 +19,663 @@ namespace sqlite { - class database; - class database_binder; + class database; + class database_binder; - template class binder; + template class binder; - typedef std::shared_ptr connection_type; + typedef std::shared_ptr connection_type; - template - struct index_binding_helper { - index_binding_helper(const index_binding_helper &) = delete; + template + struct index_binding_helper { + index_binding_helper(const index_binding_helper &) = delete; #if __cplusplus < 201703 - index_binding_helper(index_binding_helper &&) = default; + index_binding_helper(index_binding_helper &&) = default; #endif - typename std::conditional::type index; - T value; - }; + typename std::conditional::type index; + T value; + }; - template - auto named_parameter(const char *name, T &&arg) { - return index_binding_helper{name, std::forward(arg)}; - } - template - auto indexed_parameter(int index, T &&arg) { - return index_binding_helper{index, std::forward(arg)}; - } + template + auto named_parameter(const char *name, T &&arg) { + return index_binding_helper{name, std::forward(arg)}; + } + template + auto indexed_parameter(int index, T &&arg) { + return index_binding_helper{index, std::forward(arg)}; + } - class row_iterator; - class database_binder { + class row_iterator; + class database_binder { - public: - // database_binder is not copyable - database_binder() = delete; - database_binder(const database_binder& other) = delete; - database_binder& operator=(const database_binder&) = delete; + public: + // database_binder is not copyable + database_binder() = delete; + database_binder(const database_binder& other) = delete; + database_binder& operator=(const database_binder&) = delete; - database_binder(database_binder&& other) : - _db(std::move(other._db)), - _stmt(std::move(other._stmt)), - _inx(other._inx), execution_started(other.execution_started) { } + database_binder(database_binder&& other) : + _db(std::move(other._db)), + _stmt(std::move(other._stmt)), + _inx(other._inx), execution_started(other.execution_started) { } - void execute(); + void execute(); - std::string sql() { + std::string sql() { #if SQLITE_VERSION_NUMBER >= 3014000 - auto sqlite_deleter = [](void *ptr) {sqlite3_free(ptr);}; - std::unique_ptr str(sqlite3_expanded_sql(_stmt.get()), sqlite_deleter); - return str ? str.get() : original_sql(); + auto sqlite_deleter = [](void *ptr) {sqlite3_free(ptr);}; + std::unique_ptr str(sqlite3_expanded_sql(_stmt.get()), sqlite_deleter); + return str ? str.get() : original_sql(); #else - return original_sql(); + return original_sql(); #endif - } - - std::string original_sql() { - return sqlite3_sql(_stmt.get()); - } - - void used(bool state) { - if(!state) { - // We may have to reset first if we haven't done so already: - _next_index(); - --_inx; - } - execution_started = state; - } - bool used() const { return execution_started; } - row_iterator begin(); - row_iterator end(); - - private: - std::shared_ptr _db; - std::unique_ptr _stmt; - utility::UncaughtExceptionDetector _has_uncaught_exception; - - int _inx; - - bool execution_started = false; - - int _next_index() { - if(execution_started && !_inx) { - sqlite3_reset(_stmt.get()); - sqlite3_clear_bindings(_stmt.get()); - } - return ++_inx; - } - - sqlite3_stmt* _prepare(u16str_ref sql) { - return _prepare(utility::utf16_to_utf8(sql)); - } - - sqlite3_stmt* _prepare(str_ref sql) { - int hresult; - sqlite3_stmt* tmp = nullptr; - const char *remaining; - hresult = sqlite3_prepare_v2(_db.get(), sql.data(), sql.length(), &tmp, &remaining); - if(hresult != SQLITE_OK) errors::throw_sqlite_error(hresult, sql); - if(!std::all_of(remaining, sql.data() + sql.size(), [](char ch) {return std::isspace(ch);})) - throw errors::more_statements("Multiple semicolon separated statements are unsupported", sql); - return tmp; - } - - template friend database_binder& operator<<(database_binder& db, T&&); - template friend database_binder& operator<<(database_binder& db, index_binding_helper); - template friend database_binder& operator<<(database_binder& db, index_binding_helper); - friend void operator++(database_binder& db, int); - - public: - - database_binder(std::shared_ptr db, u16str_ref sql): - _db(db), - _stmt(_prepare(sql), sqlite3_finalize), - _inx(0) { - } - - database_binder(std::shared_ptr db, str_ref sql): - _db(db), - _stmt(_prepare(sql), sqlite3_finalize), - _inx(0) { - } - - ~database_binder() noexcept(false) { - /* Will be executed if no >>op is found, but not if an exception - is in mid flight */ - if(!used() && !_has_uncaught_exception && _stmt) { - execute(); - } - } - - friend class row_iterator; - }; - - class row_iterator { - public: - class value_type { - public: - value_type(database_binder *_binder): _binder(_binder) {}; - template - typename std::enable_if::value, value_type &>::type operator >>(T &result) { - result = get_col_from_db(_binder->_stmt.get(), next_index++, result_type()); - return *this; - } - template - value_type &operator >>(std::tuple& values) { - values = handle_tuple::type...>>(std::index_sequence_for()); - next_index += sizeof...(Types); - return *this; - } - template - value_type &operator >>(std::tuple&& values) { - return *this >> values; - } - template - operator std::tuple() { - std::tuple value; - *this >> value; - return value; - } - explicit operator bool() { - return sqlite3_column_count(_binder->_stmt.get()) >= next_index; - } - private: - template - Tuple handle_tuple(std::index_sequence) { - return Tuple( - get_col_from_db( - _binder->_stmt.get(), - next_index + Index, - result_type::type>())...); - } - database_binder *_binder; - int next_index = 0; - }; - using difference_type = std::ptrdiff_t; - using pointer = value_type*; - using reference = value_type&; - using iterator_category = std::input_iterator_tag; - - row_iterator() = default; - explicit row_iterator(database_binder &binder): _binder(&binder) { - _binder->_next_index(); - _binder->_inx = 0; - _binder->used(true); - ++*this; - } - - reference operator*() const { return value;} - pointer operator->() const { return std::addressof(**this); } - row_iterator &operator++() { - switch(int result = sqlite3_step(_binder->_stmt.get())) { - case SQLITE_ROW: - value = {_binder}; - break; - case SQLITE_DONE: - _binder = nullptr; - break; - default: - exceptions::throw_sqlite_error(result, _binder->sql()); - } - return *this; - } - - friend inline bool operator ==(const row_iterator &a, const row_iterator &b) { - return a._binder == b._binder; - } - friend inline bool operator !=(const row_iterator &a, const row_iterator &b) { - return !(a==b); - } - - private: - database_binder *_binder = nullptr; - mutable value_type value{_binder}; // mutable, because `changing` the value is just reading it - }; - - inline row_iterator database_binder::begin() { - return row_iterator(*this); - } - - inline row_iterator database_binder::end() { - return row_iterator(); - } - - namespace detail { - template - void _extract_single_value(database_binder &binder, Callback call_back) { - auto iter = binder.begin(); - if(iter == binder.end()) - throw errors::no_rows("no rows to extract: exactly 1 row expected", binder.sql(), SQLITE_DONE); - - call_back(*iter); - - if(++iter != binder.end()) - throw errors::more_rows("not all rows extracted", binder.sql(), SQLITE_ROW); - } - } - inline void database_binder::execute() { - for(auto &&row : *this) - (void)row; - } - namespace detail { - template using void_t = void; - template - struct sqlite_direct_result : std::false_type {}; - template - struct sqlite_direct_result< - T, - void_t() >> std::declval())> - > : std::true_type {}; - } - template - inline typename std::enable_if::value>::type operator>>(database_binder &binder, Result&& value) { - detail::_extract_single_value(binder, [&value] (row_iterator::value_type &row) { - row >> std::forward(value); - }); - } - - template - inline typename std::enable_if::value>::type operator>>(database_binder &db_binder, Function&& func) { - using traits = utility::function_traits; - - for(auto &&row : db_binder) { - binder::run(row, func); - } - } - - template - inline decltype(auto) operator>>(database_binder &&binder, Result&& value) { - return binder >> std::forward(value); - } - - namespace sql_function_binder { - template< - typename ContextType, - std::size_t Count, - typename Functions - > - inline void step( - sqlite3_context* db, - int count, - sqlite3_value** vals - ); - - template< - std::size_t Count, - typename Functions, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( - sqlite3_context* db, - int count, - sqlite3_value** vals, - Values&&... values - ); - - template< - std::size_t Count, - typename Functions, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( - sqlite3_context* db, - int, - sqlite3_value**, - Values&&... values - ); - - template< - typename ContextType, - typename Functions - > - inline void final(sqlite3_context* db); - - template< - std::size_t Count, - typename Function, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( - sqlite3_context* db, - int count, - sqlite3_value** vals, - Values&&... values - ); - - template< - std::size_t Count, - typename Function, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( - sqlite3_context* db, - int, - sqlite3_value**, - Values&&... values - ); - } - - enum class OpenFlags { - READONLY = SQLITE_OPEN_READONLY, - READWRITE = SQLITE_OPEN_READWRITE, - CREATE = SQLITE_OPEN_CREATE, - NOMUTEX = SQLITE_OPEN_NOMUTEX, - FULLMUTEX = SQLITE_OPEN_FULLMUTEX, - SHAREDCACHE = SQLITE_OPEN_SHAREDCACHE, - PRIVATECACH = SQLITE_OPEN_PRIVATECACHE, - URI = SQLITE_OPEN_URI - }; - inline OpenFlags operator|(const OpenFlags& a, const OpenFlags& b) { - return static_cast(static_cast(a) | static_cast(b)); - } - enum class Encoding { - ANY = SQLITE_ANY, - UTF8 = SQLITE_UTF8, - UTF16 = SQLITE_UTF16 - }; - struct sqlite_config { - OpenFlags flags = OpenFlags::READWRITE | OpenFlags::CREATE; - const char *zVfs = nullptr; - Encoding encoding = Encoding::ANY; - }; - - class database { - protected: - std::shared_ptr _db; - - public: - database(const std::string &db_name, const sqlite_config &config = {}): _db(nullptr) { - sqlite3* tmp = nullptr; - auto ret = sqlite3_open_v2(db_name.data(), &tmp, static_cast(config.flags), config.zVfs); - _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. - if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret); - sqlite3_extended_result_codes(_db.get(), true); - if(config.encoding == Encoding::UTF16) - *this << R"(PRAGMA encoding = "UTF-16";)"; - } - - database(const std::u16string &db_name, const sqlite_config &config = {}): database(utility::utf16_to_utf8(db_name), config) { - if (config.encoding == Encoding::ANY) - *this << R"(PRAGMA encoding = "UTF-16";)"; - } - - database(std::shared_ptr db): - _db(db) {} - - database_binder operator<<(str_ref sql) { - return database_binder(_db, sql); - } - - database_binder operator<<(u16str_ref sql) { - return database_binder(_db, sql); - } - - connection_type connection() const { return _db; } - - sqlite3_int64 last_insert_rowid() const { - return sqlite3_last_insert_rowid(_db.get()); - } - - int rows_modified() const { - return sqlite3_changes(_db.get()); - } - - template - void define(const std::string &name, Function&& func) { - typedef utility::function_traits traits; - - auto funcPtr = new auto(std::forward(func)); - if(int result = sqlite3_create_function_v2( - _db.get(), name.data(), traits::arity, SQLITE_UTF8 | SQLITE_DETERMINISTIC, funcPtr, - sql_function_binder::scalar::type>, - nullptr, nullptr, [](void* ptr){ - delete static_cast(ptr); - })) - errors::throw_sqlite_error(result); - } - - template - void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { - typedef utility::function_traits traits; - using ContextType = typename std::remove_reference>::type; - - auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); - if(int result = sqlite3_create_function_v2( - _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8 | SQLITE_DETERMINISTIC, funcPtr, nullptr, - sql_function_binder::step::type>, - sql_function_binder::final::type>, - [](void* ptr){ - delete static_cast(ptr); - })) - errors::throw_sqlite_error(result); - } - - }; - - template - class binder { - private: - template < - typename Function, - std::size_t Index - > - using nth_argument_type = typename utility::function_traits< - Function - >::template argument; - - public: - // `Boundary` needs to be defaulted to `Count` so that the `run` function - // template is not implicitly instantiated on class template instantiation. - // Look up section 14.7.1 _Implicit instantiation_ of the ISO C++14 Standard - // and the [dicussion](https://github.com/aminroosta/sqlite_modern_cpp/issues/8) - // on Github. - - template< - typename Function, - typename... Values, - std::size_t Boundary = Count - > - static typename std::enable_if<(sizeof...(Values) < Boundary), void>::type run( - row_iterator::value_type& row, - Function&& function, - Values&&... values - ) { - typename std::decay>::type value; - row >> value; - run(row, function, std::forward(values)..., std::move(value)); - } - - template< - typename Function, - typename... Values, - std::size_t Boundary = Count - > - static typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run( - row_iterator::value_type&, - Function&& function, - Values&&... values - ) { - function(std::move(values)...); - } - }; - - // Some ppl are lazy so we have a operator for proper prep. statemant handling. - void inline operator++(database_binder& db, int) { db.execute(); } - - template database_binder &operator<<(database_binder& db, index_binding_helper val) { - db._next_index(); --db._inx; - int result = bind_col_in_db(db._stmt.get(), val.index, std::forward(val.value)); - if(result != SQLITE_OK) - exceptions::throw_sqlite_error(result, db.sql()); - return db; - } - - template database_binder &operator<<(database_binder& db, index_binding_helper val) { - db._next_index(); --db._inx; - int index = sqlite3_bind_parameter_index(db._stmt.get(), val.index); - if(!index) - throw errors::unknown_binding("The given binding name is not valid for this statement", db.sql()); - int result = bind_col_in_db(db._stmt.get(), index, std::forward(val.value)); - if(result != SQLITE_OK) - exceptions::throw_sqlite_error(result, db.sql()); - return db; - } - - template database_binder &operator<<(database_binder& db, T&& val) { - int result = bind_col_in_db(db._stmt.get(), db._next_index(), std::forward(val)); - if(result != SQLITE_OK) - exceptions::throw_sqlite_error(result, db.sql()); - return db; - } - // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) - template database_binder operator << (database_binder&& db, const T& val) { db << val; return std::move(db); } - template database_binder operator << (database_binder&& db, index_binding_helper val) { db << index_binding_helper{val.index, std::forward(val.value)}; return std::move(db); } - - namespace sql_function_binder { - template - struct AggregateCtxt { - T obj; - bool constructed = true; - }; - - template< - typename ContextType, - std::size_t Count, - typename Functions - > - inline void step( - sqlite3_context* db, - int count, - sqlite3_value** vals - ) { - auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); - if(!ctxt) return; - try { - if(!ctxt->constructed) new(ctxt) AggregateCtxt(); - step(db, count, vals, ctxt->obj); - return; - } catch(const sqlite_exception &e) { - sqlite3_result_error_code(db, e.get_code()); - sqlite3_result_error(db, e.what(), -1); - } catch(const std::exception &e) { - sqlite3_result_error(db, e.what(), -1); - } catch(...) { - sqlite3_result_error(db, "Unknown error", -1); - } - if(ctxt && ctxt->constructed) - ctxt->~AggregateCtxt(); - } - - template< - std::size_t Count, - typename Functions, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( - sqlite3_context* db, - int count, - sqlite3_value** vals, - Values&&... values - ) { - using arg_type = typename std::remove_cv< - typename std::remove_reference< - typename utility::function_traits< - typename Functions::first_type - >::template argument - >::type - >::type; - - step( - db, - count, - vals, - std::forward(values)..., - get_val_from_db(vals[sizeof...(Values) - 1], result_type())); - } - - template< - std::size_t Count, - typename Functions, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( - sqlite3_context* db, - int, - sqlite3_value**, - Values&&... values - ) { - static_cast(sqlite3_user_data(db))->first(std::forward(values)...); - } - - template< - typename ContextType, - typename Functions - > - inline void final(sqlite3_context* db) { - auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); - try { - if(!ctxt) return; - if(!ctxt->constructed) new(ctxt) AggregateCtxt(); - store_result_in_db(db, - static_cast(sqlite3_user_data(db))->second(ctxt->obj)); - } catch(const sqlite_exception &e) { - sqlite3_result_error_code(db, e.get_code()); - sqlite3_result_error(db, e.what(), -1); - } catch(const std::exception &e) { - sqlite3_result_error(db, e.what(), -1); - } catch(...) { - sqlite3_result_error(db, "Unknown error", -1); - } - if(ctxt && ctxt->constructed) - ctxt->~AggregateCtxt(); - } - - template< - std::size_t Count, - typename Function, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( - sqlite3_context* db, - int count, - sqlite3_value** vals, - Values&&... values - ) { - using arg_type = typename std::remove_cv< - typename std::remove_reference< - typename utility::function_traits::template argument - >::type - >::type; - - scalar( - db, - count, - vals, - std::forward(values)..., - get_val_from_db(vals[sizeof...(Values)], result_type())); - } - - template< - std::size_t Count, - typename Function, - typename... Values - > - inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( - sqlite3_context* db, - int, - sqlite3_value**, - Values&&... values - ) { - try { - store_result_in_db(db, - (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); - } catch(const sqlite_exception &e) { - sqlite3_result_error_code(db, e.get_code()); - sqlite3_result_error(db, e.what(), -1); - } catch(const std::exception &e) { - sqlite3_result_error(db, e.what(), -1); - } catch(...) { - sqlite3_result_error(db, "Unknown error", -1); - } - } - } + } + + std::string original_sql() { + return sqlite3_sql(_stmt.get()); + } + + void used(bool state) { + if(!state) { + // We may have to reset first if we haven't done so already: + _next_index(); + --_inx; + } + execution_started = state; + } + bool used() const { return execution_started; } + row_iterator begin(); + row_iterator end(); + + private: + std::shared_ptr _db; + std::unique_ptr _stmt; + utility::UncaughtExceptionDetector _has_uncaught_exception; + + int _inx; + + bool execution_started = false; + + int _next_index() { + if(execution_started && !_inx) { + sqlite3_reset(_stmt.get()); + sqlite3_clear_bindings(_stmt.get()); + } + return ++_inx; + } + + sqlite3_stmt* _prepare(u16str_ref sql) { + return _prepare(utility::utf16_to_utf8(sql)); + } + + sqlite3_stmt* _prepare(str_ref sql) { + int hresult; + sqlite3_stmt* tmp = nullptr; + const char *remaining; + hresult = sqlite3_prepare_v2(_db.get(), sql.data(), sql.length(), &tmp, &remaining); + if(hresult != SQLITE_OK) errors::throw_sqlite_error(hresult, sql, sqlite3_errmsg(_db.get())); + if(!std::all_of(remaining, sql.data() + sql.size(), [](char ch) {return std::isspace(ch);})) + throw errors::more_statements("Multiple semicolon separated statements are unsupported", sql); + return tmp; + } + + template friend database_binder& operator<<(database_binder& db, T&&); + template friend database_binder& operator<<(database_binder& db, index_binding_helper); + template friend database_binder& operator<<(database_binder& db, index_binding_helper); + friend void operator++(database_binder& db, int); + + public: + + database_binder(std::shared_ptr db, u16str_ref sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + database_binder(std::shared_ptr db, str_ref sql): + _db(db), + _stmt(_prepare(sql), sqlite3_finalize), + _inx(0) { + } + + ~database_binder() noexcept(false) { + /* Will be executed if no >>op is found, but not if an exception + is in mid flight */ + if(!used() && !_has_uncaught_exception && _stmt) { + execute(); + } + } + + friend class row_iterator; + }; + + class row_iterator { + public: + class value_type { + public: + value_type(database_binder *_binder): _binder(_binder) {}; + template + typename std::enable_if::value, value_type &>::type operator >>(T &result) { + result = get_col_from_db(_binder->_stmt.get(), next_index++, result_type()); + return *this; + } + template + value_type &operator >>(std::tuple& values) { + values = handle_tuple::type...>>(std::index_sequence_for()); + next_index += sizeof...(Types); + return *this; + } + template + value_type &operator >>(std::tuple&& values) { + return *this >> values; + } + template + operator std::tuple() { + std::tuple value; + *this >> value; + return value; + } + explicit operator bool() { + return sqlite3_column_count(_binder->_stmt.get()) >= next_index; + } + private: + template + Tuple handle_tuple(std::index_sequence) { + return Tuple( + get_col_from_db( + _binder->_stmt.get(), + next_index + Index, + result_type::type>())...); + } + database_binder *_binder; + int next_index = 0; + }; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + using iterator_category = std::input_iterator_tag; + + row_iterator() = default; + explicit row_iterator(database_binder &binder): _binder(&binder) { + _binder->_next_index(); + _binder->_inx = 0; + _binder->used(true); + ++*this; + } + + reference operator*() const { return value;} + pointer operator->() const { return std::addressof(**this); } + row_iterator &operator++() { + switch(int result = sqlite3_step(_binder->_stmt.get())) { + case SQLITE_ROW: + value = {_binder}; + break; + case SQLITE_DONE: + _binder = nullptr; + break; + default: + exceptions::throw_sqlite_error(result, _binder->sql(), sqlite3_errmsg(_binder->_db.get())); + } + return *this; + } + + friend inline bool operator ==(const row_iterator &a, const row_iterator &b) { + return a._binder == b._binder; + } + friend inline bool operator !=(const row_iterator &a, const row_iterator &b) { + return !(a==b); + } + + private: + database_binder *_binder = nullptr; + mutable value_type value{_binder}; // mutable, because `changing` the value is just reading it + }; + + inline row_iterator database_binder::begin() { + return row_iterator(*this); + } + + inline row_iterator database_binder::end() { + return row_iterator(); + } + + namespace detail { + template + void _extract_single_value(database_binder &binder, Callback call_back) { + auto iter = binder.begin(); + if(iter == binder.end()) + throw errors::no_rows("no rows to extract: exactly 1 row expected", binder.sql(), SQLITE_DONE); + + call_back(*iter); + + if(++iter != binder.end()) + throw errors::more_rows("not all rows extracted", binder.sql(), SQLITE_ROW); + } + } + inline void database_binder::execute() { + for(auto &&row : *this) + (void)row; + } + namespace detail { + template using void_t = void; + template + struct sqlite_direct_result : std::false_type {}; + template + struct sqlite_direct_result< + T, + void_t() >> std::declval())> + > : std::true_type {}; + } + template + inline typename std::enable_if::value>::type operator>>(database_binder &binder, Result&& value) { + detail::_extract_single_value(binder, [&value] (row_iterator::value_type &row) { + row >> std::forward(value); + }); + } + + template + inline typename std::enable_if::value>::type operator>>(database_binder &db_binder, Function&& func) { + using traits = utility::function_traits; + + for(auto &&row : db_binder) { + binder::run(row, func); + } + } + + template + inline decltype(auto) operator>>(database_binder &&binder, Result&& value) { + return binder >> std::forward(value); + } + + namespace sql_function_binder { + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ); + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ); + } + + enum class OpenFlags { + READONLY = SQLITE_OPEN_READONLY, + READWRITE = SQLITE_OPEN_READWRITE, + CREATE = SQLITE_OPEN_CREATE, + NOMUTEX = SQLITE_OPEN_NOMUTEX, + FULLMUTEX = SQLITE_OPEN_FULLMUTEX, + SHAREDCACHE = SQLITE_OPEN_SHAREDCACHE, + PRIVATECACH = SQLITE_OPEN_PRIVATECACHE, + URI = SQLITE_OPEN_URI + }; + inline OpenFlags operator|(const OpenFlags& a, const OpenFlags& b) { + return static_cast(static_cast(a) | static_cast(b)); + } + enum class Encoding { + ANY = SQLITE_ANY, + UTF8 = SQLITE_UTF8, + UTF16 = SQLITE_UTF16 + }; + struct sqlite_config { + OpenFlags flags = OpenFlags::READWRITE | OpenFlags::CREATE; + const char *zVfs = nullptr; + Encoding encoding = Encoding::ANY; + }; + + class database { + protected: + std::shared_ptr _db; + + public: + database(const std::string &db_name, const sqlite_config &config = {}): _db(nullptr) { + sqlite3* tmp = nullptr; + auto ret = sqlite3_open_v2(db_name.data(), &tmp, static_cast(config.flags), config.zVfs); + _db = std::shared_ptr(tmp, [=](sqlite3* ptr) { sqlite3_close_v2(ptr); }); // this will close the connection eventually when no longer needed. + if(ret != SQLITE_OK) errors::throw_sqlite_error(_db ? sqlite3_extended_errcode(_db.get()) : ret, {}, sqlite3_errmsg(_db.get())); + sqlite3_extended_result_codes(_db.get(), true); + if(config.encoding == Encoding::UTF16) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(const std::u16string &db_name, const sqlite_config &config = {}): database(utility::utf16_to_utf8(db_name), config) { + if (config.encoding == Encoding::ANY) + *this << R"(PRAGMA encoding = "UTF-16";)"; + } + + database(std::shared_ptr db): + _db(db) {} + + database_binder operator<<(str_ref sql) { + return database_binder(_db, sql); + } + + database_binder operator<<(u16str_ref sql) { + return database_binder(_db, sql); + } + + connection_type connection() const { return _db; } + + sqlite3_int64 last_insert_rowid() const { + return sqlite3_last_insert_rowid(_db.get()); + } + + int rows_modified() const { + return sqlite3_changes(_db.get()); + } + + template + void define(const std::string &name, Function&& func) { + typedef utility::function_traits traits; + + auto funcPtr = new auto(std::forward(func)); + if(int result = sqlite3_create_function_v2( + _db.get(), name.data(), traits::arity, SQLITE_UTF8, funcPtr, + sql_function_binder::scalar::type>, + nullptr, nullptr, [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result, {}, sqlite3_errmsg(_db.get())); + } + + template + void define(const std::string &name, StepFunction&& step, FinalFunction&& final) { + typedef utility::function_traits traits; + using ContextType = typename std::remove_reference>::type; + + auto funcPtr = new auto(std::make_pair(std::forward(step), std::forward(final))); + if(int result = sqlite3_create_function_v2( + _db.get(), name.c_str(), traits::arity - 1, SQLITE_UTF8, funcPtr, nullptr, + sql_function_binder::step::type>, + sql_function_binder::final::type>, + [](void* ptr){ + delete static_cast(ptr); + })) + errors::throw_sqlite_error(result, {}, sqlite3_errmsg(_db.get())); + } + + }; + + template + class binder { + private: + template < + typename Function, + std::size_t Index + > + using nth_argument_type = typename utility::function_traits< + Function + >::template argument; + + public: + // `Boundary` needs to be defaulted to `Count` so that the `run` function + // template is not implicitly instantiated on class template instantiation. + // Look up section 14.7.1 _Implicit instantiation_ of the ISO C++14 Standard + // and the [dicussion](https://github.com/aminroosta/sqlite_modern_cpp/issues/8) + // on Github. + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) < Boundary), void>::type run( + row_iterator::value_type& row, + Function&& function, + Values&&... values + ) { + typename std::decay>::type value; + row >> value; + run(row, function, std::forward(values)..., std::move(value)); + } + + template< + typename Function, + typename... Values, + std::size_t Boundary = Count + > + static typename std::enable_if<(sizeof...(Values) == Boundary), void>::type run( + row_iterator::value_type&, + Function&& function, + Values&&... values + ) { + function(std::move(values)...); + } + }; + + // Some ppl are lazy so we have a operator for proper prep. statemant handling. + void inline operator++(database_binder& db, int) { db.execute(); } + + template database_binder &operator<<(database_binder& db, index_binding_helper val) { + db._next_index(); --db._inx; + int result = bind_col_in_db(db._stmt.get(), val.index, std::forward(val.value)); + if(result != SQLITE_OK) + exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); + return db; + } + + template database_binder &operator<<(database_binder& db, index_binding_helper val) { + db._next_index(); --db._inx; + int index = sqlite3_bind_parameter_index(db._stmt.get(), val.index); + if(!index) + throw errors::unknown_binding("The given binding name is not valid for this statement", db.sql()); + int result = bind_col_in_db(db._stmt.get(), index, std::forward(val.value)); + if(result != SQLITE_OK) + exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); + return db; + } + + template database_binder &operator<<(database_binder& db, T&& val) { + int result = bind_col_in_db(db._stmt.get(), db._next_index(), std::forward(val)); + if(result != SQLITE_OK) + exceptions::throw_sqlite_error(result, db.sql(), sqlite3_errmsg(db._db.get())); + return db; + } + // Convert the rValue binder to a reference and call first op<<, its needed for the call that creates the binder (be carefull of recursion here!) + template database_binder operator << (database_binder&& db, const T& val) { db << val; return std::move(db); } + template database_binder operator << (database_binder&& db, index_binding_helper val) { db << index_binding_helper{val.index, std::forward(val.value)}; return std::move(db); } + + namespace sql_function_binder { + template + struct AggregateCtxt { + T obj; + bool constructed = true; + }; + + template< + typename ContextType, + std::size_t Count, + typename Functions + > + inline void step( + sqlite3_context* db, + int count, + sqlite3_value** vals + ) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + if(!ctxt) return; + try { + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + step(db, count, vals, ctxt->obj); + return; + } catch(const sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(const std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) && sizeof...(Values) < Count), void>::type step( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + using arg_type = typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits< + typename Functions::first_type + >::template argument + >::type + >::type; + + step( + db, + count, + vals, + std::forward(values)..., + get_val_from_db(vals[sizeof...(Values) - 1], result_type())); + } + + template< + std::size_t Count, + typename Functions, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type step( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + static_cast(sqlite3_user_data(db))->first(std::forward(values)...); + } + + template< + typename ContextType, + typename Functions + > + inline void final(sqlite3_context* db) { + auto ctxt = static_cast*>(sqlite3_aggregate_context(db, sizeof(AggregateCtxt))); + try { + if(!ctxt) return; + if(!ctxt->constructed) new(ctxt) AggregateCtxt(); + store_result_in_db(db, + static_cast(sqlite3_user_data(db))->second(ctxt->obj)); + } catch(const sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(const std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + if(ctxt && ctxt->constructed) + ctxt->~AggregateCtxt(); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) < Count), void>::type scalar( + sqlite3_context* db, + int count, + sqlite3_value** vals, + Values&&... values + ) { + using arg_type = typename std::remove_cv< + typename std::remove_reference< + typename utility::function_traits::template argument + >::type + >::type; + + scalar( + db, + count, + vals, + std::forward(values)..., + get_val_from_db(vals[sizeof...(Values)], result_type())); + } + + template< + std::size_t Count, + typename Function, + typename... Values + > + inline typename std::enable_if<(sizeof...(Values) == Count), void>::type scalar( + sqlite3_context* db, + int, + sqlite3_value**, + Values&&... values + ) { + try { + store_result_in_db(db, + (*static_cast(sqlite3_user_data(db)))(std::forward(values)...)); + } catch(const sqlite_exception &e) { + sqlite3_result_error_code(db, e.get_code()); + sqlite3_result_error(db, e.what(), -1); + } catch(const std::exception &e) { + sqlite3_result_error(db, e.what(), -1); + } catch(...) { + sqlite3_result_error(db, "Unknown error", -1); + } + } + } } - diff --git a/src/sqlite/hdr/sqlite_modern_cpp/errors.h b/src/sqlite/hdr/sqlite_modern_cpp/errors.h index afdfd4096..aba4feb28 100644 --- a/src/sqlite/hdr/sqlite_modern_cpp/errors.h +++ b/src/sqlite/hdr/sqlite_modern_cpp/errors.h @@ -7,24 +7,25 @@ namespace sqlite { - class sqlite_exception: public std::runtime_error { - public: - sqlite_exception(const char* msg, str_ref sql, int code = -1): runtime_error(msg), code(code), sql(sql) {} - sqlite_exception(int code, str_ref sql): runtime_error(sqlite3_errstr(code)), code(code), sql(sql) {} - int get_code() const {return code & 0xFF;} - int get_extended_code() const {return code;} - std::string get_sql() const {return sql;} - private: - int code; - std::string sql; - }; + class sqlite_exception: public std::runtime_error { + public: + sqlite_exception(const char* msg, str_ref sql, int code = -1): runtime_error(msg), code(code), sql(sql) {} + sqlite_exception(int code, str_ref sql, const char *msg = nullptr): runtime_error(msg ? msg : sqlite3_errstr(code)), code(code), sql(sql) {} + int get_code() const {return code & 0xFF;} + int get_extended_code() const {return code;} + std::string get_sql() const {return sql;} + const char *errstr() const {return code == -1 ? "Unknown error" : sqlite3_errstr(code);} + private: + int code; + std::string sql; + }; - namespace errors { - //One more or less trivial derived error class for each SQLITE error. - //Note the following are not errors so have no classes: - //SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE - // - //Note these names are exact matches to the names of the SQLITE error codes. + namespace errors { + //One more or less trivial derived error class for each SQLITE error. + //Note the following are not errors so have no classes: + //SQLITE_OK, SQLITE_NOTICE, SQLITE_WARNING, SQLITE_ROW, SQLITE_DONE + // + //Note these names are exact matches to the names of the SQLITE error codes. #define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ class name: public sqlite_exception { using sqlite_exception::sqlite_exception; };\ derived @@ -34,15 +35,15 @@ namespace sqlite { #undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED #undef SQLITE_MODERN_CPP_ERROR_CODE - //Some additional errors are here for the C++ interface - class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; - class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; - class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement - class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; }; - class unknown_binding: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + //Some additional errors are here for the C++ interface + class more_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class no_rows: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class more_statements: public sqlite_exception { using sqlite_exception::sqlite_exception; }; // Prepared statements can only contain one statement + class invalid_utf16: public sqlite_exception { using sqlite_exception::sqlite_exception; }; + class unknown_binding: public sqlite_exception { using sqlite_exception::sqlite_exception; }; - static void throw_sqlite_error(const int& error_code, str_ref sql = "") { - switch(error_code & 0xFF) { + static void throw_sqlite_error(const int& error_code, str_ref sql = "", const char *errmsg = nullptr) { + switch(error_code & 0xFF) { #define SQLITE_MODERN_CPP_ERROR_CODE(NAME,name,derived) \ case SQLITE_ ## NAME: switch(error_code) { \ derived \ @@ -51,19 +52,19 @@ namespace sqlite { } #if SQLITE_VERSION_NUMBER < 3010000 -#define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8)) + #define SQLITE_IOERR_VNODE (SQLITE_IOERR | (27<<8)) #define SQLITE_IOERR_AUTH (SQLITE_IOERR | (28<<8)) #define SQLITE_AUTH_USER (SQLITE_AUTH | (1<<8)) #endif #define SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED(BASE,SUB,base,sub) \ - case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql); + case SQLITE_ ## BASE ## _ ## SUB: throw base ## _ ## sub(error_code, sql, errmsg); #include "lists/error_codes.h" #undef SQLITE_MODERN_CPP_ERROR_CODE_EXTENDED #undef SQLITE_MODERN_CPP_ERROR_CODE - default: throw sqlite_exception(error_code, sql); - } - } - } - namespace exceptions = errors; -} + default: throw sqlite_exception(error_code, sql, errmsg); + } + } + } + namespace exceptions = errors; +} \ No newline at end of file diff --git a/src/sqlite/hdr/sqlite_modern_cpp/type_wrapper.h b/src/sqlite/hdr/sqlite_modern_cpp/type_wrapper.h index 44e3ad69f..9a8710dc2 100644 --- a/src/sqlite/hdr/sqlite_modern_cpp/type_wrapper.h +++ b/src/sqlite/hdr/sqlite_modern_cpp/type_wrapper.h @@ -12,7 +12,7 @@ #ifdef __has_include #if __cplusplus > 201402 && __has_include() #define MODERN_SQLITE_STD_OPTIONAL_SUPPORT -#elif __has_include() +#elif __has_include() && __apple_build_version__ < 11000000 #define MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT #endif #endif @@ -45,130 +45,130 @@ namespace sqlite #else namespace sqlite { - typedef const std::string& str_ref; - typedef const std::u16string& u16str_ref; + typedef const std::string& str_ref; + typedef const std::u16string& u16str_ref; } #endif #include #include "errors.h" namespace sqlite { - template - struct has_sqlite_type : std::false_type {}; - - template - using is_sqlite_value = std::integral_constant::value - || has_sqlite_type::value - || has_sqlite_type::value - || has_sqlite_type::value - || has_sqlite_type::value - >; + template + struct has_sqlite_type : std::false_type {}; - template - struct has_sqlite_type : has_sqlite_type {}; - template - struct has_sqlite_type : has_sqlite_type {}; - template - struct has_sqlite_type : has_sqlite_type {}; + template + using is_sqlite_value = std::integral_constant::value + || has_sqlite_type::value + || has_sqlite_type::value + || has_sqlite_type::value + || has_sqlite_type::value + >; - template - struct result_type { - using type = T; - constexpr result_type() = default; - template::value>> - constexpr result_type(result_type) { } - }; + template + struct has_sqlite_type : has_sqlite_type {}; + template + struct has_sqlite_type : has_sqlite_type {}; + template + struct has_sqlite_type : has_sqlite_type {}; - // int - template<> - struct has_sqlite_type : std::true_type {}; + template + struct result_type { + using type = T; + constexpr result_type() = default; + template::value>> + constexpr result_type(result_type) { } + }; - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const int& val) { - return sqlite3_bind_int(stmt, inx, val); - } - inline void store_result_in_db(sqlite3_context* db, const int& val) { - sqlite3_result_int(db, val); - } - inline int get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : - sqlite3_column_int(stmt, inx); - } - inline int get_val_from_db(sqlite3_value *value, result_type) { - return sqlite3_value_type(value) == SQLITE_NULL ? 0 : - sqlite3_value_int(value); - } + // int + template<> + struct has_sqlite_type : std::true_type {}; - // sqlite_int64 - template<> - struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const int& val) { + return sqlite3_bind_int(stmt, inx, val); + } + inline void store_result_in_db(sqlite3_context* db, const int& val) { + sqlite3_result_int(db, val); + } + inline int get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_int(stmt, inx); + } + inline int get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_int(value); + } - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const sqlite_int64& val) { - return sqlite3_bind_int64(stmt, inx, val); - } - inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { - sqlite3_result_int64(db, val); - } - inline sqlite_int64 get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : - sqlite3_column_int64(stmt, inx); - } - inline sqlite3_int64 get_val_from_db(sqlite3_value *value, result_type) { - return sqlite3_value_type(value) == SQLITE_NULL ? 0 : - sqlite3_value_int64(value); - } + // sqlite_int64 + template<> + struct has_sqlite_type : std::true_type {}; - // float - template<> - struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const sqlite_int64& val) { + return sqlite3_bind_int64(stmt, inx, val); + } + inline void store_result_in_db(sqlite3_context* db, const sqlite_int64& val) { + sqlite3_result_int64(db, val); + } + inline sqlite_int64 get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_int64(stmt, inx); + } + inline sqlite3_int64 get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_int64(value); + } - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const float& val) { - return sqlite3_bind_double(stmt, inx, double(val)); - } - inline void store_result_in_db(sqlite3_context* db, const float& val) { - sqlite3_result_double(db, val); - } - inline float get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : - sqlite3_column_double(stmt, inx); - } - inline float get_val_from_db(sqlite3_value *value, result_type) { - return sqlite3_value_type(value) == SQLITE_NULL ? 0 : - sqlite3_value_double(value); - } + // float + template<> + struct has_sqlite_type : std::true_type {}; - // double - template<> - struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const float& val) { + return sqlite3_bind_double(stmt, inx, double(val)); + } + inline void store_result_in_db(sqlite3_context* db, const float& val) { + sqlite3_result_double(db, val); + } + inline float get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_double(stmt, inx); + } + inline float get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_double(value); + } - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const double& val) { - return sqlite3_bind_double(stmt, inx, val); - } - inline void store_result_in_db(sqlite3_context* db, const double& val) { - sqlite3_result_double(db, val); - } - inline double get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : - sqlite3_column_double(stmt, inx); - } - inline double get_val_from_db(sqlite3_value *value, result_type) { - return sqlite3_value_type(value) == SQLITE_NULL ? 0 : - sqlite3_value_double(value); - } + // double + template<> + struct has_sqlite_type : std::true_type {}; - /* for nullptr support */ - template<> - struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const double& val) { + return sqlite3_bind_double(stmt, inx, val); + } + inline void store_result_in_db(sqlite3_context* db, const double& val) { + sqlite3_result_double(db, val); + } + inline double get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? 0 : + sqlite3_column_double(stmt, inx); + } + inline double get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? 0 : + sqlite3_value_double(value); + } - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, std::nullptr_t) { - return sqlite3_bind_null(stmt, inx); - } - inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { - sqlite3_result_null(db); - } + /* for nullptr support */ + template<> + struct has_sqlite_type : std::true_type {}; + + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, std::nullptr_t) { + return sqlite3_bind_null(stmt, inx); + } + inline void store_result_in_db(sqlite3_context* db, std::nullptr_t) { + sqlite3_result_null(db); + } #ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT - template<> + template<> struct has_sqlite_type : std::true_type {}; inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, std::monostate) { @@ -185,184 +185,184 @@ namespace sqlite { } #endif - // str_ref - template<> - struct has_sqlite_type : std::true_type {}; - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, str_ref val) { - return sqlite3_bind_text(stmt, inx, val.data(), val.length(), SQLITE_STATIC); - } + // str_ref + template<> + struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, str_ref val) { + return sqlite3_bind_text(stmt, inx, val.data(), val.length(), SQLITE_TRANSIENT); + } - // Convert char* to string_view to trigger op<<(..., const str_ref ) - template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const char(&STR)[N]) { - return sqlite3_bind_text(stmt, inx, &STR[0], N-1, SQLITE_STATIC); - } + // Convert char* to string_view to trigger op<<(..., const str_ref ) + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const char(&STR)[N]) { + return sqlite3_bind_text(stmt, inx, &STR[0], N-1, SQLITE_TRANSIENT); + } - inline std::string get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? std::string() : - std::string(reinterpret_cast(sqlite3_column_text(stmt, inx)), sqlite3_column_bytes(stmt, inx)); - } - inline std::string get_val_from_db(sqlite3_value *value, result_type) { - return sqlite3_value_type(value) == SQLITE_NULL ? std::string() : - std::string(reinterpret_cast(sqlite3_value_text(value)), sqlite3_value_bytes(value)); - } + inline std::string get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? std::string() : + std::string(reinterpret_cast(sqlite3_column_text(stmt, inx)), sqlite3_column_bytes(stmt, inx)); + } + inline std::string get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? std::string() : + std::string(reinterpret_cast(sqlite3_value_text(value)), sqlite3_value_bytes(value)); + } - inline void store_result_in_db(sqlite3_context* db, str_ref val) { - sqlite3_result_text(db, val.data(), val.length(), SQLITE_TRANSIENT); - } - // u16str_ref - template<> - struct has_sqlite_type : std::true_type {}; - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, u16str_ref val) { - return sqlite3_bind_text16(stmt, inx, val.data(), sizeof(char16_t) * val.length(), SQLITE_STATIC); - } + inline void store_result_in_db(sqlite3_context* db, str_ref val) { + sqlite3_result_text(db, val.data(), val.length(), SQLITE_TRANSIENT); + } + // u16str_ref + template<> + struct has_sqlite_type : std::true_type {}; + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, u16str_ref val) { + return sqlite3_bind_text16(stmt, inx, val.data(), sizeof(char16_t) * val.length(), SQLITE_TRANSIENT); + } - // Convert char* to string_view to trigger op<<(..., const str_ref ) - template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const char16_t(&STR)[N]) { - return sqlite3_bind_text16(stmt, inx, &STR[0], sizeof(char16_t) * (N-1), SQLITE_STATIC); - } + // Convert char* to string_view to trigger op<<(..., const str_ref ) + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const char16_t(&STR)[N]) { + return sqlite3_bind_text16(stmt, inx, &STR[0], sizeof(char16_t) * (N-1), SQLITE_TRANSIENT); + } - inline std::u16string get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? std::u16string() : - std::u16string(reinterpret_cast(sqlite3_column_text16(stmt, inx)), sqlite3_column_bytes16(stmt, inx)); - } - inline std::u16string get_val_from_db(sqlite3_value *value, result_type) { - return sqlite3_value_type(value) == SQLITE_NULL ? std::u16string() : - std::u16string(reinterpret_cast(sqlite3_value_text16(value)), sqlite3_value_bytes16(value)); - } + inline std::u16string get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return sqlite3_column_type(stmt, inx) == SQLITE_NULL ? std::u16string() : + std::u16string(reinterpret_cast(sqlite3_column_text16(stmt, inx)), sqlite3_column_bytes16(stmt, inx)); + } + inline std::u16string get_val_from_db(sqlite3_value *value, result_type) { + return sqlite3_value_type(value) == SQLITE_NULL ? std::u16string() : + std::u16string(reinterpret_cast(sqlite3_value_text16(value)), sqlite3_value_bytes16(value)); + } - inline void store_result_in_db(sqlite3_context* db, u16str_ref val) { - sqlite3_result_text16(db, val.data(), sizeof(char16_t) * val.length(), SQLITE_TRANSIENT); - } + inline void store_result_in_db(sqlite3_context* db, u16str_ref val) { + sqlite3_result_text16(db, val.data(), sizeof(char16_t) * val.length(), SQLITE_TRANSIENT); + } - // Other integer types - template - struct has_sqlite_type::value>::type> : std::true_type {}; + // Other integer types + template + struct has_sqlite_type::value>::type> : std::true_type {}; - template::value>::type> - inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const Integral& val) { - return bind_col_in_db(stmt, inx, static_cast(val)); - } - template::type>> - inline void store_result_in_db(sqlite3_context* db, const Integral& val) { - store_result_in_db(db, static_cast(val)); - } - template::value>::type> - inline Integral get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { - return get_col_from_db(stmt, inx, result_type()); - } - template::value>::type> - inline Integral get_val_from_db(sqlite3_value *value, result_type) { - return get_val_from_db(value, result_type()); - } + template::value>::type> + inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const Integral& val) { + return bind_col_in_db(stmt, inx, static_cast(val)); + } + template::type>> + inline void store_result_in_db(sqlite3_context* db, const Integral& val) { + store_result_in_db(db, static_cast(val)); + } + template::value>::type> + inline Integral get_col_from_db(sqlite3_stmt* stmt, int inx, result_type) { + return get_col_from_db(stmt, inx, result_type()); + } + template::value>::type> + inline Integral get_val_from_db(sqlite3_value *value, result_type) { + return get_val_from_db(value, result_type()); + } - // vector - template - struct has_sqlite_type, SQLITE_BLOB, void> : std::true_type {}; + // vector + template + struct has_sqlite_type, SQLITE_BLOB, void> : std::true_type {}; - template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::vector& vec) { - void const* buf = reinterpret_cast(vec.data()); - int bytes = vec.size() * sizeof(T); - return sqlite3_bind_blob(stmt, inx, buf, bytes, SQLITE_STATIC); - } - template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { - void const* buf = reinterpret_cast(vec.data()); - int bytes = vec.size() * sizeof(T); - sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); - } - template inline std::vector get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { - if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { - return {}; - } - int bytes = sqlite3_column_bytes(stmt, inx); - T const* buf = reinterpret_cast(sqlite3_column_blob(stmt, inx)); - return std::vector(buf, buf + bytes/sizeof(T)); - } - template inline std::vector get_val_from_db(sqlite3_value *value, result_type>) { - if(sqlite3_value_type(value) == SQLITE_NULL) { - return {}; - } - int bytes = sqlite3_value_bytes(value); - T const* buf = reinterpret_cast(sqlite3_value_blob(value)); - return std::vector(buf, buf + bytes/sizeof(T)); - } + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + return sqlite3_bind_blob(stmt, inx, buf, bytes, SQLITE_TRANSIENT); + } + template inline void store_result_in_db(sqlite3_context* db, const std::vector& vec) { + void const* buf = reinterpret_cast(vec.data()); + int bytes = vec.size() * sizeof(T); + sqlite3_result_blob(db, buf, bytes, SQLITE_TRANSIENT); + } + template inline std::vector get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return {}; + } + int bytes = sqlite3_column_bytes(stmt, inx); + T const* buf = reinterpret_cast(sqlite3_column_blob(stmt, inx)); + return std::vector(buf, buf + bytes/sizeof(T)); + } + template inline std::vector get_val_from_db(sqlite3_value *value, result_type>) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + return {}; + } + int bytes = sqlite3_value_bytes(value); + T const* buf = reinterpret_cast(sqlite3_value_blob(value)); + return std::vector(buf, buf + bytes/sizeof(T)); + } - /* for unique_ptr support */ - template - struct has_sqlite_type, Type, void> : has_sqlite_type {}; - template - struct has_sqlite_type, SQLITE_NULL, void> : std::true_type {}; + /* for unique_ptr support */ + template + struct has_sqlite_type, Type, void> : has_sqlite_type {}; + template + struct has_sqlite_type, SQLITE_NULL, void> : std::true_type {}; - template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::unique_ptr& val) { - return val ? bind_col_in_db(stmt, inx, *val) : bind_col_in_db(stmt, inx, nullptr); - } - template inline std::unique_ptr get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { - if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { - return nullptr; - } - return std::make_unique(get_col_from_db(stmt, inx, result_type())); - } - template inline std::unique_ptr get_val_from_db(sqlite3_value *value, result_type>) { - if(sqlite3_value_type(value) == SQLITE_NULL) { - return nullptr; - } - return std::make_unique(get_val_from_db(value, result_type())); - } + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const std::unique_ptr& val) { + return val ? bind_col_in_db(stmt, inx, *val) : bind_col_in_db(stmt, inx, nullptr); + } + template inline std::unique_ptr get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return nullptr; + } + return std::make_unique(get_col_from_db(stmt, inx, result_type())); + } + template inline std::unique_ptr get_val_from_db(sqlite3_value *value, result_type>) { + if(sqlite3_value_type(value) == SQLITE_NULL) { + return nullptr; + } + return std::make_unique(get_val_from_db(value, result_type())); + } - // std::optional support for NULL values + // std::optional support for NULL values #ifdef MODERN_SQLITE_STD_OPTIONAL_SUPPORT #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT - template - using optional = std::experimental::optional; + template + using optional = std::experimental::optional; #else - template + template using optional = std::optional; #endif - template - struct has_sqlite_type, Type, void> : has_sqlite_type {}; - template - struct has_sqlite_type, SQLITE_NULL, void> : std::true_type {}; + template + struct has_sqlite_type, Type, void> : has_sqlite_type {}; + template + struct has_sqlite_type, SQLITE_NULL, void> : std::true_type {}; - template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const optional& val) { - return val ? bind_col_in_db(stmt, inx, *val) : bind_col_in_db(stmt, inx, nullptr); - } - template inline void store_result_in_db(sqlite3_context* db, const optional& val) { - if(val) - store_result_in_db(db, *val); - else - sqlite3_result_null(db); - } + template inline int bind_col_in_db(sqlite3_stmt* stmt, int inx, const optional& val) { + return val ? bind_col_in_db(stmt, inx, *val) : bind_col_in_db(stmt, inx, nullptr); + } + template inline void store_result_in_db(sqlite3_context* db, const optional& val) { + if(val) + store_result_in_db(db, *val); + else + sqlite3_result_null(db); + } - template inline optional get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { - #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT - if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { - return std::experimental::nullopt; - } - return std::experimental::make_optional(get_col_from_db(stmt, inx, result_type())); - #else - if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + template inline optional get_col_from_db(sqlite3_stmt* stmt, int inx, result_type>) { +#ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { + return std::experimental::nullopt; + } + return std::experimental::make_optional(get_col_from_db(stmt, inx, result_type())); +#else + if(sqlite3_column_type(stmt, inx) == SQLITE_NULL) { return std::nullopt; } return std::make_optional(get_col_from_db(stmt, inx, result_type())); - #endif - } - template inline optional get_val_from_db(sqlite3_value *value, result_type>) { - #ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT - if(sqlite3_value_type(value) == SQLITE_NULL) { - return std::experimental::nullopt; - } - return std::experimental::make_optional(get_val_from_db(value, result_type())); - #else - if(sqlite3_value_type(value) == SQLITE_NULL) { +#endif + } + template inline optional get_val_from_db(sqlite3_value *value, result_type>) { +#ifdef MODERN_SQLITE_EXPERIMENTAL_OPTIONAL_SUPPORT + if(sqlite3_value_type(value) == SQLITE_NULL) { + return std::experimental::nullopt; + } + return std::experimental::make_optional(get_val_from_db(value, result_type())); +#else + if(sqlite3_value_type(value) == SQLITE_NULL) { return std::nullopt; } return std::make_optional(get_val_from_db(value, result_type())); - #endif - } +#endif + } #endif #ifdef MODERN_SQLITE_STD_VARIANT_SUPPORT - namespace detail { + namespace detail { template struct tag_trait : U { using tag = T; }; } @@ -415,4 +415,4 @@ namespace sqlite { }); } #endif -} +} \ No newline at end of file diff --git a/src/test/claimtriecache_tests.cpp b/src/test/claimtriecache_tests.cpp index b346db5bc..e534e10c6 100644 --- a/src/test/claimtriecache_tests.cpp +++ b/src/test/claimtriecache_tests.cpp @@ -303,8 +303,7 @@ BOOST_AUTO_TEST_CASE(takeover_workaround_triggers) CClaimTrieCacheTest cache(&trie); insertUndoType icu, isu; claimQueueRowType ecu; supportQueueRowType esu; - std::vector> thu; - BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu, thu)); + BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu)); CClaimValue value; value.nHeight = 1; @@ -317,9 +316,9 @@ BOOST_AUTO_TEST_CASE(takeover_workaround_triggers) BOOST_CHECK(cache.insertClaimIntoTrie("cc", value)); BOOST_CHECK(cache.insertSupportIntoMap("aa", CSupportValue())); - BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu, thu)); + BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu)); BOOST_CHECK(cache.flush()); - BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu, thu)); + BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu)); BOOST_CHECK_EQUAL(0, cache.getTotalNamesInTrie()); CSupportValue temp; @@ -336,7 +335,7 @@ BOOST_AUTO_TEST_CASE(takeover_workaround_triggers) BOOST_CHECK(cache.insertClaimIntoTrie("bb", value)); BOOST_CHECK(cache.insertClaimIntoTrie("cc", value)); - BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu, thu)); + BOOST_CHECK(cache.incrementBlock(icu, ecu, isu, esu)); BOOST_CHECK(cache.getInfoForName("aa", cv)); BOOST_CHECK_EQUAL(3, cv.nValidAtHeight); diff --git a/src/test/claimtrienormalization_tests.cpp b/src/test/claimtrienormalization_tests.cpp index 0c15d4782..1a93570be 100644 --- a/src/test/claimtrienormalization_tests.cpp +++ b/src/test/claimtrienormalization_tests.cpp @@ -227,8 +227,7 @@ BOOST_AUTO_TEST_CASE(claimtriecache_normalization) claimQueueRowType expireUndo; insertUndoType insertSupportUndo; supportQueueRowType expireSupportUndo; - std::vector > takeoverHeightUndo; - BOOST_CHECK(trieCache.incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo, takeoverHeightUndo)); + BOOST_CHECK(trieCache.incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo)); BOOST_CHECK(trieCache.shouldNormalize()); } @@ -314,14 +313,13 @@ BOOST_AUTO_TEST_CASE(normalization_removal_test) claimQueueRowType expireUndo; insertUndoType insertSupportUndo; supportQueueRowType expireSupportUndo; - std::vector > takeoverHeightUndo; - BOOST_CHECK(cache.incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo, takeoverHeightUndo)); + BOOST_CHECK(cache.incrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo)); BOOST_CHECK(cache.getClaimsForName("ab").claimsNsupports.size() == 3U); BOOST_CHECK(cache.getClaimsForName("ab").claimsNsupports[0].supports.size() == 1U); BOOST_CHECK(cache.getClaimsForName("ab").claimsNsupports[1].supports.size() == 0U); BOOST_CHECK(cache.getClaimsForName("ab").claimsNsupports[2].supports.size() == 1U); BOOST_CHECK(cache.decrementBlock(insertUndo, expireUndo, insertSupportUndo, expireSupportUndo)); - BOOST_CHECK(cache.finalizeDecrement(takeoverHeightUndo)); + BOOST_CHECK(cache.finalizeDecrement()); std::string unused; BOOST_CHECK(cache.removeSupport(COutPoint(sx1.GetHash(), 0), unused, height)); BOOST_CHECK(cache.removeSupport(COutPoint(sx2.GetHash(), 0), unused, height)); diff --git a/src/undo.h b/src/undo.h index 3765df918..69a8bdb18 100644 --- a/src/undo.h +++ b/src/undo.h @@ -80,7 +80,6 @@ public: claimQueueRowType expireUndo; // any claims that expired insertUndoType insertSupportUndo; // any supports that went from the support queue to the support map supportQueueRowType expireSupportUndo; // any supports that expired - std::vector > takeoverHeightUndo; // for any name that was taken over, the previous time that name was taken over ADD_SERIALIZE_METHODS; @@ -91,7 +90,6 @@ public: READWRITE(expireUndo); READWRITE(insertSupportUndo); READWRITE(expireSupportUndo); - READWRITE(takeoverHeightUndo); } }; diff --git a/src/validation.cpp b/src/validation.cpp index c86fd1289..08134af02 100644 --- a/src/validation.cpp +++ b/src/validation.cpp @@ -1824,7 +1824,7 @@ DisconnectResult CChainState::DisconnectBlock(const CBlock& block, const CBlockI // move best block pointer to prevout block view.SetBestBlock(pindex->pprev->GetBlockHash()); - assert(trieCache.finalizeDecrement(blockUndo.takeoverHeightUndo)); + assert(trieCache.finalizeDecrement()); auto merkleHash = trieCache.getMerkleHash(); if (merkleHash != pindex->pprev->hashClaimTrie) { LogPrintf("Hash comparison failure at block %d\n", pindex->nHeight); @@ -2309,7 +2309,7 @@ bool CChainState::ConnectBlock(const CBlock& block, CValidationState& state, CBl } // TODO: if the "just check" flag is set, we should reduce the work done here. Incrementing blocks twice per mine is not efficient. - const auto incremented = trieCache.incrementBlock(blockundo.insertUndo, blockundo.expireUndo, blockundo.insertSupportUndo, blockundo.expireSupportUndo, blockundo.takeoverHeightUndo); + const auto incremented = trieCache.incrementBlock(blockundo.insertUndo, blockundo.expireUndo, blockundo.insertSupportUndo, blockundo.expireSupportUndo); assert(incremented); if (trieCache.getMerkleHash() != block.hashClaimTrie)