hat.event.backends.lmdb.refdb

  1import collections
  2import itertools
  3import typing
  4
  5import lmdb
  6
  7from hat import util
  8
  9from hat.event.backends.lmdb import common
 10from hat.event.backends.lmdb import environment
 11
 12
 13# depending on dict order for added refs
 14class ServerChanges(typing.NamedTuple):
 15    added: dict[common.EventId,
 16                tuple[common.Event,
 17                      set[common.EventRef]]]
 18    removed: dict[common.EventId,
 19                  set[common.EventRef]]
 20
 21
 22Changes: typing.TypeAlias = dict[common.ServerId, ServerChanges]
 23
 24
 25class RefDb:
 26
 27    def __init__(self,
 28                 env: environment.Environment,
 29                 max_results: int = 4096):
 30        self._env = env
 31        self._max_results = max_results
 32        self._last_event_ids = {}
 33        self._changes = collections.defaultdict(_create_server_changes)
 34
 35    def add(self,
 36            event: common.Event,
 37            refs: typing.Iterable[common.EventRef]):
 38        last_event_id = self._last_event_ids.get(event.id.server)
 39        if last_event_id and last_event_id > event.id:
 40            raise Exception('event older than last')
 41
 42        self._last_event_ids[event.id.server] = event.id
 43
 44        server_changes = self._changes[event.id.server]
 45        server_changes.added[event.id] = event, set(refs)
 46
 47    def remove(self,
 48               event_id: common.EventId,
 49               ref: common.EventRef):
 50        server_changes = self._changes[event_id.server]
 51        added = server_changes.added.get(event_id)
 52
 53        if added and ref in added[1]:
 54            added[1].remove(ref)
 55
 56            if not added[1]:
 57                server_changes.added.pop(event_id)
 58
 59            return
 60
 61        removed = server_changes.removed[event_id]
 62        removed.add(ref)
 63
 64    async def query(self,
 65                    params: common.QueryServerParams
 66                    ) -> common.QueryResult:
 67        if (params.last_event_id and
 68                params.last_event_id.server != params.server_id):
 69            raise ValueError('invalid server id')
 70
 71        max_results = (params.max_results
 72                       if params.max_results is not None and
 73                       params.max_results < self._max_results
 74                       else self._max_results)
 75
 76        changes = self._changes
 77        events = await self._env.execute(_ext_query_events, self._env,
 78                                         params.server_id, max_results + 1,
 79                                         params.last_event_id)
 80
 81        if not params.persisted and len(events) <= max_results:
 82            last_event_id = events[-1].id if events else params.last_event_id
 83            changes_max_result = max_results + 1 - len(events)
 84            changes_events = _query_changes(changes, params.server_id,
 85                                            changes_max_result, last_event_id)
 86
 87            events.extend(changes_events)
 88
 89        if len(events) > max_results:
 90            events = list(itertools.islice(events, max_results))
 91            more_follows = True
 92
 93        else:
 94            events = list(events)
 95            more_follows = False
 96
 97        return common.QueryResult(events=events,
 98                                  more_follows=more_follows)
 99
100    def create_changes(self) -> Changes:
101        self._changes, changes = (
102            collections.defaultdict(_create_server_changes), self._changes)
103        return changes
104
105    def ext_write(self,
106                  txn: lmdb.Transaction,
107                  changes: Changes):
108        db_def = common.db_defs[common.DbType.REF]
109
110        with self._env.ext_cursor(txn, common.DbType.REF) as cursor:
111            for server_changes in changes.values():
112                event_ids = {*server_changes.added.keys(),
113                             *server_changes.removed.keys()}
114
115                for event_id in event_ids:
116                    encoded_key = db_def.encode_key(event_id)
117                    encoded_value = cursor.pop(encoded_key)
118                    value = (db_def.decode_value(encoded_value)
119                             if encoded_value else set())
120
121                    added = server_changes.added.get(event_id)
122                    if added:
123                        value.update(added[1])
124
125                    removed = server_changes.removed.get(event_id)
126                    if removed:
127                        value.difference_update(removed)
128
129                    if not value:
130                        continue
131
132                    encoded_value = db_def.encode_value(value)
133                    cursor.put(encoded_key, encoded_value)
134
135    def ext_cleanup(self,
136                    txn: lmdb.Transaction,
137                    refs: typing.Iterable[tuple[common.EventId,
138                                                common.EventRef]]):
139        db_def = common.db_defs[common.DbType.REF]
140
141        with self._env.ext_cursor(txn, common.DbType.REF) as cursor:
142            for event_id, ref in refs:
143                encoded_key = db_def.encode_key(event_id)
144                encoded_value = cursor.pop(encoded_key)
145                if not encoded_value:
146                    continue
147
148                value = db_def.decode_value(encoded_value)
149                value.discard(ref)
150
151                if not value:
152                    continue
153
154                encoded_value = db_def.encode_value(value)
155                cursor.put(encoded_key, encoded_value)
156
157
158def _query_changes(changes, server_id, max_results, last_event_id):
159    server_changes = changes.get(server_id)
160    if not server_changes:
161        return []
162
163    events = (event for _, (event, __) in server_changes.added.items())
164
165    if last_event_id:
166        events = itertools.dropwhile(lambda i: i.id <= last_event_id,
167                                     events)
168
169    return itertools.islice(events, max_results)
170
171
172def _ext_query_events(env, server_id, max_results, last_event_id):
173    db_def = common.db_defs[common.DbType.REF]
174
175    start_key = (last_event_id if last_event_id
176                 else common.EventId(server=server_id,
177                                     session=0,
178                                     instance=0))
179    stop_key = common.EventId(server=server_id + 1,
180                              session=0,
181                              instance=0)
182
183    encoded_start_key = db_def.encode_key(start_key)
184    encoded_stop_key = db_def.encode_key(stop_key)
185
186    events = collections.deque()
187
188    with env.ext_begin() as txn:
189        with env.ext_cursor(txn, common.DbType.REF) as cursor:
190            available = cursor.set_range(encoded_start_key)
191            if available and bytes(cursor.key()) == encoded_start_key:
192                available = cursor.next()
193
194            while (available and
195                    len(events) < max_results and
196                    bytes(cursor.key()) < encoded_stop_key):
197                value = db_def.decode_value(cursor.value())
198                ref = util.first(value)
199                if not ref:
200                    continue
201
202                event = _ext_get_event(env, txn, ref)
203                if event:
204                    # TODO decode key and check event_id == event.id
205                    events.append(event)
206
207                available = cursor.next()
208
209    return events
210
211
212def _ext_get_event(env, txn, ref):
213    if isinstance(ref, common.LatestEventRef):
214        db_type = common.DbType.LATEST_DATA
215
216    elif isinstance(ref, common.TimeseriesEventRef):
217        db_type = common.DbType.TIMESERIES_DATA
218
219    else:
220        raise ValueError('unsupported event reference type')
221
222    db_def = common.db_defs[db_type]
223    encoded_key = db_def.encode_key(ref.key)
224
225    with env.ext_cursor(txn, db_type) as cursor:
226        encoded_value = cursor.get(encoded_key)
227        if not encoded_value:
228            return
229
230        return db_def.decode_value(encoded_value)
231
232
233def _create_server_changes():
234    return ServerChanges({}, collections.defaultdict(set))
class ServerChanges(typing.NamedTuple):
15class ServerChanges(typing.NamedTuple):
16    added: dict[common.EventId,
17                tuple[common.Event,
18                      set[common.EventRef]]]
19    removed: dict[common.EventId,
20                  set[common.EventRef]]

ServerChanges(added, removed)

