Define json functions for old SQLite versions

This commit is contained in:
Šarūnas Nejus 2024-06-16 17:56:44 +01:00
parent 856585392d
commit 8b4b5bb1f5
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435

View file

@ -27,10 +27,11 @@ import time
from abc import ABC
from collections import defaultdict
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
from sqlite3 import Connection
from sqlite3 import Connection, sqlite_version
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic, TypeVar, cast
from mediafile import MediaFile
from packaging.version import Version
from rich import print
from rich_tables.generic import flexitable
from unidecode import unidecode
@ -1128,9 +1129,71 @@ class Database:
return bytestring
def json_patch(first: str, second: str) -> str:
"""Implementation of the 'json_patch' SQL function.
This function merges two JSON strings together.
"""
first_dict = json.loads(first)
second_dict = json.loads(second)
first_dict.update(second_dict)
return json.dumps(first_dict)
def json_extract(json_str: str, key: str) -> str | None:
"""Simple implementation of the 'json_extract' SQLite function.
The original implementation in SQLite allows traversing objects of
any depth. Here, we only ever deal with a flat dictionary, thus
we can simplify the implementation to a single 'get' call.
"""
if json_str:
return json.loads(json_str).get(key.replace("$.", ""))
return None
class JSONGroupObject:
"""Implementation of the 'json_group_object' SQLite aggregate.
An aggregate function which accepts two values (key, val) and
groups all {key: val} pairs into a single object.
It is found in the json1 extension which is included in SQLite
by default since version 3.38.0 (2022-02-22). To ensure support
for older SQLite versions, we add our implementation.
Notably, it does not exist on Windows in Python 3.8.
Consider the following table
id key val
1 plays "10"
1 skips "20"
2 city "London"
SELECT id, group_to_json(key, val) GROUP BY id
1, '{"plays": "10", "skips": "20"}'
2, '{"city": "London"}'
"""
def __init__(self) -> None:
self.flex: dict[str, SQLiteType] = {}
def step(self, field: str, value: SQLiteType) -> None:
if field:
self.flex[field] = value
def finalize(self) -> str:
return json.dumps(self.flex)
conn.create_function("regexp", 2, regexp)
conn.create_function("unidecode", 1, unidecode)
conn.create_function("bytelower", 1, bytelower)
if Version(sqlite_version) < Version("3.38.0"):
# create 'json_group_object' for older SQLite versions that do
# not include the json1 extension by default
conn.create_aggregate("json_group_object", 2, JSONGroupObject)
conn.create_function("json_patch", 2, json_patch)
conn.create_function("json_extract", 2, json_extract)
def _close(self):
"""Close the all connections to the underlying SQLite database