diff --git a/CHANGELOG.md b/CHANGELOG.md index 88ae7ea7b..0ba5552b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ - #1413 Fix for null tests 13 and 23 of windowFunctionTest - #1416 Fix full join when both tables contains nulls - #1423 Fix temporary directory for hive partition test +- #1422 Fix for predict file type when directory is provided as input ## Deprecated Features diff --git a/pyblazing/pyblazing/apiv2/context.py b/pyblazing/pyblazing/apiv2/context.py index a4c80eb4d..618728102 100644 --- a/pyblazing/pyblazing/apiv2/context.py +++ b/pyblazing/pyblazing/apiv2/context.py @@ -2270,15 +2270,6 @@ def create_table(self, table_name, input, **kwargs): # /path/to/data/folder/ -> name_file = /path/to/data/folder/, extension = '' name_file, extension = os.path.splitext(input[0]) - if not recognized_extension(extension) and file_format_hint == "undefined": - raise Exception( - "ERROR: Your input file doesn't have a recognized extension, " - + "you have to specify the `file_format` parameter. " - + "Recognized extensions are: [orc, parquet, csv, json, psv]." - + "\nFor example if you are using a *.log file you must pass file_format='csv' " - + "with all the needed extra parameters. See https://docs.blazingdb.com/docs/creating-tables" - ) - if ( file_format_hint == "undefined" and extension == "" @@ -2296,6 +2287,28 @@ def create_table(self, table_name, input, **kwargs): kwargs["names"].pop(id) kwargs["dtype"].pop(id) + # if the input is a directory and files do not have extension we want to raise an error + if name_file[-1] == "/" and extension == "": + all_files = os.listdir(name_file) + + if len(all_files) == 0: + raise Exception( + "ERROR: You need to ensure the current directory is not empty." + ) + + first_file, extension_file = os.path.splitext(all_files[0]) + if ( + not recognized_extension(extension_file) + and file_format_hint == "undefined" + ): + raise Exception( + "ERROR: Your input file doesn't have a recognized extension, " + + "you have to specify the `file_format` parameter. " + + "Recognized extensions are: [orc, parquet, csv, json, psv]." + + "\nFor example if you are using a *.log file you must pass file_format='csv' " + + "with all the needed extra parameters. See https://docs.blazingdb.com/docs/creating-tables" + ) + parsedSchema, parsed_mapping_files = self._parseSchema( input, file_format_hint,