Unverified Commit bb53d88a authored by James Davidheiser's avatar James Davidheiser Committed by GitHub

fix: Set snowflake extractor database to be consistent with other extractors (#283)

* Update sample_data_loader.py

* Update sample_data_loader.py

* Update sample_data_loader.py

* Set snowflake extractor database to be consistent with other extractors

This is a proposed fix for the bug described in https://github.com/lyft/amundsen/issues/494 - it adds a new configuration key, SNOWFLAKE_DATABASE_KEY, and uses it to set the database that metadata should be extracted from.  The DATABASE_KEY reverts back to simply describing the database, with a default of 'snowflake'.

* .

* .

* tests

* lint

* lint

* update env vars
parent 55aeaac8
...@@ -302,7 +302,7 @@ The SQL query driving the extraction is defined [here](https://github.com/lyft/a ...@@ -302,7 +302,7 @@ The SQL query driving the extraction is defined [here](https://github.com/lyft/a
```python ```python
job_config = ConfigFactory.from_dict({ job_config = ConfigFactory.from_dict({
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.DATABASE_KEY): 'YourDbName', 'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.SNOWFLAKE_DATABASE_KEY): 'YourDbName',
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): where_clause_suffix, 'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): where_clause_suffix,
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): True, 'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): True,
'extractor.postgres_metadata.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): connection_string()}) 'extractor.postgres_metadata.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): connection_string()})
......
import logging import logging
import six import six
from collections import namedtuple from collections import namedtuple
...@@ -51,7 +52,10 @@ class SnowflakeMetadataExtractor(Extractor): ...@@ -51,7 +52,10 @@ class SnowflakeMetadataExtractor(Extractor):
WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix' WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix'
CLUSTER_KEY = 'cluster_key' CLUSTER_KEY = 'cluster_key'
USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name' USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name'
# Database Key, used to identify the database type in the UI.
DATABASE_KEY = 'database_key' DATABASE_KEY = 'database_key'
# Snowflake Database Key, used to determine which Snowflake database to connect to.
SNOWFLAKE_DATABASE_KEY = 'snowflake_database'
# Default values # Default values
DEFAULT_CLUSTER_NAME = 'master' DEFAULT_CLUSTER_NAME = 'master'
...@@ -60,7 +64,8 @@ class SnowflakeMetadataExtractor(Extractor): ...@@ -60,7 +64,8 @@ class SnowflakeMetadataExtractor(Extractor):
{WHERE_CLAUSE_SUFFIX_KEY: ' ', {WHERE_CLAUSE_SUFFIX_KEY: ' ',
CLUSTER_KEY: DEFAULT_CLUSTER_NAME, CLUSTER_KEY: DEFAULT_CLUSTER_NAME,
USE_CATALOG_AS_CLUSTER_NAME: True, USE_CATALOG_AS_CLUSTER_NAME: True,
DATABASE_KEY: 'prod'} DATABASE_KEY: 'snowflake',
SNOWFLAKE_DATABASE_KEY: 'prod'}
) )
def init(self, conf): def init(self, conf):
...@@ -74,13 +79,16 @@ class SnowflakeMetadataExtractor(Extractor): ...@@ -74,13 +79,16 @@ class SnowflakeMetadataExtractor(Extractor):
cluster_source = "'{}'".format(self._cluster) cluster_source = "'{}'".format(self._cluster)
self._database = conf.get_string(SnowflakeMetadataExtractor.DATABASE_KEY) self._database = conf.get_string(SnowflakeMetadataExtractor.DATABASE_KEY)
self._snowflake_database = conf.get_string(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY)
if six.PY2: if six.PY2:
self._database = self._database.encode('utf-8', 'ignore') self._database = self._database.encode('utf-8', 'ignore')
self._snowflake_database = self._snowflake_database.encode('utf-8', 'ignore')
self.sql_stmt = SnowflakeMetadataExtractor.SQL_STATEMENT.format( self.sql_stmt = SnowflakeMetadataExtractor.SQL_STATEMENT.format(
where_clause_suffix=conf.get_string(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY), where_clause_suffix=conf.get_string(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY),
cluster_source=cluster_source, cluster_source=cluster_source,
database=self._database database=self._snowflake_database
) )
LOGGER.info('SQL for snowflake metadata: {}'.format(self.sql_stmt)) LOGGER.info('SQL for snowflake metadata: {}'.format(self.sql_stmt))
......
...@@ -18,6 +18,7 @@ https://github.com/lyft/amundsendatabuilder#list-of-extractors ...@@ -18,6 +18,7 @@ https://github.com/lyft/amundsendatabuilder#list-of-extractors
""" """
import logging import logging
import os
import sqlite3 import sqlite3
import sys import sys
import textwrap import textwrap
...@@ -37,22 +38,25 @@ from databuilder.publisher.neo4j_csv_publisher import Neo4jCsvPublisher ...@@ -37,22 +38,25 @@ from databuilder.publisher.neo4j_csv_publisher import Neo4jCsvPublisher
from databuilder.task.task import DefaultTask from databuilder.task.task import DefaultTask
from databuilder.transformer.base_transformer import NoopTransformer from databuilder.transformer.base_transformer import NoopTransformer
es_host = None es_host = os.getenv('CREDENTIALS_ELASTICSEARCH_PROXY_HOST', 'localhost')
neo_host = None neo_host = os.getenv('CREDENTIALS_NEO4J_PROXY_HOST', 'localhost')
es_port = os.getenv('CREDENTIALS_ELASTICSEARCH_PROXY_PORT', 9200)
neo_port = os.getenv('CREDENTIALS_NEO4J_PROXY_PORT', 7687)
if len(sys.argv) > 1: if len(sys.argv) > 1:
es_host = sys.argv[1] es_host = sys.argv[1]
if len(sys.argv) > 2: if len(sys.argv) > 2:
neo_host = sys.argv[2] neo_host = sys.argv[2]
es = Elasticsearch([ es = Elasticsearch([
{'host': es_host if es_host else 'localhost'}, {'host': es_host, 'port': es_port},
]) ])
DB_FILE = '/tmp/test.db' DB_FILE = '/tmp/test.db'
SQLITE_CONN_STRING = 'sqlite:////tmp/test.db' SQLITE_CONN_STRING = 'sqlite:////tmp/test.db'
Base = declarative_base() Base = declarative_base()
NEO4J_ENDPOINT = 'bolt://{}:7687'.format(neo_host if neo_host else 'localhost') NEO4J_ENDPOINT = 'bolt://{}:{}'.format(neo_host, neo_port)
neo4j_endpoint = NEO4J_ENDPOINT neo4j_endpoint = NEO4J_ENDPOINT
......
...@@ -79,7 +79,7 @@ def create_sample_snowflake_job(): ...@@ -79,7 +79,7 @@ def create_sample_snowflake_job():
job_config = ConfigFactory.from_dict({ job_config = ConfigFactory.from_dict({
'extractor.snowflake.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): connection_string(), 'extractor.snowflake.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): connection_string(),
'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.DATABASE_KEY): SNOWFLAKE_DATABASE_KEY, 'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY): SNOWFLAKE_DATABASE_KEY,
'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): where_clause, 'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): where_clause,
'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.NODE_DIR_PATH): node_files_folder, 'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.NODE_DIR_PATH): node_files_folder,
'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.RELATION_DIR_PATH): relationship_files_folder, 'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.RELATION_DIR_PATH): relationship_files_folder,
......
...@@ -22,7 +22,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
'MY_CLUSTER', 'MY_CLUSTER',
'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): 'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME):
False, False,
'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.DATABASE_KEY): 'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY):
'prod' 'prod'
} }
self.conf = ConfigFactory.from_dict(config_dict) self.conf = ConfigFactory.from_dict(config_dict)
...@@ -90,7 +90,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase): ...@@ -90,7 +90,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
extractor = SnowflakeMetadataExtractor() extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf) extractor.init(self.conf)
actual = extractor.extract() actual = extractor.extract()
expected = TableMetadata('prod', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing', expected = TableMetadata('snowflake', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing',
[ColumnMetadata('col_id1', 'description of id1', 'number', 0), [ColumnMetadata('col_id1', 'description of id1', 'number', 0),
ColumnMetadata('col_id2', 'description of id2', 'number', 1), ColumnMetadata('col_id2', 'description of id2', 'number', 1),
ColumnMetadata('is_active', None, 'boolean', 2), ColumnMetadata('is_active', None, 'boolean', 2),
...@@ -189,7 +189,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase): ...@@ -189,7 +189,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
extractor = SnowflakeMetadataExtractor() extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf) extractor.init(self.conf)
expected = TableMetadata('prod', expected = TableMetadata('snowflake',
self.conf['extractor.snowflake_metadata.{}'.format( self.conf['extractor.snowflake_metadata.{}'.format(
SnowflakeMetadataExtractor.CLUSTER_KEY)], SnowflakeMetadataExtractor.CLUSTER_KEY)],
'test_schema1', 'test_table1', 'test table 1', 'test_schema1', 'test_table1', 'test table 1',
...@@ -202,7 +202,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase): ...@@ -202,7 +202,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
ColumnMetadata('ds', None, 'varchar', 5)]) ColumnMetadata('ds', None, 'varchar', 5)])
self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) self.assertEqual(expected.__repr__(), extractor.extract().__repr__())
expected = TableMetadata('prod', expected = TableMetadata('snowflake',
self.conf['extractor.snowflake_metadata.{}'.format( self.conf['extractor.snowflake_metadata.{}'.format(
SnowflakeMetadataExtractor.CLUSTER_KEY)], SnowflakeMetadataExtractor.CLUSTER_KEY)],
'test_schema1', 'test_table2', 'test table 2', 'test_schema1', 'test_table2', 'test table 2',
...@@ -210,7 +210,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase): ...@@ -210,7 +210,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)]) ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)])
self.assertEqual(expected.__repr__(), extractor.extract().__repr__()) self.assertEqual(expected.__repr__(), extractor.extract().__repr__())
expected = TableMetadata('prod', expected = TableMetadata('snowflake',
self.conf['extractor.snowflake_metadata.{}'.format( self.conf['extractor.snowflake_metadata.{}'.format(
SnowflakeMetadataExtractor.CLUSTER_KEY)], SnowflakeMetadataExtractor.CLUSTER_KEY)],
'test_schema2', 'test_table3', 'test table 3', 'test_schema2', 'test_table3', 'test table 3',
...@@ -281,12 +281,37 @@ class TestSnowflakeMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase): ...@@ -281,12 +281,37 @@ class TestSnowflakeMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase):
self.assertTrue(self.cluster_key in extractor.sql_stmt) self.assertTrue(self.cluster_key in extractor.sql_stmt)
class TestSnowflakeMetadataExtractorDefaultSnowflakeDatabaseKey(unittest.TestCase):
# test when SNOWFLAKE_DATABASE_KEY is specified
def setUp(self):
# type: () -> None
logging.basicConfig(level=logging.INFO)
self.snowflake_database_key = "not_prod"
config_dict = {
SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY: self.snowflake_database_key,
'extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING):
'TEST_CONNECTION'
}
self.conf = ConfigFactory.from_dict(config_dict)
def test_sql_statement(self):
# type: () -> None
"""
Test Extraction with empty result from query
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
self.assertTrue(self.snowflake_database_key in extractor.sql_stmt)
class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase): class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase):
# test when DATABASE_KEY is specified # test when DATABASE_KEY is specified
def setUp(self): def setUp(self):
# type: () -> None # type: () -> None
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
self.database_key = "not_prod" self.database_key = 'not_snowflake'
config_dict = { config_dict = {
SnowflakeMetadataExtractor.DATABASE_KEY: self.database_key, SnowflakeMetadataExtractor.DATABASE_KEY: self.database_key,
...@@ -303,7 +328,38 @@ class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase): ...@@ -303,7 +328,38 @@ class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase):
with patch.object(SQLAlchemyExtractor, '_get_connection'): with patch.object(SQLAlchemyExtractor, '_get_connection'):
extractor = SnowflakeMetadataExtractor() extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf) extractor.init(self.conf)
self.assertTrue(self.database_key in extractor.sql_stmt) self.assertFalse(self.database_key in extractor.sql_stmt)
def test_extraction_with_database_specified(self):
# type: () -> None
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection:
connection = MagicMock()
mock_connection.return_value = connection
sql_execute = MagicMock()
connection.execute = sql_execute
sql_execute.return_value = [
{'schema': 'test_schema',
'name': 'test_table',
'description': 'a table for testing',
'cluster': 'MY_CLUSTER',
'is_view': 'false',
'col_name': 'ds',
'col_type': 'varchar',
'col_description': None,
'col_sort_order': 0}
]
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
actual = extractor.extract()
expected = TableMetadata(
self.database_key, 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing',
[ColumnMetadata('ds', None, 'varchar', 0)]
)
self.assertEqual(expected.__repr__(), actual.__repr__())
self.assertIsNone(extractor.extract())
class TestSnowflakeMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase): class TestSnowflakeMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment