Pyspark cast float to double is unprecise

Question:

I grouped by the sum(float) and the result is not what I expected.

Not only for grouping by but it happens when I cast float to double.

Here is an code example below.

>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import *
>>> schema = StructType([ 
...     StructField("firstname",StringType(),True), 
...     StructField("middlename",StringType(),True), 
...     StructField("v",FloatType(),True)])
>>>
>>> df = spark.createDataFrame([["a","b",1.12],["a","b",2.23],["a","c",7.78]],schema=schema)
>>> df.show()
+---------+----------+----+
|firstname|middlename|   v|
+---------+----------+----+
|        a|         b|1.12|
|        a|         b|2.23|
|        a|         c|7.78|
+---------+----------+----+
>>> df.groupBy("firstname","middlename").agg(sum("v")).show()
+---------+----------+-----------------+
|firstname|middlename|           sum(v)|
+---------+----------+-----------------+
|        a|         b|3.350000023841858|
|        a|         c| 7.78000020980835|
+---------+----------+-----------------+
>>> df.groupBy("firstname","middlename").agg(sum("v").cast("float")).show()
+---------+----------+---------------------+
|firstname|middlename|CAST(sum(v) AS FLOAT)|
+---------+----------+---------------------+
|        a|         b|                 3.35|
|        a|         c|                 7.78|
+---------+----------+---------------------+
>>> df.select(col("v"), col("v").cast("double")).show()
+----+------------------+
|   v|                 v|
+----+------------------+
|1.12|1.1200000047683716|
|2.23|2.2300000190734863|
|7.78|  7.78000020980835|
+----+------------------+

I think that’s because of the type precision(4 bytes, 8 bytes) but I think this is a bug because the value of float should be preserved when it is cast to double.

I found a solution as I write that cast to float after grouping by but I think this is not clear.

Is there any fancy solution for this?

Asked By: 에스파파

||

Answers:

I found an answer that is doing cast to string before I aggregate column v.

ex)

from pyspark.sql import functions as F

>>> df.withColumn("v",col("v").cast("string").cast("double"))
    .groupBy("firstname","middlename").F.agg(sum("v")).show()
+---------+----------+------+
|firstname|middlename|sum(v)|
+---------+----------+------+
|        a|         b|  3.35|
|        a|         c|  7.78|
+---------+----------+------+
>>> df.withColumn("v",col("v").cast("string").cast("double"))
    .groupBy("firstname","middlename").F.agg(sum("v")).printSchema()
root
 |-- firstname: string (nullable = true)
 |-- middlename: string (nullable = true)
 |-- sum(v): double (nullable = true)
Answered By: 에스파파