-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinsert_triples.py
173 lines (153 loc) · 5.14 KB
/
insert_triples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# kg-llm-interface
# Copyright 2023 - Swiss Data Science Center (SDSC)
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
# Eidgenössische Technische Hochschule Zürich (ETHZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This flow populates a SPARQL endpoint from RDF data in a file."""
import os
from pathlib import Path
from typing import Optional
from typing_extensions import Annotated
from dotenv import load_dotenv
from prefect import flow, get_run_logger, task
from SPARQLWrapper import SPARQLWrapper
import typer
from aikg.config.common import parse_yaml_config
from aikg.config import SparqlConfig
@task
def setup_sparql_endpoint(
endpoint: str, user: Optional[str] = None, password: Optional[str] = None
) -> SPARQLWrapper:
"""Connect to SPARQL endpoint and setup credentials.
Parameters
----------
endpoint:
URL of the SPARQL endpoint.
user:
Username to use for authentication.
password:
Password to use for authentication.
"""
# Setup sparql endpoint
sparql = SPARQLWrapper(endpoint, updateEndpoint=endpoint + "/statements")
if user and password:
sparql.setCredentials(user, password)
return sparql
@task
def insert_triples(
rdf_file: Path,
endpoint: SPARQLWrapper,
graph: Optional[str] = None,
chunk_size: int = 1000,
):
"""Insert triples from source file into SPARQL endpoint.
Parameters
----------
rdf_file:
Path to RDF file to load into the SPARQL endpoint.
endpoint:
SPARQL endpoint to load RDF data into.
graph:
URI of named graph to load RDF data into.
If set to None, the default graph is used.
chunk_size:
Number of triples per insert operation.
"""
from rdflib import Dataset
from rdflib.util import guess_format
format = guess_format(str(rdf_file))
if format not in ["nt", "nquads"]:
raise ValueError("Unsupported RDF format, must be ntriples or nquads.")
cur = 0
tot = os.path.getsize(rdf_file)
with open(rdf_file, "r", encoding="utf-8") as source:
# Run INSERT DATA queries by chunks of triples
while True:
data = "".join([source.readline() for _ in range(chunk_size)])
if data == "":
break
ds = Dataset()
ds.parse(data=data, format=format)
query = "\n".join(
[f"PREFIX {prefix}: {ns.n3()}" for prefix, ns in ds.namespaces()]
)
query += f"\nINSERT DATA {{"
if graph:
query += f"\n\tGRAPH <{graph}> {{"
query += " .\n".join(
[f"\t\t{s.n3()} {p.n3()} {o.n3()}" for (s, p, o, _) in ds.quads()]
)
if graph:
query += f"\n\t}}"
query += f" . \n\n}}\n"
endpoint.setQuery(query)
endpoint.queryType = "INSERT"
endpoint.method = "POST"
endpoint.setReturnFormat("json")
endpoint.query()
cur += len(data.encode("utf-8"))
print(f"inserted triples: {round(100 * cur / tot, 2)}%")
@flow
def sparql_insert_flow(
rdf_file: Path,
sparql_cfg: SparqlConfig = SparqlConfig(),
graph: Optional[str] = None,
):
"""Workflow to connect to a SPARQL endpoint and send insert
queries to load triples from a local file.
Parameters
----------
rdf_file:
Path to source RDF file.
sparql_cfg:
Configuration for the target SPARQL endpoint.
"""
load_dotenv()
logger = get_run_logger()
sparql = setup_sparql_endpoint(
sparql_cfg.endpoint, sparql_cfg.user, sparql_cfg.password
)
logger.info("SPARQL endpoint connected")
insert_triples(rdf_file, sparql, graph)
logger.info("all triples inserted")
def cli(
rdf_file: Annotated[
Path,
typer.Argument(
help="RDF file to load into the SPARQL endpoint, in turtle or n-triples format.",
exists=True,
file_okay=True,
dir_okay=False,
),
],
sparql_cfg_path: Annotated[
Optional[Path],
typer.Option(help="YAML file with SPARQL endpoint configuration."),
] = None,
graph: Annotated[
Optional[str],
typer.Option(
help="URI of named graph to load RDF data into. If not set, the default graph is used.",
),
] = None,
):
"""Command line wrapper to insert triples to a SPARQL endpoint."""
sparql_cfg = (
parse_yaml_config(sparql_cfg_path, SparqlConfig)
if sparql_cfg_path
else SparqlConfig()
)
sparql_insert_flow(rdf_file, sparql_cfg, graph)
if __name__ == "__main__":
typer.run(cli)