Define json functions for old SQLite versions

This commit is contained in:
Šarūnas Nejus 2024-06-16 17:56:44 +01:00
parent 1f93207823
commit 11d584aa83
No known key found for this signature in database
GPG key ID: DD28F6704DBE3435

View file

@ -26,7 +26,7 @@ import threading
import time
from abc import ABC
from collections import defaultdict
from sqlite3 import Connection
from sqlite3 import Connection, sqlite_version
from types import TracebackType
from typing import (
Any,
@ -50,6 +50,7 @@ from typing import (
cast,
)
from packaging.version import Version
from rich import print
from rich_tables.generic import flexitable
from unidecode import unidecode
@ -1115,9 +1116,71 @@ class Database:
return bytestring
def json_patch(first: str, second: str) -> str:
"""Implementation of the 'json_patch' SQL function.
This function merges two JSON strings together.
"""
first_dict = json.loads(first)
second_dict = json.loads(second)
first_dict.update(second_dict)
return json.dumps(first_dict)
def json_extract(json_str: str, key: str) -> Optional[str]:
"""Simple implementation of the 'json_extract' SQLite function.
The original implementation in SQLite allows traversing objects of
any depth. Here, we only ever deal with a flat dictionary, thus
we can simplify the implementation to a single 'get' call.
"""
if json_str:
return json.loads(json_str).get(key.replace("$.", ""))
return None
class JSONGroupObject:
"""Implementation of the 'json_group_object' SQLite aggregate.
An aggregate function which accepts two values (key, val) and
groups all {key: val} pairs into a single object.
It is found in the json1 extension which is included in SQLite
by default since version 3.38.0 (2022-02-22). To ensure support
for older SQLite versions, we add our implementation.
Notably, it does not exist on Windows in Python 3.8.
Consider the following table
id key val
1 plays "10"
1 skips "20"
2 city "London"
SELECT id, group_to_json(key, val) GROUP BY id
1, '{"plays": "10", "skips": "20"}'
2, '{"city": "London"}'
"""
def __init__(self):
self.flex = {}
def step(self, field, value):
if field:
self.flex[field] = value
def finalize(self):
return json.dumps(self.flex)
conn.create_function("regexp", 2, regexp)
conn.create_function("unidecode", 1, unidecode)
conn.create_function("bytelower", 1, bytelower)
if Version(sqlite_version) < Version("3.38.0"):
# create 'json_group_object' for older SQLite versions that do
# not include the json1 extension by default
conn.create_aggregate("json_group_object", 2, JSONGroupObject)
conn.create_function("json_patch", 2, json_patch)
conn.create_function("json_extract", 2, json_extract)
def _close(self):
"""Close the all connections to the underlying SQLite database