-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy path2_ingest_data.py
More file actions
86 lines (73 loc) · 3.01 KB
/
2_ingest_data.py
File metadata and controls
86 lines (73 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import json
import asyncio
from neo4j import GraphDatabase
from pydantic import validate_call
from neo4j_graphrag.experimental.components.types import (
Neo4jGraph,
Neo4jNode,
Neo4jRelationship,
)
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, KGWriterModel
class Neo4jCreateWriter(KGWriter):
"""관계에 대해 MERGE 대신 CREATE를 사용하는 Custom KGWriter (에피소드별로 다른 관계도 반영)"""
def __init__(self, driver, neo4j_database=None):
self.driver = driver
self.neo4j_database = neo4j_database
def _wipe_database(self) -> None:
self.driver.execute_query(
"MATCH (n) DETACH DELETE n",
database_=self.neo4j_database,
)
@validate_call
async def run(self, graph: Neo4jGraph) -> KGWriterModel:
try:
self._wipe_database()
with self.driver.session(database=self.neo4j_database) as session:
# 1. node 작성
for node in graph.nodes:
labels = f":{node.label}"
session.run(
f"""
MERGE (n{labels} {{id: $id}})
SET n += $props
""",
{"id": node.id, "props": node.properties or {}},
)
# 2. relationship 작성
for rel in graph.relationships:
session.run(
f"""
MATCH (a {{id: $start_id}}), (b {{id: $end_id}})
CREATE (a)-[r:{rel.type} $props]->(b)
""",
{
"start_id": rel.start_node_id,
"end_id": rel.end_node_id,
"props": rel.properties or {},
},
)
return KGWriterModel(
status="SUCCESS",
metadata={
"node_count": len(graph.nodes),
"relationship_count": len(graph.relationships),
},
)
except Exception as e:
return KGWriterModel(status="FAILURE", metadata={"error": str(e)})
async def write_to_neo4j(graph: Neo4jGraph):
uri = "neo4j://127.0.0.1:7687"
user = "neo4j"
password = "12345678"
driver = GraphDatabase.driver(uri, auth=(user, password))
# writer = KGWriter(driver)
writer = Neo4jCreateWriter(driver)
result = await writer.run(graph)
print(result)
if __name__ == "__main__":
with open("output/지식그래프_최종.json", "r", encoding="utf-8") as f:
data = json.load(f)
nodes = [Neo4jNode(**node) for node in data["nodes"]]
relationships = [Neo4jRelationship(**rel) for rel in data.get("relationships", [])]
graph = Neo4jGraph(nodes=nodes, relationships=relationships)
asyncio.run(write_to_neo4j(graph))