Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Convert Snowflake connect_args to a dict before passing to Ibis #1431

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions data_validation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import re
import time

from data_validation import exceptions


def timed_call(log_txt, fn, *args, **kwargs):
t0 = time.time()
Expand Down Expand Up @@ -44,3 +47,18 @@ def split_not_in_quotes(
return [t for t in re.split(pattern, to_split) if t]
else:
return re.split(pattern, to_split)


def dvt_config_string_to_dict(config_string: str) -> dict:
"""Convert JSON in a string to a dict."""
if not config_string:
return None
if isinstance(config_string, dict):
return config_string
try:
param_dict = json.loads(config_string.replace("'", '"'))
return param_dict
except json.JSONDecodeError as exc:
raise exceptions.ValidationException(
f"Invalid JSON format in connection parameter dictionary string: {config_string}"
) from exc
74 changes: 73 additions & 1 deletion tests/unit/test_cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,37 @@
"https://mybq.p.googleapis.com",
]


SNOWFLAKE_CONNECTION_ARGS_DICT_STR = (
'{"private_key_file": "/dir/rsa_key.p8", "private_key_file_pwd": "p@1"}'
)
CLI_ADD_SNOWFLAKE_CONNECTION_ARGS = [
"connections",
"add",
"--connection-name",
"snowflake_conn",
"Snowflake",
"--user=dvtuserp8",
"--password=",
"--account=some-str",
"--database=pso_data_validator",
f"--connect-args={SNOWFLAKE_CONNECTION_ARGS_DICT_STR}",
]

TERADATA_CONNECTION_ARGS_DICT_STR = '{"a": "1", "b": 2}'
CLI_ADD_TERADATA_CONNECTION_ARGS = [
"connections",
"add",
"--connection-name",
"teradata_conn",
"Teradata",
"--host=host_name",
"--port=123",
"--user-name=dvt_user",
"--password=dvt_pass",
f"--json-params={TERADATA_CONNECTION_ARGS_DICT_STR}",
]

CLI_ADD_ORACLE_STD_CONNECTION_ARGS = [
"connections",
"add",
Expand Down Expand Up @@ -277,7 +308,7 @@ def test_create_bq_connection(caplog, fs):
bq_conn = cli_tools.get_connection(args.connection_name)
assert bq_conn["source_type"] == "BigQuery"

conn_from_file = cli_tools.get_connection("test_with_endpoint")
conn_from_file = cli_tools.get_connection(args.connection_name)
assert conn_from_file["api_endpoint"] == "https://mybq.p.googleapis.com"


Expand All @@ -300,6 +331,47 @@ def test_create_connections_oracle(mock_write_file):
cli_tools.store_connection(args.connection_name, conn)


def test_create_snowflake_connection(caplog, fs):
caplog.set_level(logging.INFO)
# Create Connection
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(CLI_ADD_SNOWFLAKE_CONNECTION_ARGS)
conn = cli_tools.get_connection_config_from_args(args)
cli_tools.store_connection(args.connection_name, conn)

assert gcs_helper.WRITE_SUCCESS_STRING in caplog.records[0].msg

conn = cli_tools.get_connection(args.connection_name)
assert conn["source_type"] == "Snowflake"
assert conn["user"] == args.user
assert conn["password"] == args.password
assert conn["account"] == args.account

conn_from_file = cli_tools.get_connection(args.connection_name)
assert conn_from_file["connect_args"] == SNOWFLAKE_CONNECTION_ARGS_DICT_STR


def test_create_teradata_connection(caplog, fs):
caplog.set_level(logging.INFO)
# Create Connection
parser = cli_tools.configure_arg_parser()
args = parser.parse_args(CLI_ADD_TERADATA_CONNECTION_ARGS)
conn = cli_tools.get_connection_config_from_args(args)
cli_tools.store_connection(args.connection_name, conn)

assert gcs_helper.WRITE_SUCCESS_STRING in caplog.records[0].msg

conn = cli_tools.get_connection(args.connection_name)
assert conn["source_type"] == "Teradata"
assert conn["host"] == args.host
assert conn["port"] == args.port
assert conn["user_name"] == args.user_name
assert conn["password"] == args.password

conn_from_file = cli_tools.get_connection(args.connection_name)
assert conn_from_file["json_params"] == TERADATA_CONNECTION_ARGS_DICT_STR


def test_configure_arg_parser_list_and_run_validation_configs():
"""Test configuring arg parse in different ways."""
parser = cli_tools.configure_arg_parser()
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,18 @@ def test_split_not_in_quotes(
else:
result = module_under_test.split_not_in_quotes(test_input)
assert result == expected


@pytest.mark.parametrize(
"test_input,expected",
[
("", None),
(None, None),
({"a": 123}, {"a": 123}),
('{"a": 123}', {"a": 123}),
("{'a': 123}", {"a": 123}),
],
)
def test_dvt_config_string_to_dict(module_under_test, test_input: str, expected):
result = module_under_test.dvt_config_string_to_dict(test_input)
assert result == expected
5 changes: 4 additions & 1 deletion third_party/ibis/ibis_snowflake/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

logging.getLogger("snowflake.connector").setLevel(logging.WARNING)

from data_validation.util import dvt_config_string_to_dict
import third_party.ibis.ibis_snowflake.datatypes


Expand All @@ -30,8 +31,10 @@ def snowflake_connect(
password: str,
account: str,
database: str,
connect_args: Mapping[str, Any] = None,
connect_args: str = None,
):
if connect_args:
connect_args = dvt_config_string_to_dict(connect_args)
return ibis.snowflake.connect(
user=user,
password=password,
Expand Down
13 changes: 3 additions & 10 deletions third_party/ibis/ibis_teradata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import pandas
import warnings

import json
import teradatasql
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from typing import Mapping, Any
import ibis.expr.schema as sch
from ibis.backends.base.sql import BaseSQLBackend

from third_party.ibis.ibis_teradata.compiler import TeradataCompiler
from third_party.ibis.ibis_teradata.datatypes import (
TeradataTypeTranslator,
Expand All @@ -41,7 +41,7 @@ def do_connect(
port: int = 1025,
logmech: str = "TD2",
use_no_lock_tables: str = "False",
json_params: str = None,
json_params: Mapping[str, Any] = None,
) -> None:
self.teradata_config = {
"host": host,
Expand All @@ -50,15 +50,8 @@ def do_connect(
"dbs_port": port,
"logmech": logmech,
}

if json_params:
try:
param_dict = json.loads(json_params.replace("'", '"'))
self.teradata_config.update(param_dict)
except json.JSONDecodeError:
print(
f"Invalid JSON format in the parameter dictionary string: {json_params}"
)
self.teradata_config.update(json_params)

self.client = teradatasql.connect(**self.teradata_config)
self.con = self.client.cursor()
Expand Down
5 changes: 5 additions & 0 deletions third_party/ibis/ibis_teradata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from third_party.ibis.ibis_teradata import Backend as TeradataBackend
import teradatasql # NOQA fail early if the package is missing

from data_validation.util import dvt_config_string_to_dict


def teradata_connect(
host: str = "localhost",
Expand All @@ -25,6 +27,9 @@ def teradata_connect(
json_params: str = None,
):
backend = TeradataBackend()
if json_params:
json_params = dvt_config_string_to_dict(json_params)

backend.do_connect(
host=host,
user_name=user_name,
Expand Down