[28]:
import os
import sys
import site

cwd = os.getcwd()
print(f"Current directory: {cwd}")
print(f"Current Python version: {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}")
print(f"Current Python interpreter: {sys.executable}")
print(f"Current site-packages: {site.getsitepackages()}")

sys.path.append(os.path.join(cwd, "site-packages"))
Current directory: /home/jovyan/docs/source/04-Typing
Current Python version: 3.10.5
Current Python interpreter: /opt/conda/bin/python
Current site-packages: ['/opt/conda/lib/python3.10/site-packages']
[29]:
# 首先创建一个 Spark Session
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
spark
[29]:

SparkSession - in-memory

SparkContext

Spark UI

Version
v3.3.0
Master
local[*]
AppName
pyspark-shell

Typing in PySpark

[30]:
import json
from pathlib import Path
import pyspark.sql.functions as F
import pyspark.sql.types as T

Mixed Type in One Column

PySpark 的 DataFrame 是 Schema Enforced 的. 也就是说一个 Column 中的值的 Type 必须一致. 如果你是直接创建 DataFrame 而你又 Pass in 了不同的 Type, 那么会直接报错不让你创建.

[31]:
pdf = spark.createDataFrame(
    [
        (1, ),
        (2, ),
        (3, ),
        (4, ),
        (5, ),
        ("a", ),
        ("b", ),
    ],
    ("col",)
)
pdf.show()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [31], in <cell line: 1>()
----> 1 pdf = spark.createDataFrame(
      2     [
      3         (1, ),
      4         (2, ),
      5         (3, ),
      6         (4, ),
      7         (5, ),
      8         ("a", ),
      9         ("b", ),
     10     ],
     11     ("col",)
     12 )
     13 pdf.show()

File /usr/local/spark/python/pyspark/sql/session.py:894, in SparkSession.createDataFrame(self, data, schema, samplingRatio, verifySchema)
    889 if has_pandas and isinstance(data, pandas.DataFrame):
    890     # Create a DataFrame from pandas DataFrame.
    891     return super(SparkSession, self).createDataFrame(  # type: ignore[call-overload]
    892         data, schema, samplingRatio, verifySchema
    893     )
--> 894 return self._create_dataframe(
    895     data, schema, samplingRatio, verifySchema  # type: ignore[arg-type]
    896 )

File /usr/local/spark/python/pyspark/sql/session.py:936, in SparkSession._create_dataframe(self, data, schema, samplingRatio, verifySchema)
    934     rdd, struct = self._createFromRDD(data.map(prepare), schema, samplingRatio)
    935 else:
--> 936     rdd, struct = self._createFromLocal(map(prepare, data), schema)
    937 assert self._jvm is not None
    938 jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())

File /usr/local/spark/python/pyspark/sql/session.py:631, in SparkSession._createFromLocal(self, data, schema)
    628     data = list(data)
    630 if schema is None or isinstance(schema, (list, tuple)):
--> 631     struct = self._inferSchemaFromList(data, names=schema)
    632     converter = _create_converter(struct)
    633     tupled_data: Iterable[Tuple] = map(converter, data)

File /usr/local/spark/python/pyspark/sql/session.py:517, in SparkSession._inferSchemaFromList(self, data, names)
    515 infer_dict_as_struct = self._jconf.inferDictAsStruct()
    516 prefer_timestamp_ntz = is_timestamp_ntz_preferred()
--> 517 schema = reduce(
    518     _merge_type,
    519     (_infer_schema(row, names, infer_dict_as_struct, prefer_timestamp_ntz) for row in data),
    520 )
    521 if _has_nulltype(schema):
    522     raise ValueError("Some of types cannot be determined after inferring")

File /usr/local/spark/python/pyspark/sql/types.py:1383, in _merge_type(a, b, name)
   1381 if isinstance(a, StructType):
   1382     nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields)
-> 1383     fields = [
   1384         StructField(
   1385             f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name))
   1386         )
   1387         for f in a.fields
   1388     ]
   1389     names = set([f.name for f in fields])
   1390     for n in nfs:

File /usr/local/spark/python/pyspark/sql/types.py:1385, in <listcomp>(.0)
   1381 if isinstance(a, StructType):
   1382     nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields)
   1383     fields = [
   1384         StructField(
-> 1385             f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name))
   1386         )
   1387         for f in a.fields
   1388     ]
   1389     names = set([f.name for f in fields])
   1390     for n in nfs:

File /usr/local/spark/python/pyspark/sql/types.py:1378, in _merge_type(a, b, name)
   1375     return b
   1376 elif type(a) is not type(b):
   1377     # TODO: type cast (such as int -> long)
-> 1378     raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
   1380 # same type
   1381 if isinstance(a, StructType):

TypeError: field col: Can not merge type <class 'pyspark.sql.types.LongType'> and <class 'pyspark.sql.types.StringType'>

如果你是通过 IO 从文件中读取数据, 而存在 Type 不匹配的情况, 那么 Spark 会尝试将所有数据 Convert 成同一个 Type, 通常是字符串. 如果实在不行则报错. 这个机制也有副作用, 当你期望的数值是 “非字符串” 时, 你往往想要把错误的 “字符串” 所在的行丢弃, 而 Spark 默认的行为则是全部转化成 “字符串”.

[ ]:
p_json = Path(cwd) / "tmp.json"

records = [
    {"col": 1},
    {"col": 2},
    {"col": 3},
    {"col": 4},
    {"col": 5},
    {"col": "a"},
    {"col": "b"},
]

content = "\n".join([
    json.dumps(record)
    for record in records
])

p_json.write_text(content)

pdf = spark.read.json(
    f"{p_json}",
)

pdf.show()
[ ]:
pdf.printSchema()
[ ]:
pdf.collect()

为了解决 “错误的数据类型会连累正确的数据类型被强制转化成字符串” 的问题, 你可以在 IO 的时候手动指定 schema, 并允许 nullable = True, 这样凡是数据类型不对的行就会自动被设为 Null. 手动指定 schema 的 API 可以 参考这里

[ ]:
# Define Structure
schema = T.StructType([
    T.StructField("col", T.IntegerType(), True),
])

pdf = spark.read.json(f"{p_json}", schema=schema)

pdf.show()
[ ]: