Skip to content

Commit 3f80a02

Browse files
asg017simonw
andauthored
prepare_connection plugin hook
Closes: - #574 Refs #567 --------- Co-authored-by: Simon Willison <swillison@gmail.com>
1 parent 091c63c commit 3f80a02

5 files changed

Lines changed: 79 additions & 3 deletions

File tree

docs/plugins.rst

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ See the `LLM plugin documentation <https://llm.datasette.io/en/stable/plugins/tu
8686
Plugin hooks
8787
------------
8888
89-
Plugin hooks allow ``sqlite-utils`` to be customized. There is currently one hook.
89+
Plugin hooks allow ``sqlite-utils`` to be customized.
9090
9191
.. _plugins_hooks_register_commands:
9292
@@ -109,3 +109,29 @@ Example implementation:
109109
"Say hello world"
110110
click.echo("Hello world!")
111111
112+
.. _plugins_hooks_prepare_connection:
113+
114+
prepare_connection(conn)
115+
~~~~~~~~~~~~~~~~~~~~~~~~
116+
117+
This hook is called when a new SQLite database connection is created. You can
118+
use it to `register custom SQL functions <https://docs.python.org/2/library/sqlite3.html#sqlite3.Connection.create_function>`_,
119+
aggregates and collations. For example:
120+
121+
.. code-block:: python
122+
123+
import click
124+
import sqlite_utils
125+
126+
@sqlite_utils.hookimpl
127+
def prepare_connection(conn):
128+
conn.create_function(
129+
"hello", 1, lambda name: f"Hello, {name}!"
130+
)
131+
132+
This registers a SQL function called ``hello`` which takes a single
133+
argument and can be called like this:
134+
135+
.. code-block:: sql
136+
137+
select hello("world"); -- "Hello, world!"

sqlite_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .db import Database
21
from .utils import suggest_column_types
32
from .hookspecs import hookimpl
43
from .hookspecs import hookspec
4+
from .db import Database
55

66
__all__ = ["Database", "suggest_column_types", "hookimpl", "hookspec"]

sqlite_utils/db.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Tuple,
3838
)
3939
import uuid
40+
from sqlite_utils.plugins import pm
4041

4142
try:
4243
from sqlite_dump import iterdump
@@ -342,6 +343,8 @@ def __init__(
342343
self._registered_functions: set = set()
343344
self.use_counts_table = use_counts_table
344345

346+
pm.hook.prepare_connection(conn=self.conn)
347+
345348
def close(self):
346349
"Close the SQLite connection, and the underlying database file"
347350
self.conn.close()

sqlite_utils/hookspecs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@
88
@hookspec
99
def register_commands(cli):
1010
"""Register additional CLI commands, e.g. 'sqlite-utils mycommand ...'"""
11+
12+
13+
@hookspec
14+
def prepare_connection(conn):
15+
"""Modify SQLite connection in some way e.g. register custom SQL functions"""

tests/test_plugins.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from click.testing import CliRunner
22
import click
33
import importlib
4-
from sqlite_utils import cli, hookimpl, plugins
4+
from sqlite_utils import cli, Database, hookimpl, plugins
55

66

77
def test_register_commands():
@@ -35,3 +35,45 @@ def hello_world():
3535
plugins.pm.unregister(name="HelloWorldPlugin")
3636
importlib.reload(cli)
3737
assert plugins.get_plugins() == []
38+
39+
40+
def test_prepare_connection():
41+
importlib.reload(cli)
42+
assert plugins.get_plugins() == []
43+
44+
class HelloFunctionPlugin:
45+
__name__ = "HelloFunctionPlugin"
46+
47+
@hookimpl
48+
def prepare_connection(self, conn):
49+
conn.create_function("hello", 1, lambda name: f"Hello, {name}!")
50+
51+
db = Database(memory=True)
52+
functions = db.execute(
53+
"select distinct name from pragma_function_list order by 1"
54+
).fetchall()
55+
assert "hello" not in functions
56+
57+
try:
58+
plugins.pm.register(HelloFunctionPlugin(), name="HelloFunctionPlugin")
59+
60+
assert plugins.get_plugins() == [
61+
{"name": "HelloFunctionPlugin", "hooks": ["prepare_connection"]}
62+
]
63+
64+
db = Database(memory=True)
65+
66+
functions = [
67+
row[0]
68+
for row in db.execute(
69+
"select distinct name from pragma_function_list order by 1"
70+
).fetchall()
71+
]
72+
assert "hello" in functions
73+
74+
result = db.execute('select hello("world")').fetchone()[0]
75+
assert result == "Hello, world!"
76+
77+
finally:
78+
plugins.pm.unregister(name="HelloFunctionPlugin")
79+
assert plugins.get_plugins() == []

0 commit comments

Comments
 (0)