Unverified Commit edce3cbb authored by Jin Hyuk Chang's avatar Jin Hyuk Chang Committed by GitHub

Neo4j Publisher to support desired state of relation (#69)

* [AMD-120] Add relation pre-processor in Neo4jPublisher

* Update

* Added DeleteRelationPreprocessor

* Added DeleteRelationPreprocessor

* Update

* Update
parent 014690ea
import abc
import logging
import six
import textwrap
LOGGER = logging.getLogger(__name__)
@six.add_metaclass(abc.ABCMeta)
class RelationPreprocessor(object):
"""
A Preprocessor for relations. Prior to publish Neo4j relations, RelationPreprocessor will be used for
pre-processing.
Neo4j Publisher will iterate through relation file and call preprocess_cypher to perform any pre-process requested.
For example, if you need current job's relation data to be desired state, you can add delete statement in
pre-process_cypher method. With preprocess_cypher defined, and with long transaction size, Neo4j publisher will
atomically apply desired state.
"""
def preprocess_cypher(self,
start_label,
end_label,
start_key,
end_key,
relation,
reverse_relation):
# type: (str, str, str, str, str, str) -> Tuple[str, Dict[str, str]]
"""
Provides a Cypher statement that will be executed before publishing relations.
:param start_label:
:param end_label:
:param start_key:
:param end_key:
:param relation:
:param reverse_relation:
:return:
"""
if self.filter(start_label=start_label,
end_label=end_label,
start_key=start_key,
end_key=end_key,
relation=relation,
reverse_relation=reverse_relation):
return self.preprocess_cypher_impl(start_label=start_label,
end_label=end_label,
start_key=start_key,
end_key=end_key,
relation=relation,
reverse_relation=reverse_relation)
@abc.abstractmethod
def preprocess_cypher_impl(self,
start_label,
end_label,
start_key,
end_key,
relation,
reverse_relation):
# type: (str, str, str, str, str, str) -> Tuple[str, Dict[str, str]]
"""
Provides a Cypher statement that will be executed before publishing relations.
:param start_label:
:param end_label:
:param relation:
:param reverse_relation:
:return: A Cypher statement
"""
pass
def filter(self,
start_label,
end_label,
start_key,
end_key,
relation,
reverse_relation):
# type: (str, str, str, str, str, str) -> bool
"""
A method that filters pre-processing in record level. Returns True if it needs preprocessing, otherwise False.
:param start_label:
:param end_label:
:param start_key:
:param end_key:
:param relation:
:param reverse_relation:
:return: bool. True if it needs preprocessing, otherwise False.
"""
True
@abc.abstractmethod
def is_perform_preprocess(self):
# type: () -> bool
"""
A method for Neo4j Publisher to determine whether to perform pre-processing or not. Regard this method as a
global filter.
:return: True if you want to enable the pre-processing.
"""
pass
class NoopRelationPreprocessor(RelationPreprocessor):
def preprocess_cypher_impl(self,
start_label,
end_label,
start_key,
end_key,
relation,
reverse_relation):
# type: (str, str, str, str, str, str) -> Tuple[str, Dict[str, str]]
pass
def is_perform_preprocess(self):
# type: () -> bool
return False
class DeleteRelationPreprocessor(RelationPreprocessor):
"""
A Relation Pre-processor that delete relationship before Neo4jPublisher publishes relations.
Example use case: Take an example of an external privacy service trying to push personal identifiable
identification (PII) tag into Amundsen. It is fine to push set of PII tags for the first push, but it becomes a
challenge when it comes to following update as external service does not know current PII state in Amundsen.
The easy solution is for external service to know desired state (certain columns should have certain PII tags),
and push that information.
Now the challenge is how Amundsen apply desired state. This is where DeleteRelationPreprocessor comes into the
picture. We can utilize DeleteRelationPreprocessor to let it delete certain relations in the job,
and let Neo4jPublisher update to desired state. Should there be a small window (between delete and update) that
Amundsen data is not complete, you can increase Neo4jPublisher's transaction size to make it atomic. However,
note that you should not set transaction size too big as Neo4j uses memory to store transaction and this use case
is proper for small size of batch job.
"""
RELATION_MERGE_TEMPLATE = textwrap.dedent("""
MATCH (n1:{start_label} {{key: $start_key }})-[r]-(n2:{end_label} {{key: $end_key }})
{where_clause}
WITH r LIMIT 2
DELETE r
RETURN count(*) as count;
""")
def __init__(self, label_tuples=None, where_clause=''):
# type: (List[Tuple[str, str]], str) -> None
super(DeleteRelationPreprocessor, self).__init__()
self._label_tuples = set(label_tuples) if label_tuples else set()
reversed_label_tuples = [(t2, t1) for t1, t2 in self._label_tuples]
self._label_tuples.update(reversed_label_tuples)
self._where_clause = where_clause
def preprocess_cypher_impl(self,
start_label,
end_label,
start_key,
end_key,
relation,
reverse_relation):
# type: (str, str, str, str, str, str) -> Tuple[str, Dict[str, str]]
"""
Provides DELETE Relation Cypher query on specific relation.
:param start_label:
:param end_label:
:param start_key:
:param end_key:
:param relation:
:param reverse_relation:
:return:
"""
if not (start_label or end_label or start_key or end_key):
raise Exception('all labels and keys are required: {}'.format(locals()))
params = {'start_key': start_key, 'end_key': end_key}
return DeleteRelationPreprocessor.RELATION_MERGE_TEMPLATE.format(start_label=start_label,
end_label=end_label,
where_clause=self._where_clause), params
def is_perform_preprocess(self):
# type: () -> bool
return True
def filter(self,
start_label,
end_label,
start_key,
end_key,
relation,
reverse_relation):
# type: (str, str, str, str, str, str) -> bool
"""
If pair of labels is what client requested passed through label_tuples, filter will return True meaning that
it needs to be pre-processed.
:param start_label:
:param end_label:
:param start_key:
:param end_key:
:param relation:
:param reverse_relation:
:return: bool. True if it needs preprocessing, otherwise False.
"""
if self._label_tuples and (start_label, end_label) not in self._label_tuples:
return False
return True
...@@ -12,11 +12,18 @@ from databuilder.transformer.base_transformer \ ...@@ -12,11 +12,18 @@ from databuilder.transformer.base_transformer \
from databuilder.utils.closer import Closer from databuilder.utils.closer import Closer
LOGGER = logging.getLogger(__name__)
class DefaultTask(Task): class DefaultTask(Task):
""" """
A default task expecting to extract, transform and load. A default task expecting to extract, transform and load.
""" """
# Determines the frequency of the log on task progress
PROGRESS_REPORT_FREQUENCY = 'progress_report_frequency'
def __init__(self, def __init__(self,
extractor, extractor,
loader, loader,
...@@ -33,6 +40,9 @@ class DefaultTask(Task): ...@@ -33,6 +40,9 @@ class DefaultTask(Task):
def init(self, conf): def init(self, conf):
# type: (ConfigTree) -> None # type: (ConfigTree) -> None
self._progress_report_frequency = \
conf.get_int('{}.{}'.format(self.get_scope(), DefaultTask.PROGRESS_REPORT_FREQUENCY), 500)
self.extractor.init(Scoped.get_scoped_conf(conf, self.extractor.get_scope())) self.extractor.init(Scoped.get_scoped_conf(conf, self.extractor.get_scope()))
self.transformer.init(Scoped.get_scoped_conf(conf, self.transformer.get_scope())) self.transformer.init(Scoped.get_scoped_conf(conf, self.transformer.get_scope()))
self.loader.init(Scoped.get_scoped_conf(conf, self.loader.get_scope())) self.loader.init(Scoped.get_scoped_conf(conf, self.loader.get_scope()))
...@@ -43,15 +53,19 @@ class DefaultTask(Task): ...@@ -43,15 +53,19 @@ class DefaultTask(Task):
Runs a task Runs a task
:return: :return:
""" """
logging.info('Running a task') LOGGER.info('Running a task')
try: try:
record = self.extractor.extract() record = self.extractor.extract()
count = 1
while record: while record:
record = self.transformer.transform(record) record = self.transformer.transform(record)
if not record: if not record:
continue continue
self.loader.load(record) self.loader.load(record)
record = self.extractor.extract() record = self.extractor.extract()
count += 1
if count > 0 and count % self._progress_report_frequency == 0:
LOGGER.info('Extracted {} records so far'.format(count))
finally: finally:
self._closer.close() self._closer.close()
...@@ -16,7 +16,7 @@ class TestPublish(unittest.TestCase): ...@@ -16,7 +16,7 @@ class TestPublish(unittest.TestCase):
def setUp(self): def setUp(self):
# type: () -> None # type: () -> None
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
self._resource_path = '{}/../resources/csv_publisher'\ self._resource_path = '{}/../resources/csv_publisher' \
.format(os.path.join(os.path.dirname(__file__))) .format(os.path.join(os.path.dirname(__file__)))
def test_publisher(self): def test_publisher(self):
...@@ -36,12 +36,9 @@ class TestPublish(unittest.TestCase): ...@@ -36,12 +36,9 @@ class TestPublish(unittest.TestCase):
publisher = Neo4jCsvPublisher() publisher = Neo4jCsvPublisher()
conf = ConfigFactory.from_dict( conf = ConfigFactory.from_dict(
{neo4j_csv_publisher.NEO4J_END_POINT_KEY: {neo4j_csv_publisher.NEO4J_END_POINT_KEY: 'dummy://999.999.999.999:7687/',
'dummy://999.999.999.999:7687/', neo4j_csv_publisher.NODE_FILES_DIR: '{}/nodes'.format(self._resource_path),
neo4j_csv_publisher.NODE_FILES_DIR: neo4j_csv_publisher.RELATION_FILES_DIR: '{}/relations'.format(self._resource_path),
'{}/nodes'.format(self._resource_path),
neo4j_csv_publisher.RELATION_FILES_DIR:
'{}/relations'.format(self._resource_path),
neo4j_csv_publisher.NEO4J_USER: 'neo4j_user', neo4j_csv_publisher.NEO4J_USER: 'neo4j_user',
neo4j_csv_publisher.NEO4J_PASSWORD: 'neo4j_password', neo4j_csv_publisher.NEO4J_PASSWORD: 'neo4j_password',
neo4j_csv_publisher.JOB_PUBLISH_TAG: '{}'.format(uuid.uuid4())} neo4j_csv_publisher.JOB_PUBLISH_TAG: '{}'.format(uuid.uuid4())}
...@@ -52,7 +49,44 @@ class TestPublish(unittest.TestCase): ...@@ -52,7 +49,44 @@ class TestPublish(unittest.TestCase):
self.assertEqual(mock_run.call_count, 6) self.assertEqual(mock_run.call_count, 6)
# 2 node files, 1 relation file # 2 node files, 1 relation file
self.assertEqual(mock_commit.call_count, 3) self.assertEqual(mock_commit.call_count, 1)
def test_preprocessor(self):
# type: () -> None
with patch.object(GraphDatabase, 'driver') as mock_driver:
mock_session = MagicMock()
mock_driver.return_value.session.return_value = mock_session
mock_transaction = MagicMock()
mock_session.begin_transaction.return_value = mock_transaction
mock_run = MagicMock()
mock_transaction.run = mock_run
mock_commit = MagicMock()
mock_transaction.commit = mock_commit
mock_preprocessor = MagicMock()
mock_preprocessor.is_perform_preprocess.return_value = MagicMock(return_value=True)
mock_preprocessor.preprocess_cypher.return_value = ('MATCH (f:Foo) RETURN f', {})
publisher = Neo4jCsvPublisher()
conf = ConfigFactory.from_dict(
{neo4j_csv_publisher.NEO4J_END_POINT_KEY: 'dummy://999.999.999.999:7687/',
neo4j_csv_publisher.NODE_FILES_DIR: '{}/nodes'.format(self._resource_path),
neo4j_csv_publisher.RELATION_FILES_DIR: '{}/relations'.format(self._resource_path),
neo4j_csv_publisher.RELATION_PREPROCESSOR: mock_preprocessor,
neo4j_csv_publisher.NEO4J_USER: 'neo4j_user',
neo4j_csv_publisher.NEO4J_PASSWORD: 'neo4j_password',
neo4j_csv_publisher.JOB_PUBLISH_TAG: '{}'.format(uuid.uuid4())}
)
publisher.init(conf)
publisher.publish()
self.assertEqual(mock_run.call_count, 8)
# 2 node files, 1 relation file
self.assertEqual(mock_commit.call_count, 1)
if __name__ == '__main__': if __name__ == '__main__':
......
import textwrap
import unittest
import uuid
from databuilder.publisher.neo4j_preprocessor import NoopRelationPreprocessor, DeleteRelationPreprocessor
class TestNeo4jPreprocessor(unittest.TestCase):
def testNoopRelationPreprocessor(self):
# type () -> None
preprocessor = NoopRelationPreprocessor()
self.assertTrue(not preprocessor.is_perform_preprocess())
def testDeleteRelationPreprocessor(self): # noqa: W293
preprocessor = DeleteRelationPreprocessor()
self.assertTrue(preprocessor.is_perform_preprocess())
preprocessor.filter(start_label='foo_label',
end_label='bar_label',
start_key='foo_key',
end_key='bar_key',
relation='foo_relation',
reverse_relation='bar_relation')
self.assertTrue(preprocessor.filter(start_label=str(uuid.uuid4()),
end_label=str(uuid.uuid4()),
start_key=str(uuid.uuid4()),
end_key=str(uuid.uuid4()),
relation=str(uuid.uuid4()),
reverse_relation=str(uuid.uuid4())))
actual = preprocessor.preprocess_cypher(start_label='foo_label',
end_label='bar_label',
start_key='foo_key',
end_key='bar_key',
relation='foo_relation',
reverse_relation='bar_relation')
expected = (textwrap.dedent("""
MATCH (n1:foo_label {key: $start_key })-[r]-(n2:bar_label {key: $end_key })
WITH r LIMIT 2
DELETE r
RETURN count(*) as count;
"""), {'start_key': 'foo_key', 'end_key': 'bar_key'})
self.assertEqual(expected, actual)
def testDeleteRelationPreprocessorFilter(self):
preprocessor = DeleteRelationPreprocessor(label_tuples=[('foo', 'bar')])
self.assertTrue(preprocessor.filter(start_label='foo',
end_label='bar',
start_key=str(uuid.uuid4()),
end_key=str(uuid.uuid4()),
relation=str(uuid.uuid4()),
reverse_relation=str(uuid.uuid4())))
self.assertTrue(preprocessor.filter(start_label='bar',
end_label='foo',
start_key=str(uuid.uuid4()),
end_key=str(uuid.uuid4()),
relation=str(uuid.uuid4()),
reverse_relation=str(uuid.uuid4())))
self.assertFalse(preprocessor.filter(start_label='foz',
end_label='baz',
start_key=str(uuid.uuid4()),
end_key=str(uuid.uuid4()),
relation=str(uuid.uuid4()),
reverse_relation=str(uuid.uuid4())))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment