Unverified Commit bb53d88a authored by James Davidheiser's avatar James Davidheiser Committed by GitHub

fix: Set snowflake extractor database to be consistent with other extractors (#283)

* Update sample_data_loader.py

* Update sample_data_loader.py

* Update sample_data_loader.py

* Set snowflake extractor database to be consistent with other extractors

This is a proposed fix for the bug described in https://github.com/lyft/amundsen/issues/494 - it adds a new configuration key, SNOWFLAKE_DATABASE_KEY, and uses it to set the database that metadata should be extracted from.  The DATABASE_KEY reverts back to simply describing the database, with a default of 'snowflake'.

* .

* .

* tests

* lint

* lint

* update env vars
parent 55aeaac8
......@@ -302,7 +302,7 @@ The SQL query driving the extraction is defined [here](https://github.com/lyft/a
```python
job_config = ConfigFactory.from_dict({
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.DATABASE_KEY): 'YourDbName',
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.SNOWFLAKE_DATABASE_KEY): 'YourDbName',
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): where_clause_suffix,
'extractor.postgres_metadata.{}'.format(PostgresMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME): True,
'extractor.postgres_metadata.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): connection_string()})
......
import logging
import six
from collections import namedtuple
......@@ -51,7 +52,10 @@ class SnowflakeMetadataExtractor(Extractor):
WHERE_CLAUSE_SUFFIX_KEY = 'where_clause_suffix'
CLUSTER_KEY = 'cluster_key'
USE_CATALOG_AS_CLUSTER_NAME = 'use_catalog_as_cluster_name'
# Database Key, used to identify the database type in the UI.
DATABASE_KEY = 'database_key'
# Snowflake Database Key, used to determine which Snowflake database to connect to.
SNOWFLAKE_DATABASE_KEY = 'snowflake_database'
# Default values
DEFAULT_CLUSTER_NAME = 'master'
......@@ -60,7 +64,8 @@ class SnowflakeMetadataExtractor(Extractor):
{WHERE_CLAUSE_SUFFIX_KEY: ' ',
CLUSTER_KEY: DEFAULT_CLUSTER_NAME,
USE_CATALOG_AS_CLUSTER_NAME: True,
DATABASE_KEY: 'prod'}
DATABASE_KEY: 'snowflake',
SNOWFLAKE_DATABASE_KEY: 'prod'}
)
def init(self, conf):
......@@ -74,13 +79,16 @@ class SnowflakeMetadataExtractor(Extractor):
cluster_source = "'{}'".format(self._cluster)
self._database = conf.get_string(SnowflakeMetadataExtractor.DATABASE_KEY)
self._snowflake_database = conf.get_string(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY)
if six.PY2:
self._database = self._database.encode('utf-8', 'ignore')
self._snowflake_database = self._snowflake_database.encode('utf-8', 'ignore')
self.sql_stmt = SnowflakeMetadataExtractor.SQL_STATEMENT.format(
where_clause_suffix=conf.get_string(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY),
cluster_source=cluster_source,
database=self._database
database=self._snowflake_database
)
LOGGER.info('SQL for snowflake metadata: {}'.format(self.sql_stmt))
......
......@@ -18,6 +18,7 @@ https://github.com/lyft/amundsendatabuilder#list-of-extractors
"""
import logging
import os
import sqlite3
import sys
import textwrap
......@@ -37,22 +38,25 @@ from databuilder.publisher.neo4j_csv_publisher import Neo4jCsvPublisher
from databuilder.task.task import DefaultTask
from databuilder.transformer.base_transformer import NoopTransformer
es_host = None
neo_host = None
es_host = os.getenv('CREDENTIALS_ELASTICSEARCH_PROXY_HOST', 'localhost')
neo_host = os.getenv('CREDENTIALS_NEO4J_PROXY_HOST', 'localhost')
es_port = os.getenv('CREDENTIALS_ELASTICSEARCH_PROXY_PORT', 9200)
neo_port = os.getenv('CREDENTIALS_NEO4J_PROXY_PORT', 7687)
if len(sys.argv) > 1:
es_host = sys.argv[1]
if len(sys.argv) > 2:
neo_host = sys.argv[2]
es = Elasticsearch([
{'host': es_host if es_host else 'localhost'},
{'host': es_host, 'port': es_port},
])
DB_FILE = '/tmp/test.db'
SQLITE_CONN_STRING = 'sqlite:////tmp/test.db'
Base = declarative_base()
NEO4J_ENDPOINT = 'bolt://{}:7687'.format(neo_host if neo_host else 'localhost')
NEO4J_ENDPOINT = 'bolt://{}:{}'.format(neo_host, neo_port)
neo4j_endpoint = NEO4J_ENDPOINT
......
......@@ -79,7 +79,7 @@ def create_sample_snowflake_job():
job_config = ConfigFactory.from_dict({
'extractor.snowflake.extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING): connection_string(),
'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.DATABASE_KEY): SNOWFLAKE_DATABASE_KEY,
'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY): SNOWFLAKE_DATABASE_KEY,
'extractor.snowflake.{}'.format(SnowflakeMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY): where_clause,
'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.NODE_DIR_PATH): node_files_folder,
'loader.filesystem_csv_neo4j.{}'.format(FsNeo4jCSVLoader.RELATION_DIR_PATH): relationship_files_folder,
......
......@@ -22,7 +22,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
'MY_CLUSTER',
'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.USE_CATALOG_AS_CLUSTER_NAME):
False,
'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.DATABASE_KEY):
'extractor.snowflake_metadata.{}'.format(SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY):
'prod'
}
self.conf = ConfigFactory.from_dict(config_dict)
......@@ -90,7 +90,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
actual = extractor.extract()
expected = TableMetadata('prod', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing',
expected = TableMetadata('snowflake', 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing',
[ColumnMetadata('col_id1', 'description of id1', 'number', 0),
ColumnMetadata('col_id2', 'description of id2', 'number', 1),
ColumnMetadata('is_active', None, 'boolean', 2),
......@@ -189,7 +189,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
expected = TableMetadata('prod',
expected = TableMetadata('snowflake',
self.conf['extractor.snowflake_metadata.{}'.format(
SnowflakeMetadataExtractor.CLUSTER_KEY)],
'test_schema1', 'test_table1', 'test table 1',
......@@ -202,7 +202,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
ColumnMetadata('ds', None, 'varchar', 5)])
self.assertEqual(expected.__repr__(), extractor.extract().__repr__())
expected = TableMetadata('prod',
expected = TableMetadata('snowflake',
self.conf['extractor.snowflake_metadata.{}'.format(
SnowflakeMetadataExtractor.CLUSTER_KEY)],
'test_schema1', 'test_table2', 'test table 2',
......@@ -210,7 +210,7 @@ class TestSnowflakeMetadataExtractor(unittest.TestCase):
ColumnMetadata('col_name2', 'description of col_name2', 'varchar', 1)])
self.assertEqual(expected.__repr__(), extractor.extract().__repr__())
expected = TableMetadata('prod',
expected = TableMetadata('snowflake',
self.conf['extractor.snowflake_metadata.{}'.format(
SnowflakeMetadataExtractor.CLUSTER_KEY)],
'test_schema2', 'test_table3', 'test table 3',
......@@ -281,12 +281,37 @@ class TestSnowflakeMetadataExtractorClusterKeyNoTableCatalog(unittest.TestCase):
self.assertTrue(self.cluster_key in extractor.sql_stmt)
class TestSnowflakeMetadataExtractorDefaultSnowflakeDatabaseKey(unittest.TestCase):
# test when SNOWFLAKE_DATABASE_KEY is specified
def setUp(self):
# type: () -> None
logging.basicConfig(level=logging.INFO)
self.snowflake_database_key = "not_prod"
config_dict = {
SnowflakeMetadataExtractor.SNOWFLAKE_DATABASE_KEY: self.snowflake_database_key,
'extractor.sqlalchemy.{}'.format(SQLAlchemyExtractor.CONN_STRING):
'TEST_CONNECTION'
}
self.conf = ConfigFactory.from_dict(config_dict)
def test_sql_statement(self):
# type: () -> None
"""
Test Extraction with empty result from query
"""
with patch.object(SQLAlchemyExtractor, '_get_connection'):
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
self.assertTrue(self.snowflake_database_key in extractor.sql_stmt)
class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase):
# test when DATABASE_KEY is specified
def setUp(self):
# type: () -> None
logging.basicConfig(level=logging.INFO)
self.database_key = "not_prod"
self.database_key = 'not_snowflake'
config_dict = {
SnowflakeMetadataExtractor.DATABASE_KEY: self.database_key,
......@@ -303,7 +328,38 @@ class TestSnowflakeMetadataExtractorDefaultDatabaseKey(unittest.TestCase):
with patch.object(SQLAlchemyExtractor, '_get_connection'):
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
self.assertTrue(self.database_key in extractor.sql_stmt)
self.assertFalse(self.database_key in extractor.sql_stmt)
def test_extraction_with_database_specified(self):
# type: () -> None
with patch.object(SQLAlchemyExtractor, '_get_connection') as mock_connection:
connection = MagicMock()
mock_connection.return_value = connection
sql_execute = MagicMock()
connection.execute = sql_execute
sql_execute.return_value = [
{'schema': 'test_schema',
'name': 'test_table',
'description': 'a table for testing',
'cluster': 'MY_CLUSTER',
'is_view': 'false',
'col_name': 'ds',
'col_type': 'varchar',
'col_description': None,
'col_sort_order': 0}
]
extractor = SnowflakeMetadataExtractor()
extractor.init(self.conf)
actual = extractor.extract()
expected = TableMetadata(
self.database_key, 'MY_CLUSTER', 'test_schema', 'test_table', 'a table for testing',
[ColumnMetadata('ds', None, 'varchar', 0)]
)
self.assertEqual(expected.__repr__(), actual.__repr__())
self.assertIsNone(extractor.extract())
class TestSnowflakeMetadataExtractorNoClusterKeyNoTableCatalog(unittest.TestCase):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment