What I am doing wrong in class definition that my fit is that bad when plotted. I know that it has to be good enough without adjusting number of estimators or decision trees’ depth..
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad
from sklearn.tree import DecisionTreeRegressor
class GradientBoosting:
def __init__(self, n_estimators=10, learning_rate=1., weak_learner=None, ensemble_loss=None):
self.n_estimators = n_estimators
self.learning_rate = learning_rate
self.models = []
self.weak_learner = weak_learner if weak_learner is not None else lambda: DecisionTreeRegressor(max_depth=3)
self.ensemble_loss = ensemble_loss if ensemble_loss is not None else lambda y_pred, y_true: jnp.mean((y_pred - y_true) ** 2)
self.running_pred_sum = None
def fit(self, X, y):
initial_model = self.weak_learner()
initial_model.fit(X, y)
self.models.append(initial_model)
self.running_pred_sum = jnp.array(initial_model.predict(X))
for _ in range(1, self.n_estimators):
y_pred = self.running_pred_sum
value, gradients = value_and_grad(self.ensemble_loss)(y_pred, y)
#print(value)
tree = self.weak_learner()
tree.fit(X, -gradients)
self.models.append(tree)
self.running_pred_sum += self.learning_rate * jnp.array(tree.predict(X))
def predict(self, X):
y_pred = sum(self.learning_rate * jnp.array(model.predict(X)) for model in self.models)
return y_pred
import matplotlib.pyplot as plt
def generate_sine_data():
rng_key = jax.random.PRNGKey(1)
X = jnp.linspace(0, 6, 100).reshape(-1, 1)
y = jnp.sin(X).ravel() + jnp.sin(6 * X).ravel() + jax.random.normal(rng_key, X.shape).ravel() * 0.1
return X, y
X, y = generate_sine_data()
gb = GradientBoosting(n_estimators=25, learning_rate=1)
gb.fit(X, y)
X_test = jnp.linspace(X.min(), X.max(), 100).reshape(-1, 1)
predictions = gb.predict(X_test)
plt.figure(figsize=(10, 6))
plt.scatter(X, y, color='blue', label='Points de données')
plt.plot(X_test, predictions, color='red', label='Ajustement Gradient Boosting')
plt.title('Ajustement de la régression Gradient Boosting')
plt.xlabel('Caractéristique')
plt.ylabel('Cible')
plt.legend()
plt.show()
Thanks!
I tried to do it step by step instead of defining class.. The print(value) output for Loss values should be as follows:
5.240905
5.010452
4.757929
4.438871
4.3248005
4.0045547
3.8759923
3.5886106
3.252979
3.0276678
2.876125
2.5988436
2.4220018
2.210686
2.1285996
2.0060892
1.8828202
1.7201844
1.6591048
1.6228759
1.4467268
1.3575808
1.315804
1.2892128
But I am getting them all under 1..