SKLearn Linear Regression on Grouped Pandas Dataframe without aggregation?

Question:

Trying to perform a linear regression over a set of grouped columns and put the coefficient results on each line without performing an aggregations (equivalent to a window function in SQL).

I’m banging my head against a wall here.

In a for loop this works just fine…

Input:

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
    
def ger(df):
   x=np.array(df['x']).reshape(-1,1)
   y=np.array(df['y']).reshape(-1,1)
   w = np.array(df['w'])
   r = LinearRegression().fit(x,y,w)
   df['a'] = r.coef_[0]
   df['b'] = r.intercept_[0]

g = ['a','a','a','b','b','b']
x = [1,2,3,2,4,6]
y = [1,2,3,4,8,12]
w = [1,1,1,1,1,1]
    
df = pd.DataFrame()
df['g'] = g
df['x'] = x
df['y'] = y
df['w'] = w

df.groupby(['g']).apply(lambda x : ger(x))

Expected Outcome:

df
       g  x  y  w a b
    0  a  1  1  1 1 0
    1  a  2  2  1 1 0
    2  a  3  3  1 1 0
    3  b  2  1  1 2 0
    4  b  4  2  1 2 0
    5  b  6  3  1 2 0

Actual Outcome:

Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/groupby/groupby.py", line 1567, in apply
        result = self._python_apply_general(f, self._selected_obj)
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/groupby/groupby.py", line 1629, in _python_apply_general
        values, mutated = self.grouper.apply(f, data, self.axis)
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/groupby/ops.py", line 839, in apply
        res = f(group)
      File "<stdin>", line 1, in <lambda>
      File "<stdin>", line 6, in ger
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/frame.py", line 3980, in __setitem__
        self._set_item(key, value)
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/frame.py", line 4174, in _set_item
        value = self._sanitize_column(value)
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/frame.py", line 4915, in _sanitize_column
        com.require_length_match(value, self.index)
      File "/home/username/.local/lib/python3.10/site-packages/pandas/core/common.py", line 571, in require_length_match
        raise ValueError(
    ValueError: Length of values (1) does not match length of index (3)

What am I doing wrong here? Any guidance would be greatly appreciated.

In R w/ tidyverse this would look like:

df %>%
    group_by(g) %>%
    mutate(a = lm_a(x, y, w),
           b = lm_b(x, y, w)) ->
    df

Where lm_a and lm_b return fit coefficients and intercepts respectively.

Asked By: MusLearning

||

Answers:

There are a few issues with the code:

  1. r.coef_ is a 2D array so you can either do:
    • df[['a']] = r.coef_[0] or
    • df['a'] = r.coef_[0][0]
  2. You need to use return df in ger() in order to get df back.
  3. FutureWarning on groupby requires to use group_keys=False for future compatibility.
  4. Not really an issue, but .apply(ger) is fine here instead of .apply(lambda x : ger(x))

Below is the corrected code:

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
    
def ger(df):
   x=np.array(df['x']).reshape(-1,1)
   y=np.array(df['y']).reshape(-1,1)
   w = np.array(df['w'])
   r = LinearRegression().fit(x,y,w)
   df[['a']] = r.coef_[0]
   df[['b']] = r.intercept_[0]
   return df


g = ['a','a','a','b','b','b']
x = [1,2,3,2,4,6]
y = [1,2,3,4,8,12]
w = [1,1,1,1,1,1]
    
df = pd.DataFrame()
df['g'] = g
df['x'] = x
df['y'] = y
df['w'] = w

df = df.groupby('g', group_keys=False).apply(ger)

Output:

    g   x   y   w   a    b
0   a   1   1   1   1.0  4.440892e-16
1   a   2   2   1   1.0  4.440892e-16
2   a   3   3   1   1.0  4.440892e-16
3   b   2   4   1   2.0  1.776357e-15
4   b   4   8   1   2.0  1.776357e-15
5   b   6   12  1   2.0  1.776357e-15

Due to rounding calculations, column b isn’t at 0 exactly. You can round b to the 10th decimal for example if needed: df[['b']] = round(r.intercept_[0], 10).

Answered By: Mattravel
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.