Unverified Commit 975a59a2 authored by Jin Hyuk Chang's avatar Jin Hyuk Chang Committed by GitHub

Add description to schema (#242)

* Add description to schema

* Update

* Update
parent cb6dc2ed
...@@ -32,7 +32,7 @@ class GenericExtractor(Extractor): ...@@ -32,7 +32,7 @@ class GenericExtractor(Extractor):
self._iter = iter(results) self._iter = iter(results)
else: else:
raise RuntimeError('model class needs to be provided!') self._iter = iter(self.values)
def extract(self): def extract(self):
# type: () -> Any # type: () -> Any
......
from typing import Dict, Any, Union, Iterator # noqa: F401
from databuilder.models.neo4j_csv_serde import (
Neo4jCsvSerializable, NODE_LABEL, NODE_KEY)
from databuilder.models.schema.schema_constant import SCHEMA_NODE_LABEL, SCHEMA_NAME_ATTR
from databuilder.models.table_metadata import DescriptionMetadata
class SchemaModel(Neo4jCsvSerializable):
def __init__(self,
schema_key,
schema,
description=None,
description_source=None,
**kwargs):
self._schema_key = schema_key
self._schema = schema
self._description = DescriptionMetadata.create_description_metadata(text=description,
source=description_source) \
if description else None
self._node_iterator = self._create_node_iterator()
self._relation_iterator = self._create_relation_iterator()
def create_next_node(self):
# type: () -> Union[Dict[str, Any], None]
try:
return next(self._node_iterator)
except StopIteration:
return None
def _create_node_iterator(self):
# type: () -> Iterator[[Dict[str, Any]]]
yield {
NODE_LABEL: SCHEMA_NODE_LABEL,
NODE_KEY: self._schema_key,
SCHEMA_NAME_ATTR: self._schema,
}
if self._description:
yield self._description.get_node_dict(self._get_description_node_key())
def create_next_relation(self):
# type: () -> Union[Dict[str, Any], None]
try:
return next(self._relation_iterator)
except StopIteration:
return None
def _get_description_node_key(self):
return '{}/{}'.format(self._schema_key, self._description.get_description_id())
def _create_relation_iterator(self):
# type: () -> Iterator[[Dict[str, Any]]]
if self._description:
yield self._description.get_relation(start_node=SCHEMA_NODE_LABEL,
start_key=self._schema_key,
end_key=self._get_description_node_key())
SCHEMA_NODE_LABEL = 'Schema'
SCHEMA_NAME_ATTR = 'name'
SCHEMA_RELATION_TYPE = 'SCHEMA'
SCHEMA_REVERSE_RELATION_TYPE = 'SCHEMA_OF'
DATABASE_SCHEMA_KEY_FORMAT = '{db}://{cluster}.{schema}'
...@@ -8,6 +8,7 @@ from databuilder.models.neo4j_csv_serde import ( ...@@ -8,6 +8,7 @@ from databuilder.models.neo4j_csv_serde import (
Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL,
RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE)
from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX
from databuilder.models.schema import schema_constant
DESCRIPTION_NODE_LABEL_VAL = 'Description' DESCRIPTION_NODE_LABEL_VAL = 'Description'
DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL
...@@ -208,11 +209,11 @@ class TableMetadata(Neo4jCsvSerializable): ...@@ -208,11 +209,11 @@ class TableMetadata(Neo4jCsvSerializable):
CLUSTER_NODE_LABEL = cluster_constants.CLUSTER_NODE_LABEL CLUSTER_NODE_LABEL = cluster_constants.CLUSTER_NODE_LABEL
CLUSTER_KEY_FORMAT = '{db}://{cluster}' CLUSTER_KEY_FORMAT = '{db}://{cluster}'
CLUSTER_SCHEMA_RELATION_TYPE = 'SCHEMA' CLUSTER_SCHEMA_RELATION_TYPE = schema_constant.SCHEMA_RELATION_TYPE
SCHEMA_CLUSTER_RELATION_TYPE = 'SCHEMA_OF' SCHEMA_CLUSTER_RELATION_TYPE = schema_constant.SCHEMA_REVERSE_RELATION_TYPE
SCHEMA_NODE_LABEL = 'Schema' SCHEMA_NODE_LABEL = schema_constant.SCHEMA_NODE_LABEL
SCHEMA_KEY_FORMAT = '{db}://{cluster}.{schema}' SCHEMA_KEY_FORMAT = schema_constant.DATABASE_SCHEMA_KEY_FORMAT
SCHEMA_TABLE_RELATION_TYPE = 'TABLE' SCHEMA_TABLE_RELATION_TYPE = 'TABLE'
TABLE_SCHEMA_RELATION_TYPE = 'TABLE_OF' TABLE_SCHEMA_RELATION_TYPE = 'TABLE_OF'
......
import abc import abc
from pyhocon import ConfigTree # noqa: F401 from pyhocon import ConfigTree # noqa: F401
from typing import Any, Iterable # noqa: F401 from typing import Any, Iterable, Optional # noqa: F401
from databuilder import Scoped from databuilder import Scoped
...@@ -42,13 +42,18 @@ class ChainedTransformer(Transformer): ...@@ -42,13 +42,18 @@ class ChainedTransformer(Transformer):
""" """
A chained transformer that iterates transformers and transforms a record A chained transformer that iterates transformers and transforms a record
""" """
def __init__(self, transformers): def __init__(self,
# type: (Iterable[Transformer]) -> None transformers,
is_init_transformers=False):
# type: (Iterable[Transformer], Optional[bool]) -> None
self.transformers = transformers self.transformers = transformers
self.is_init_transformers = is_init_transformers
def init(self, conf): def init(self, conf):
# type: (ConfigTree) -> None # type: (ConfigTree) -> None
pass if self.is_init_transformers:
for transformer in self.transformers:
transformer.init(Scoped.get_scoped_conf(conf, transformer.get_scope()))
def transform(self, record): def transform(self, record):
# type: (Any) -> Any # type: (Any) -> Any
...@@ -62,7 +67,7 @@ class ChainedTransformer(Transformer): ...@@ -62,7 +67,7 @@ class ChainedTransformer(Transformer):
def get_scope(self): def get_scope(self):
# type: () -> str # type: () -> str
pass return 'transformer.chained'
def close(self): def close(self):
# type: () -> None # type: () -> None
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '2.5.3' __version__ = '2.5.4'
requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'requirements.txt') requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'requirements.txt')
with open(requirements_path) as requirements_file: with open(requirements_path) as requirements_file:
......
...@@ -8,23 +8,40 @@ from databuilder.extractor.generic_extractor import GenericExtractor ...@@ -8,23 +8,40 @@ from databuilder.extractor.generic_extractor import GenericExtractor
class TestGenericExtractor(unittest.TestCase): class TestGenericExtractor(unittest.TestCase):
def setUp(self): def test_extraction_with_model_class(self):
# type: () -> None # type: () -> None
"""
Test Extraction using model class
"""
config_dict = { config_dict = {
'extractor.generic.extraction_items': [{'timestamp': 10000000}], 'extractor.generic.extraction_items': [{'timestamp': 10000000}],
'extractor.generic.model_class': 'extractor.generic.model_class':
'databuilder.models.neo4j_es_last_updated.Neo4jESLastUpdated', 'databuilder.models.neo4j_es_last_updated.Neo4jESLastUpdated',
} }
conf = ConfigFactory.from_dict(config_dict)
extractor = GenericExtractor()
self.conf = ConfigFactory.from_dict(config_dict) self.conf = ConfigFactory.from_dict(config_dict)
extractor.init(Scoped.get_scoped_conf(conf=conf,
scope=extractor.get_scope()))
def test_extraction_with_model_class(self): result = extractor.extract()
self.assertEquals(result.timestamp, 10000000)
def test_extraction_without_model_class(self):
# type: () -> None # type: () -> None
""" """
Test Extraction using model class Test Extraction using model class
""" """
config_dict = {
'extractor.generic.extraction_items': [{'foo': 1}, {'bar': 2}],
}
conf = ConfigFactory.from_dict(config_dict)
extractor = GenericExtractor() extractor = GenericExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf, self.conf = ConfigFactory.from_dict(config_dict)
extractor.init(Scoped.get_scoped_conf(conf=conf,
scope=extractor.get_scope())) scope=extractor.get_scope()))
result = extractor.extract() self.assertEquals(extractor.extract(), {'foo': 1})
self.assertEquals(result.timestamp, 10000000) self.assertEquals(extractor.extract(), {'bar': 2})
import unittest
from databuilder.models.schema.schema import SchemaModel
class TestSchemaDescription(unittest.TestCase):
def test_create_nodes(self):
# type: () -> None
schema = SchemaModel(schema_key='db://cluster.schema',
schema='schema_name',
description='foo')
self.assertDictEqual(schema.create_next_node(),
{'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'})
self.assertDictEqual(schema.create_next_node(),
{'description_source': 'description', 'description': 'foo',
'KEY': 'db://cluster.schema/_description', 'LABEL': 'Description'})
self.assertIsNone(schema.create_next_node())
def test_create_nodes_no_description(self):
# type: () -> None
schema = SchemaModel(schema_key='db://cluster.schema',
schema='schema_name')
self.assertDictEqual(schema.create_next_node(),
{'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'})
self.assertIsNone(schema.create_next_node())
def test_create_nodes_programmatic_description(self):
# type: () -> None
schema = SchemaModel(schema_key='db://cluster.schema',
schema='schema_name',
description='foo',
description_source='bar')
self.assertDictEqual(schema.create_next_node(),
{'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'})
self.assertDictEqual(schema.create_next_node(),
{'description_source': 'bar', 'description': 'foo',
'KEY': 'db://cluster.schema/_bar_description', 'LABEL': 'Programmatic_Description'})
self.assertIsNone(schema.create_next_node())
def test_create_relation(self):
# type: () -> None
schema = SchemaModel(schema_key='db://cluster.schema',
schema='schema_name',
description='foo')
actual = schema.create_next_relation()
expected = {'END_KEY': 'db://cluster.schema/_description', 'START_LABEL': 'Schema', 'END_LABEL': 'Description',
'START_KEY': 'db://cluster.schema', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'}
self.assertEqual(expected, actual)
self.assertIsNone(schema.create_next_relation())
def test_create_relation_no_description(self):
# type: () -> None
schema = SchemaModel(schema_key='db://cluster.schema',
schema='schema_name')
self.assertIsNone(schema.create_next_relation())
def test_create_relation_programmatic_description(self):
# type: () -> None
schema = SchemaModel(schema_key='db://cluster.schema',
schema='schema_name',
description='foo',
description_source='bar')
actual = schema.create_next_relation()
expected = {
'END_KEY': 'db://cluster.schema/_bar_description', 'START_LABEL': 'Schema',
'END_LABEL': 'Programmatic_Description', 'START_KEY': 'db://cluster.schema', 'TYPE': 'DESCRIPTION',
'REVERSE_TYPE': 'DESCRIPTION_OF'
}
self.assertEqual(expected, actual)
self.assertIsNone(schema.create_next_relation())
import unittest
from mock import MagicMock
from pyhocon import ConfigFactory
from databuilder.transformer.base_transformer import ChainedTransformer
class TestChainedTransformer(unittest.TestCase):
def test_init_not_called(self):
# type: () -> None
mock_transformer1 = MagicMock()
mock_transformer2 = MagicMock()
chained_transformer = ChainedTransformer(transformers=[mock_transformer1, mock_transformer2])
config = ConfigFactory.from_dict({})
chained_transformer.init(conf=config)
chained_transformer.transform(
{
'foo': 'bar'
}
)
mock_transformer1.init.assert_not_called()
mock_transformer1.transform.assert_called_once()
mock_transformer2.init.assert_not_called()
mock_transformer2.transform.assert_called_once()
def test_init_called(self):
# type: () -> None
mock_transformer1 = MagicMock()
mock_transformer1.get_scope.return_value = 'foo'
mock_transformer2 = MagicMock()
mock_transformer2.get_scope.return_value = 'bar'
chained_transformer = ChainedTransformer(transformers=[mock_transformer1, mock_transformer2],
is_init_transformers=True)
config = ConfigFactory.from_dict({})
chained_transformer.init(conf=config)
chained_transformer.transform(
{
'foo': 'bar'
}
)
mock_transformer1.init.assert_called_once()
mock_transformer1.transform.assert_called_once()
mock_transformer2.init.assert_called_once()
mock_transformer2.transform.assert_called_once()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment