refactor trending

This commit is contained in:
Jack Robison 2021-09-03 00:33:40 -04:00 committed by Victor Shyba
parent da75968078
commit 8f9e6a519d
6 changed files with 174 additions and 152 deletions

View file

@ -5,7 +5,7 @@ import struct
from bisect import bisect_right
from struct import pack, unpack
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Optional, List, Tuple, Set, DefaultDict, Dict
from typing import Optional, List, Tuple, Set, DefaultDict, Dict, NamedTuple
from prometheus_client import Gauge, Histogram
from collections import defaultdict
import array
@ -35,6 +35,13 @@ if typing.TYPE_CHECKING:
from lbry.wallet.server.leveldb import LevelDB
class TrendingNotification(NamedTuple):
height: int
added: bool
prev_amount: int
new_amount: int
class Prefetcher:
"""Prefetches blocks (in the forward direction only)."""
@ -245,6 +252,7 @@ class BlockProcessor:
self.removed_claims_to_send_es = set() # cumulative changes across blocks to send ES
self.touched_claims_to_send_es = set()
self.activation_info_to_send_es: DefaultDict[str, List[TrendingNotification]] = defaultdict(list)
self.removed_claim_hashes: Set[bytes] = set() # per block changes
self.touched_claim_hashes: Set[bytes] = set()
@ -316,16 +324,17 @@ class BlockProcessor:
"applying extended claim expiration fork on claims accepted by, %i", self.height
)
await self.run_in_thread(self.db.apply_expiration_extension_fork)
# TODO: we shouldnt wait on the search index updating before advancing to the next block
if not self.db.first_sync:
self.db.reload_blocking_filtering_streams()
await self.db.search_index.claim_consumer(self.claim_producer())
await self.db.search_index.apply_filters(self.db.blocked_streams, self.db.blocked_channels,
# TODO: we shouldnt wait on the search index updating before advancing to the next block
if not self.db.first_sync:
self.db.reload_blocking_filtering_streams()
await self.db.search_index.claim_consumer(self.claim_producer())
await self.db.search_index.apply_filters(self.db.blocked_streams, self.db.blocked_channels,
self.db.filtered_streams, self.db.filtered_channels)
await self.db.search_index.apply_update_and_decay_trending_score()
self.db.search_index.clear_caches()
self.touched_claims_to_send_es.clear()
self.removed_claims_to_send_es.clear()
await self.db.search_index.update_trending_score(self.activation_info_to_send_es)
self.db.search_index.clear_caches()
self.touched_claims_to_send_es.clear()
self.removed_claims_to_send_es.clear()
self.activation_info_to_send_es.clear()
# print("******************\n")
except:
self.logger.exception("advance blocks failed")
@ -369,6 +378,7 @@ class BlockProcessor:
self.db.search_index.clear_caches()
self.touched_claims_to_send_es.clear()
self.removed_claims_to_send_es.clear()
self.activation_info_to_send_es.clear()
await self.prefetcher.reset_height(self.height)
self.reorg_count_metric.inc()
except:
@ -518,11 +528,6 @@ class BlockProcessor:
self.claim_hash_to_txo[claim_hash] = (tx_num, nout)
self.db_op_stack.extend_ops(pending.get_add_claim_utxo_ops())
# add the spike for trending
self.db_op_stack.append_op(self.db.prefix_db.trending_spike.pack_spike(
height, claim_hash, tx_num, nout, txo.amount, half_life=self.env.trending_half_life
))
def _add_support(self, height: int, txo: 'Output', tx_num: int, nout: int):
supported_claim_hash = txo.claim_hash[::-1]
self.support_txos_by_claim[supported_claim_hash].append((tx_num, nout))
@ -532,11 +537,6 @@ class BlockProcessor:
supported_claim_hash, tx_num, nout, txo.amount
).get_add_support_utxo_ops())
# add the spike for trending
self.db_op_stack.append_op(self.db.prefix_db.trending_spike.pack_spike(
height, supported_claim_hash, tx_num, nout, txo.amount, half_life=self.env.trending_half_life
))
def _add_claim_or_support(self, height: int, tx_hash: bytes, tx_num: int, nout: int, txo: 'Output',
spent_claims: typing.Dict[bytes, Tuple[int, int, str]]):
if txo.script.is_claim_name or txo.script.is_update_claim:
@ -552,7 +552,6 @@ class BlockProcessor:
self.support_txos_by_claim[spent_support].remove((txin_num, txin.prev_idx))
supported_name = self._get_pending_claim_name(spent_support)
self.removed_support_txos_by_name_by_claim[supported_name][spent_support].append((txin_num, txin.prev_idx))
txin_height = height
else:
spent_support, support_amount = self.db.get_supported_claim_from_txo(txin_num, txin.prev_idx)
if not spent_support: # it is not a support
@ -562,7 +561,6 @@ class BlockProcessor:
self.removed_support_txos_by_name_by_claim[supported_name][spent_support].append(
(txin_num, txin.prev_idx))
activation = self.db.get_activation(txin_num, txin.prev_idx, is_support=True)
txin_height = bisect_right(self.db.tx_counts, self.db.transaction_num_mapping[txin.prev_hash])
if 0 < activation < self.height + 1:
self.removed_active_support_amount_by_claim[spent_support].append(support_amount)
if supported_name is not None and activation > 0:
@ -574,11 +572,6 @@ class BlockProcessor:
self.db_op_stack.extend_ops(StagedClaimtrieSupport(
spent_support, txin_num, txin.prev_idx, support_amount
).get_spend_support_txo_ops())
# add the spike for trending
self.db_op_stack.append_op(self.db.prefix_db.trending_spike.pack_spike(
height, spent_support, txin_num, txin.prev_idx, support_amount, subtract=True,
depth=height-txin_height-1, half_life=self.env.trending_half_life
))
def _spend_claim_txo(self, txin: TxInput, spent_claims: Dict[bytes, Tuple[int, int, str]]) -> bool:
txin_num = self.db.transaction_num_mapping[txin.prev_hash]
@ -1121,15 +1114,30 @@ class BlockProcessor:
self.touched_claim_hashes.add(controlling.claim_hash)
self.touched_claim_hashes.add(winning)
def _get_cumulative_update_ops(self):
def _add_claim_activation_change_notification(self, claim_id: str, height: int, added: bool, prev_amount: int,
new_amount: int):
self.activation_info_to_send_es[claim_id].append(TrendingNotification(height, added, prev_amount, new_amount))
def _get_cumulative_update_ops(self, height: int):
# gather cumulative removed/touched sets to update the search index
self.removed_claim_hashes.update(set(self.abandoned_claims.keys()))
self.touched_claim_hashes.difference_update(self.removed_claim_hashes)
self.touched_claim_hashes.update(
set(self.activated_support_amount_by_claim.keys()).union(
set(claim_hash for (_, claim_hash) in self.activated_claim_amount_by_name_and_hash.keys())
).union(self.signatures_changed).union(
set(
map(lambda item: item[1], self.activated_claim_amount_by_name_and_hash.keys())
).union(
set(self.claim_hash_to_txo.keys())
).union(
self.removed_active_support_amount_by_claim.keys()
).union(
self.signatures_changed
).union(
set(self.removed_active_support_amount_by_claim.keys())
).difference(self.removed_claim_hashes)
).union(
set(self.activated_support_amount_by_claim.keys())
).difference(
self.removed_claim_hashes
)
)
# use the cumulative changes to update bid ordered resolve
@ -1145,6 +1153,8 @@ class BlockProcessor:
amt.position, removed
))
for touched in self.touched_claim_hashes:
prev_effective_amount = 0
if touched in self.claim_hash_to_txo:
pending = self.txo_to_claim[self.claim_hash_to_txo[touched]]
name, tx_num, position = pending.normalized_name, pending.tx_num, pending.position
@ -1152,6 +1162,7 @@ class BlockProcessor:
if claim_from_db:
claim_amount_info = self.db.get_url_effective_amount(name, touched)
if claim_amount_info:
prev_effective_amount = claim_amount_info.effective_amount
self.db_op_stack.extend_ops(get_remove_effective_amount_ops(
name, claim_amount_info.effective_amount, claim_amount_info.tx_num,
claim_amount_info.position, touched
@ -1163,12 +1174,33 @@ class BlockProcessor:
name, tx_num, position = v.normalized_name, v.tx_num, v.position
amt = self.db.get_url_effective_amount(name, touched)
if amt:
self.db_op_stack.extend_ops(get_remove_effective_amount_ops(
name, amt.effective_amount, amt.tx_num, amt.position, touched
))
prev_effective_amount = amt.effective_amount
self.db_op_stack.extend_ops(
get_remove_effective_amount_ops(
name, amt.effective_amount, amt.tx_num, amt.position, touched
)
)
if (name, touched) in self.activated_claim_amount_by_name_and_hash:
self._add_claim_activation_change_notification(
touched.hex(), height, True, prev_effective_amount,
self.activated_claim_amount_by_name_and_hash[(name, touched)]
)
if touched in self.activated_support_amount_by_claim:
for support_amount in self.activated_support_amount_by_claim[touched]:
self._add_claim_activation_change_notification(
touched.hex(), height, True, prev_effective_amount, support_amount
)
if touched in self.removed_active_support_amount_by_claim:
for support_amount in self.removed_active_support_amount_by_claim[touched]:
self._add_claim_activation_change_notification(
touched.hex(), height, False, prev_effective_amount, support_amount
)
new_effective_amount = self._get_pending_effective_amount(name, touched)
self.db_op_stack.extend_ops(
get_add_effective_amount_ops(name, self._get_pending_effective_amount(name, touched),
tx_num, position, touched)
get_add_effective_amount_ops(
name, new_effective_amount, tx_num, position, touched
)
)
self.touched_claim_hashes.update(
@ -1254,7 +1286,7 @@ class BlockProcessor:
self._get_takeover_ops(height)
# update effective amount and update sets of touched and deleted claims
self._get_cumulative_update_ops()
self._get_cumulative_update_ops(height)
self.db_op_stack.append_op(RevertablePut(*Prefixes.tx_count.pack_item(height, tx_count)))
@ -1441,7 +1473,6 @@ class BlockProcessor:
self.height = self.db.db_height
self.tip = self.db.db_tip
self.tx_count = self.db.db_tx_count
self.status_server.set_height(self.db.fs_height, self.db.db_tip)
await asyncio.wait([
self.prefetcher.main_loop(self.height),

View file

@ -31,7 +31,6 @@ INDEX_DEFAULT_SETTINGS = {
"claim_type": {"type": "byte"},
"censor_type": {"type": "byte"},
"trending_score": {"type": "float"},
"trending_score_change": {"type": "float"},
"release_time": {"type": "long"}
}
}

View file

@ -158,46 +158,74 @@ class SearchIndex:
}
return update
async def apply_update_and_decay_trending_score(self):
async def update_trending_score(self, params):
update_trending_score_script = """
if (ctx._source.trending_score == null) {
ctx._source.trending_score = ctx._source.trending_score_change;
} else {
ctx._source.trending_score += ctx._source.trending_score_change;
double softenLBC(double lbc) { Math.pow(lbc, 1.0f / 3.0f) }
double inflateUnits(int height) { Math.pow(2.0, height / 400.0f) }
double spikePower(double newAmount) {
if (newAmount < 50.0) {
0.5
} else if (newAmount < 85.0) {
newAmount / 100.0
} else {
0.85
}
}
double spikeMass(double oldAmount, double newAmount) {
double softenedChange = softenLBC(Math.abs(newAmount - oldAmount));
double changeInSoftened = Math.abs(softenLBC(newAmount) - softenLBC(oldAmount));
if (oldAmount > newAmount) {
-1.0 * Math.pow(changeInSoftened, spikePower(newAmount)) * Math.pow(softenedChange, 1.0 - spikePower(newAmount))
} else {
Math.pow(changeInSoftened, spikePower(newAmount)) * Math.pow(softenedChange, 1.0 - spikePower(newAmount))
}
}
for (i in params.src.changes) {
if (i.added) {
if (ctx._source.trending_score == null) {
ctx._source.trending_score = spikeMass(i.prev_amount, i.prev_amount + i.new_amount);
} else {
ctx._source.trending_score += spikeMass(i.prev_amount, i.prev_amount + i.new_amount);
}
} else {
if (ctx._source.trending_score == null) {
ctx._source.trending_score = spikeMass(i.prev_amount, i.prev_amount - i.new_amount);
} else {
ctx._source.trending_score += spikeMass(i.prev_amount, i.prev_amount - i.new_amount);
}
}
}
ctx._source.trending_score_change = 0.0;
"""
start = time.perf_counter()
start = time.perf_counter()
await self.sync_client.update_by_query(
self.index, body={
'query': {
'bool': {'must_not': [{'match': {'trending_score_change': 0.0}}]}
},
'script': {'source': update_trending_score_script, 'lang': 'painless'}
}, slices=4, conflicts='proceed'
)
self.logger.info("updated trending scores in %ims", int((time.perf_counter() - start) * 1000))
whale_decay_factor = 2.0 ** ((-1 / self._trending_whale_half_life) + 1)
decay_factor = 2.0 ** ((-1 / self._trending_half_life) + 1)
decay_script = """
if (ctx._source.trending_score == null) { ctx._source.trending_score = 0.0; }
if ((-0.1 <= ctx._source.trending_score) && (ctx._source.trending_score <= 0.1)) {
ctx._source.trending_score = 0.0;
} else if (ctx._source.effective_amount >= %s) {
ctx._source.trending_score *= %s;
} else {
ctx._source.trending_score *= %s;
}
""" % (self._trending_whale_threshold, whale_decay_factor, decay_factor)
start = time.perf_counter()
await self.sync_client.update_by_query(
self.index, body={
'query': {'bool': {'must_not': [{'match': {'trending_score': 0.0}}]}},
'script': {'source': decay_script, 'lang': 'painless'}
}, slices=4, conflicts='proceed'
)
self.logger.info("decayed trending scores in %ims", int((time.perf_counter() - start) * 1000))
def producer():
for claim_id, claim_updates in params.items():
yield {
'_id': claim_id,
'_index': self.index,
'_op_type': 'update',
'script': {
'lang': 'painless',
'source': update_trending_score_script,
'params': {'src': {
'changes': [
{
'height': p.height,
'added': p.added,
'prev_amount': p.prev_amount,
'new_amount': p.new_amount,
} for p in claim_updates
]
}}
},
}
if not params:
return
async for ok, item in async_streaming_bulk(self.sync_client, producer(), raise_on_error=False):
if not ok:
self.logger.warning("updating trending failed for an item: %s", item)
await self.sync_client.indices.refresh(self.index)
self.logger.warning("updated trending scores in %ims", int((time.perf_counter() - start) * 1000))
async def apply_filters(self, blocked_streams, blocked_channels, filtered_streams, filtered_channels):
if filtered_streams:

View file

@ -463,21 +463,6 @@ class TouchedOrDeletedClaimValue(typing.NamedTuple):
f"deleted_claims={','.join(map(lambda x: x.hex(), self.deleted_claims))})"
class TrendingSpikeKey(typing.NamedTuple):
height: int
claim_hash: bytes
tx_num: int
position: int
def __str__(self):
return f"{self.__class__.__name__}(height={self.height}, claim_hash={self.claim_hash.hex()}, " \
f"tx_num={self.tx_num}, position={self.position})"
class TrendingSpikeValue(typing.NamedTuple):
mass: float
class ActiveAmountPrefixRow(PrefixRow):
prefix = DB_PREFIXES.active_amount.value
key_struct = struct.Struct(b'>20sBLLH')
@ -1350,49 +1335,6 @@ class TouchedOrDeletedPrefixRow(PrefixRow):
return cls.pack_key(height), cls.pack_value(touched, deleted)
class TrendingSpikePrefixRow(PrefixRow):
prefix = DB_PREFIXES.trending_spike.value
key_struct = struct.Struct(b'>L20sLH')
value_struct = struct.Struct(b'>f')
key_part_lambdas = [
lambda: b'',
struct.Struct(b'>L').pack,
struct.Struct(b'>L20s').pack,
struct.Struct(b'>L20sL').pack,
struct.Struct(b'>L20sLH').pack
]
@classmethod
def pack_spike(cls, height: int, claim_hash: bytes, tx_num: int, position: int, amount: int, half_life: int,
depth: int = 0, subtract: bool = False) -> RevertablePut:
softened_change = (((amount * 1E-8) + 1E-8) ** (1 / 4))
spike_mass = softened_change * ((2.0 ** (-1 / half_life)) ** depth)
if subtract:
spike_mass = -spike_mass
return RevertablePut(*cls.pack_item(height, claim_hash, tx_num, position, spike_mass))
@classmethod
def pack_key(cls, height: int, claim_hash: bytes, tx_num: int, position: int):
return super().pack_key(height, claim_hash, tx_num, position)
@classmethod
def unpack_key(cls, key: bytes) -> TrendingSpikeKey:
return TrendingSpikeKey(*super().unpack_key(key))
@classmethod
def pack_value(cls, mass: float) -> bytes:
return super().pack_value(mass)
@classmethod
def unpack_value(cls, data: bytes) -> TrendingSpikeValue:
return TrendingSpikeValue(*cls.value_struct.unpack(data))
@classmethod
def pack_item(cls, height: int, claim_hash: bytes, tx_num: int, position: int, mass: float):
return cls.pack_key(height, claim_hash, tx_num, position), cls.pack_value(mass)
class Prefixes:
claim_to_support = ClaimToSupportPrefixRow
support_to_claim = SupportToClaimPrefixRow
@ -1427,7 +1369,6 @@ class Prefixes:
tx = TXPrefixRow
header = BlockHeaderPrefixRow
touched_or_deleted = TouchedOrDeletedPrefixRow
trending_spike = TrendingSpikePrefixRow
class PrefixDB:
@ -1461,7 +1402,6 @@ class PrefixDB:
self.tx = TXPrefixRow(db, op_stack)
self.header = BlockHeaderPrefixRow(db, op_stack)
self.touched_or_deleted = TouchedOrDeletedPrefixRow(db, op_stack)
self.trending_spike = TrendingSpikePrefixRow(db, op_stack)
def commit(self):
try:

View file

@ -18,7 +18,7 @@ import attr
import zlib
import base64
import plyvel
from typing import Optional, Iterable, Tuple, DefaultDict, Set, Dict, List
from typing import Optional, Iterable, Tuple, DefaultDict, Set, Dict, List, TYPE_CHECKING
from functools import partial
from asyncio import sleep
from bisect import bisect_right
@ -44,6 +44,9 @@ from lbry.wallet.ledger import Ledger, RegTestLedger, TestNetLedger
from lbry.wallet.server.db.elasticsearch import SearchIndex
if TYPE_CHECKING:
from lbry.wallet.server.db.prefixes import EffectiveAmountKey
class UTXO(typing.NamedTuple):
tx_num: int
@ -187,12 +190,6 @@ class LevelDB:
cnt += 1
return cnt
def get_trending_spike_sum(self, height: int, claim_hash: bytes) -> float:
spikes = 0.0
for k, v in self.prefix_db.trending_spike.iterate(prefix=(height, claim_hash)):
spikes += v.mass
return spikes
def get_activation(self, tx_num, position, is_support=False) -> int:
activation = self.db.get(
Prefixes.activated.pack_key(
@ -409,9 +406,10 @@ class LevelDB:
def _fs_get_claim_by_hash(self, claim_hash):
claim = self.claim_to_txo.get(claim_hash)
if claim:
activation = self.get_activation(claim.tx_num, claim.position)
return self._prepare_resolve_result(
claim.tx_num, claim.position, claim_hash, claim.name, claim.root_tx_num, claim.root_position,
self.get_activation(claim.tx_num, claim.position), claim.channel_signature_is_valid
activation, claim.channel_signature_is_valid
)
async def fs_getclaimbyid(self, claim_id):
@ -457,7 +455,7 @@ class LevelDB:
return support_only
return support_amount + self._get_active_amount(claim_hash, ACTIVATED_CLAIM_TXO_TYPE, self.db_height + 1)
def get_url_effective_amount(self, name: str, claim_hash: bytes):
def get_url_effective_amount(self, name: str, claim_hash: bytes) -> Optional['EffectiveAmountKey']:
for k, v in self.prefix_db.effective_amount.iterate(prefix=(name,)):
if v.claim_hash == claim_hash:
return k
@ -708,8 +706,7 @@ class LevelDB:
'languages': languages,
'censor_type': Censor.RESOLVE if blocked_hash else Censor.SEARCH if filtered_hash else Censor.NOT_CENSORED,
'censoring_channel_id': (blocked_hash or filtered_hash or b'').hex() or None,
'claims_in_channel': None if not metadata.is_channel else self.get_claims_in_channel_count(claim_hash),
'trending_score_change': self.get_trending_spike_sum(self.db_height, claim_hash)
'claims_in_channel': None if not metadata.is_channel else self.get_claims_in_channel_count(claim_hash)
}
if metadata.is_repost and reposted_duration is not None:
@ -946,11 +943,6 @@ class LevelDB:
stop=Prefixes.touched_or_deleted.pack_key(min_height), include_value=False
)
)
delete_undo_keys.extend(
self.db.iterator(
prefix=Prefixes.trending_spike.pack_partial_key(min_height), include_value=False
)
)
with self.db.write_batch(transaction=True) as batch:
batch_put = batch.put

View file

@ -37,6 +37,15 @@ class BaseResolveTestCase(CommandTestCase):
claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search(name=name)
self.assertListEqual([], claim_from_es[0])
async def assertNoClaim(self, claim_id: str):
self.assertDictEqual(
{}, json.loads(await self.blockchain._cli_cmnd('getclaimbyid', claim_id))
)
claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search(claim_id=claim_id)
self.assertListEqual([], claim_from_es[0])
claim = await self.conductor.spv_node.server.bp.db.fs_getclaimbyid(claim_id)
self.assertIsNone(claim)
async def assertMatchWinningClaim(self, name):
expected = json.loads(await self.blockchain._cli_cmnd('getvalueforname', name))
stream, channel = await self.conductor.spv_node.server.bp.db.fs_resolve(name)
@ -61,6 +70,11 @@ class BaseResolveTestCase(CommandTestCase):
if not expected:
self.assertIsNone(claim)
return
claim_from_es = await self.conductor.spv_node.server.bp.db.search_index.search(
claim_id=claim.claim_hash.hex()
)
self.assertEqual(len(claim_from_es[0]), 1)
self.assertEqual(claim_from_es[0][0]['claim_hash'][::-1].hex(), claim.claim_hash.hex())
self.assertEqual(expected['claimId'], claim.claim_hash.hex())
self.assertEqual(expected['validAtHeight'], claim.activation_height)
self.assertEqual(expected['lastTakeoverHeight'], claim.last_takeover_height)
@ -945,6 +959,24 @@ class ResolveClaimTakeovers(BaseResolveTestCase):
await self.generate(1)
await self.assertNoClaimForName(name)
async def _test_add_non_winning_already_claimed(self):
name = 'derp'
# initially claim the name
first_claim_id = (await self.stream_create(name, '0.1'))['outputs'][0]['claim_id']
self.assertEqual(first_claim_id, (await self.assertMatchWinningClaim(name)).claim_hash.hex())
await self.generate(32)
second_claim_id = (await self.stream_create(name, '0.01', allow_duplicate_name=True))['outputs'][0]['claim_id']
await self.assertNoClaim(second_claim_id)
self.assertEqual(
len((await self.conductor.spv_node.server.bp.db.search_index.search(claim_name=name))[0]), 1
)
await self.generate(1)
await self.assertMatchClaim(second_claim_id)
self.assertEqual(
len((await self.conductor.spv_node.server.bp.db.search_index.search(claim_name=name))[0]), 2
)
class ResolveAfterReorg(BaseResolveTestCase):
async def reorg(self, start):