[28]:
import os
import sys
import site
cwd = os.getcwd()
print(f"Current directory: {cwd}")
print(f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
print(f"Current Python interpreter: {sys.executable}")
print(f"Current site-packages: {site.getsitepackages()}")
sys.path.append(os.path.join(cwd, "site-packages"))
Current directory: /home/jovyan/docs/source/04-Typing
Current Python version: 3.10.5
Current Python interpreter: /opt/conda/bin/python
Current site-packages: ['/opt/conda/lib/python3.10/site-packages']
[29]:
# 首先创建一个 Spark Session
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
spark
[29]:
SparkSession - in-memory
Typing in PySpark¶
[30]:
import json
from pathlib import Path
import pyspark.sql.functions as F
import pyspark.sql.types as T
Mixed Type in One Column¶
PySpark 的 DataFrame 是 Schema Enforced 的. 也就是说一个 Column 中的值的 Type 必须一致. 如果你是直接创建 DataFrame 而你又 Pass in 了不同的 Type, 那么会直接报错不让你创建.
[31]:
pdf = spark.createDataFrame(
[
(1, ),
(2, ),
(3, ),
(4, ),
(5, ),
("a", ),
("b", ),
],
("col",)
)
pdf.show()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 pdf = spark.createDataFrame(
2 [
3 (1, ),
4 (2, ),
5 (3, ),
6 (4, ),
7 (5, ),
8 ("a", ),
9 ("b", ),
10 ],
11 ("col",)
12 )
13 pdf.show()
File /usr/local/spark/python/pyspark/sql/session.py:894, in SparkSession.createDataFrame(self, data, schema, samplingRatio, verifySchema)
889 if has_pandas and isinstance(data, pandas.DataFrame):
890 # Create a DataFrame from pandas DataFrame.
891 return super(SparkSession, self).createDataFrame( # type: ignore[call-overload]
892 data, schema, samplingRatio, verifySchema
893 )
--> 894 return self._create_dataframe(
895 data, schema, samplingRatio, verifySchema # type: ignore[arg-type]
896 )
File /usr/local/spark/python/pyspark/sql/session.py:936, in SparkSession._create_dataframe(self, data, schema, samplingRatio, verifySchema)
934 rdd, struct = self._createFromRDD(data.map(prepare), schema, samplingRatio)
935 else:
--> 936 rdd, struct = self._createFromLocal(map(prepare, data), schema)
937 assert self._jvm is not None
938 jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
File /usr/local/spark/python/pyspark/sql/session.py:631, in SparkSession._createFromLocal(self, data, schema)
628 data = list(data)
630 if schema is None or isinstance(schema, (list, tuple)):
--> 631 struct = self._inferSchemaFromList(data, names=schema)
632 converter = _create_converter(struct)
633 tupled_data: Iterable[Tuple] = map(converter, data)
File /usr/local/spark/python/pyspark/sql/session.py:517, in SparkSession._inferSchemaFromList(self, data, names)
515 infer_dict_as_struct = self._jconf.inferDictAsStruct()
516 prefer_timestamp_ntz = is_timestamp_ntz_preferred()
--> 517 schema = reduce(
518 _merge_type,
519 (_infer_schema(row, names, infer_dict_as_struct, prefer_timestamp_ntz) for row in data),
520 )
521 if _has_nulltype(schema):
522 raise ValueError("Some of types cannot be determined after inferring")
File /usr/local/spark/python/pyspark/sql/types.py:1383, in _merge_type(a, b, name)
1381 if isinstance(a, StructType):
1382 nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields)
-> 1383 fields = [
1384 StructField(
1385 f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name))
1386 )
1387 for f in a.fields
1388 ]
1389 names = set([f.name for f in fields])
1390 for n in nfs:
File /usr/local/spark/python/pyspark/sql/types.py:1385, in <listcomp>(.0)
1381 if isinstance(a, StructType):
1382 nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields)
1383 fields = [
1384 StructField(
-> 1385 f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name))
1386 )
1387 for f in a.fields
1388 ]
1389 names = set([f.name for f in fields])
1390 for n in nfs:
File /usr/local/spark/python/pyspark/sql/types.py:1378, in _merge_type(a, b, name)
1375 return b
1376 elif type(a) is not type(b):
1377 # TODO: type cast (such as int -> long)
-> 1378 raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
1380 # same type
1381 if isinstance(a, StructType):
TypeError: field col: Can not merge type <class 'pyspark.sql.types.LongType'> and <class 'pyspark.sql.types.StringType'>
如果你是通过 IO 从文件中读取数据, 而存在 Type 不匹配的情况, 那么 Spark 会尝试将所有数据 Convert 成同一个 Type, 通常是字符串. 如果实在不行则报错. 这个机制也有副作用, 当你期望的数值是 “非字符串” 时, 你往往想要把错误的 “字符串” 所在的行丢弃, 而 Spark 默认的行为则是全部转化成 “字符串”.
[ ]:
p_json = Path(cwd) / "tmp.json"
records = [
{"col": 1},
{"col": 2},
{"col": 3},
{"col": 4},
{"col": 5},
{"col": "a"},
{"col": "b"},
]
content = "\n".join([
json.dumps(record)
for record in records
])
p_json.write_text(content)
pdf = spark.read.json(
f"{p_json}",
)
pdf.show()
[ ]:
pdf.printSchema()
[ ]:
pdf.collect()
为了解决 “错误的数据类型会连累正确的数据类型被强制转化成字符串” 的问题, 你可以在 IO 的时候手动指定 schema, 并允许 nullable = True
, 这样凡是数据类型不对的行就会自动被设为 Null. 手动指定 schema 的 API 可以 参考这里
[ ]:
# Define Structure
schema = T.StructType([
T.StructField("col", T.IntegerType(), True),
])
pdf = spark.read.json(f"{p_json}", schema=schema)
pdf.show()
[ ]: