ValueError: bad input shape (1, 4) in sklearn.naive_bayes.GaussianNB

Question:

I started to learn machine learning, currently Naive Bayes/

My python script

import numpy as np
x = np.array([[0,0],[1,1],[0,1],[1,0]])
y = np.array([0,0,1,1])
print(x)
from sklearn.naive_bayes import GaussianNB
clf = GaussianNB()
x = x.reshape(1,-1)
y = y.reshape(1,-1)
clf.fit(x,y)
a = clf.predict([[1,1]])
print(a)

Error

The error is:

[[0 0]
[1 1]
[0 1]
[1 0]]
Traceback (most recent call last):
  File "ex.py", line 9, in <module>
    clf.fit(x,y)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/naive_bayes.py", line 182, in fit
    X, y = check_X_y(X, y)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.py", line 526, in check_X_y
    y = column_or_1d(y, warn=True)
  File "/usr/local/lib/python2.7/dist-packages/sklearn/utils/validation.py", line 562, in column_or_1d
    raise ValueError("bad input shape {0}".format(shape))
ValueError: bad input shape (1, 4)

What should I do?

Asked By: Adit Srivastava

||

Answers:

As I was saying in the comments, no need to reshape. You get this error before (i.e. no with the code you provide) since you code was working after that.

Considering the part where you obtain always 0 as the prediction, it was due to your data. Naive Bayes needs more samples to separate the classes, two samples per class for a non-linear problem is not sufficient.

import numpy as np
from sklearn.naive_bayes import GaussianNB

def GNB(x,y):
    clf = GaussianNB()
    clf.fit(x,y)
    a = clf.predict(x)
    print(a)

x = np.array([[0,0],[1,1],[0,1],[1,0]])
y = np.array([0,0,1,1])
GNB(x,y)
# Output : [0,0,0,0]

x = np.array([[0,0],[0,1],[1,1],[1,0],[3,4],[-2,2],[-3,2],[-4,-2]])
y = np.array([0,0,0,0,1,1,1,1])
GNB(x,y)
# [0 0 0 0 1 1 1 1]

Look at my two examples. In the first case (the one you provide), the NB does not succeed the separate the data. On the 2nd example (still non-linearly), the NB succeed to return the right classes as there was a sufficient number of samples.

I create a function for clarity, but you can just add more samples to your example and you will see it works.
Hope this helps and solve your problem.

Answered By: Nuageux