Changes: TypeAlias = dict[int, ServerChanges]
class RefDb:
 26class RefDb:
 27
 28    def __init__(self,
 29                 env: environment.Environment,
 30                 max_results: int = 4096):
 31        self._env = env
 32        self._max_results = max_results
 33        self._last_event_ids = {}
 34        self._changes = collections.defaultdict(_create_server_changes)
 35
 36    def add(self,
 37            event: common.Event,
 38            refs: typing.Iterable[common.EventRef]):
 39        last_event_id = self._last_event_ids.get(event.id.server)
 40        if last_event_id and last_event_id > event.id:
 41            raise Exception('event older than last')
 42
 43        self._last_event_ids[event.id.server] = event.id
 44
 45        server_changes = self._changes[event.id.server]
 46        server_changes.added[event.id] = event, set(refs)
 47
 48    def remove(self,
 49               event_id: common.EventId,
 50               ref: common.EventRef):
 51        server_changes = self._changes[event_id.server]
 52        added = server_changes.added.get(event_id)
 53
 54        if added and ref in added[1]:
 55            added[1].remove(ref)
 56
 57            if not added[1]:
 58                server_changes.added.pop(event_id)
 59
 60            return
 61
 62        removed = server_changes.removed[event_id]
 63        removed.add(ref)
 64
 65    async def query(self,
 66                    params: common.QueryServerParams
 67                    ) -> common.QueryResult:
 68        if (params.last_event_id and
 69                params.last_event_id.server != params.server_id):
 70            raise ValueError('invalid server id')
 71
 72        max_results = (params.max_results
 73                       if params.max_results is not None and
 74                       params.max_results < self._max_results
 75                       else self._max_results)
 76
 77        changes = self._changes
 78        events = await self._env.execute(_ext_query_events, self._env,
 79                                         params.server_id, max_results + 1,
 80                                         params.last_event_id)
 81
 82        if not params.persisted and len(events) <= max_results:
 83            last_event_id = events[-1].id if events else params.last_event_id
 84            changes_max_result = max_results + 1 - len(events)
 85            changes_events = _query_changes(changes, params.server_id,
 86                                            changes_max_result, last_event_id)
 87
 88            events.extend(changes_events)
 89
 90        if len(events) > max_results:
 91            events = list(itertools.islice(events, max_results))
 92            more_follows = True
 93
 94        else:
 95            events = list(events)
 96            more_follows = False
 97
 98        return common.QueryResult(events=events,
 99                                  more_follows=more_follows)
100
101    def create_changes(self) -> Changes:
102        self._changes, changes = (
103            collections.defaultdict(_create_server_changes), self._changes)
104        return changes
105
106    def ext_write(self,
107                  txn: lmdb.Transaction,
108                  changes: Changes):
109        db_def = common.db_defs[common.DbType.REF]
110
111        with self._env.ext_cursor(txn, common.DbType.REF) as cursor:
112            for server_changes in changes.values():
113                event_ids = {*server_changes.added.keys(),
114                             *server_changes.removed.keys()}
115
116                for event_id in event_ids:
117                    encoded_key = db_def.encode_key(event_id)
118                    encoded_value = cursor.pop(encoded_key)
119                    value = (db_def.decode_value(encoded_value)
120                             if encoded_value else set())
121
122                    added = server_changes.added.get(event_id)
123                    if added:
124                        value.update(added[1])
125
126                    removed = server_changes.removed.get(event_id)
127                    if removed:
128                        value.difference_update(removed)
129
130                    if not value:
131                        continue
132
133                    encoded_value = db_def.encode_value(value)
134                    cursor.put(encoded_key, encoded_value)
135
136    def ext_cleanup(self,
137                    txn: lmdb.Transaction,
138                    refs: typing.Iterable[tuple[common.EventId,
139                                                common.EventRef]]):
140        db_def = common.db_defs[common.DbType.REF]
141
142        with self._env.ext_cursor(txn, common.DbType.REF) as cursor:
143            for event_id, ref in refs:
144                encoded_key = db_def.encode_key(event_id)
145                encoded_value = cursor.pop(encoded_key)
146                if not encoded_value:
147                    continue
148
149                value = db_def.decode_value(encoded_value)
150                value.discard(ref)
151
152                if not value:
153                    continue
154
155                encoded_value = db_def.encode_value(value)
156                cursor.put(encoded_key, encoded_value)
RefDb( env: hat.event.backends.lmdb.environment.Environment, max_results: int = 4096)
28    def __init__(self,
29                 env: environment.Environment,
30                 max_results: int = 4096):
31        self._env = env
32        self._max_results = max_results
33        self._last_event_ids = {}
34        self._changes = collections.defaultdict(_create_server_changes)
36    def add(self,
37            event: common.Event,
38            refs: typing.Iterable[common.EventRef]):
39        last_event_id = self._last_event_ids.get(event.id.server)
40        if last_event_id and last_event_id > event.id:
41            raise Exception('event older than last')
42
43        self._last_event_ids[event.id.server] = event.id
44
45        server_changes = self._changes[event.id.server]
46        server_changes.added[event.id] = event, set(refs)
48    def remove(self,
49               event_id: common.EventId,
50               ref: common.EventRef):
51        server_changes = self._changes[event_id.server]
52        added = server_changes.added.get(event_id)
53
54        if added and ref in added[1]:
55            added[1].remove(ref)
56
57            if not added[1]:
58                server_changes.added.pop(event_id)
59
60            return
61
62        removed = server_changes.removed[event_id]
63        removed.add(ref)
async def query( self, params: hat.event.common.QueryServerParams) -> hat.event.common.QueryResult:
65    async def query(self,
66                    params: common.QueryServerParams
67                    ) -> common.QueryResult:
68        if (params.last_event_id and
69                params.last_event_id.server != params.server_id):
70            raise ValueError('invalid server id')
71
72        max_results = (params.max_results
73                       if params.max_results is not None and
74                       params.max_results < self._max_results
75                       else self._max_results)
76
77        changes = self._changes
78        events = await self._env.execute(_ext_query_events, self._env,
79                                         params.server_id, max_results + 1,
80                                         params.last_event_id)
81
82        if not params.persisted and len(events) <= max_results:
83            last_event_id = events[-1].id if events else params.last_event_id
84            changes_max_result = max_results + 1 - len(events)
85            changes_events = _query_changes(changes, params.server_id,
86                                            changes_max_result, last_event_id)
87
88            events.extend(changes_events)
89
90        if len(events) > max_results:
91            events = list(itertools.islice(events, max_results))
92            more_follows = True
93
94        else:
95            events = list(events)
96            more_follows = False
97
98        return common.QueryResult(events=events,
99                                  more_follows=more_follows)
def create_changes(self) -> dict[int, ServerChanges]:
101    def create_changes(self) -> Changes:
102        self._changes, changes = (
103            collections.defaultdict(_create_server_changes), self._changes)
104        return changes
def ext_write( self, txn: Transaction, changes: dict[int, ServerChanges]):
106    def ext_write(self,
107                  txn: lmdb.Transaction,
108                  changes: Changes):
109        db_def = common.db_defs[common.DbType.REF]
110
111        with self._env.ext_cursor(txn, common.DbType.REF) as cursor:
112            for server_changes in changes.values():
113                event_ids = {*server_changes.added.keys(),
114                             *server_changes.removed.keys()}
115
116                for event_id in event_ids:
117                    encoded_key = db_def.encode_key(event_id)
118                    encoded_value = cursor.pop(encoded_key)
119                    value = (db_def.decode_value(encoded_value)
120                             if encoded_value else set())
121
122                    added = server_changes.added.get(event_id)
123                    if added:
124                        value.update(added[1])
125
126                    removed = server_changes.removed.get(event_id)
127                    if removed:
128                        value.difference_update(removed)
129
130                    if not value:
131                        continue
132
133                    encoded_value = db_def.encode_value(value)
134                    cursor.put(encoded_key, encoded_value)
def ext_cleanup( self, txn: Transaction, refs: Iterable[tuple[hat.event.common.EventId, hat.event.backends.lmdb.common.LatestEventRef | hat.event.backends.lmdb.common.TimeseriesEventRef]]):
136    def ext_cleanup(self,
137                    txn: lmdb.Transaction,
138                    refs: typing.Iterable[tuple[common.EventId,
139                                                common.EventRef]]):
140        db_def = common.db_defs[common.DbType.REF]
141
142        with self._env.ext_cursor(txn, common.DbType.REF) as cursor:
143            for event_id, ref in refs:
144                encoded_key = db_def.encode_key(event_id)
145                encoded_value = cursor.pop(encoded_key)
146                if not encoded_value:
147                    continue
148
149                value = db_def.decode_value(encoded_value)
150                value.discard(ref)
151
152                if not value:
153                    continue
154
155                encoded_value = db_def.encode_value(value)
156                cursor.put(encoded_key, encoded_value)