mirror of
https://forge.apps.education.fr/phroy/mes-scripts-de-ml.git
synced 2024-01-27 11:30:36 +01:00
Keras : regression
This commit is contained in:
parent
abbc09de30
commit
abe646de60
@ -68,9 +68,18 @@ t_debut = time.time()
|
||||
# Init des plots
|
||||
fig = plt.figure(figsize=(15, 5))
|
||||
fig.suptitle("Réseaux de neurones avec Keras - Regression")
|
||||
model_ax = fig.add_subplot(131) # Modèle
|
||||
apts_ax = fig.add_subplot(132) # Courbes d'apprentissage
|
||||
donnees_ax = fig.add_subplot(133) # Observations : x1,x2 et cibles : y
|
||||
model_ax = fig.add_subplot(121) # Modèle
|
||||
apts_ax = fig.add_subplot(122) # Courbes d'apprentissage
|
||||
# donnees_ax = fig.add_subplot(133) # Observations : x1,x2 et cibles : y
|
||||
|
||||
# Logs
|
||||
root_logdir = os.path.join(os.curdir, "keras_logs")
|
||||
|
||||
def get_run_logdir():
|
||||
run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
|
||||
return os.path.join(root_logdir, run_id)
|
||||
|
||||
run_logdir = get_run_logdir()
|
||||
|
||||
###############################################################################
|
||||
# Observations
|
||||
@ -78,8 +87,10 @@ donnees_ax = fig.add_subplot(133) # Observations : x1,x2 et cibles : y
|
||||
|
||||
# Observations d'apprentissage, de validation et de test
|
||||
housing = sklearn.datasets.fetch_california_housing() # Jeu de données California housing
|
||||
X, X_test, y, y_test = sklearn.model_selection.train_test_split(housing.data, housing.target, random_state=42)
|
||||
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=42)
|
||||
X, X_test, y, y_test = sklearn.model_selection.train_test_split(housing.data, housing.target)
|
||||
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
|
||||
# X, X_test, y, y_test = sklearn.model_selection.train_test_split(housing.data, housing.target, random_state=42)
|
||||
# X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=42)
|
||||
|
||||
# Normalisation
|
||||
scaler = sklearn.preprocessing.StandardScaler()
|
||||
@ -101,16 +112,17 @@ perte="mse" # Type de perte (hyperparamètre)
|
||||
# perte="sparse_categorical_crossentropy"
|
||||
|
||||
keras.backend.clear_session()
|
||||
np.random.seed(42)
|
||||
tf.random.set_seed(42)
|
||||
# np.random.seed(42)
|
||||
# tf.random.set_seed(42)
|
||||
model = keras.models.Sequential() # Modèle de reseau de neurones
|
||||
model.add(keras.layers.Dense(30, input_shape=X_train.shape[1:], activation="relu")) # Couche 1 : 30 nodes
|
||||
model.add(keras.layers.Dense(1)) # Couche de sortie : 1 node par classe
|
||||
|
||||
optimiseur=keras.optimizers.SGD(learning_rate= eta)
|
||||
model.compile(loss=perte, optimizer=optimiseur) # Compilation du modèle
|
||||
|
||||
apts = model.fit(X_train, y_train, epochs=n, batch_size=lot, validation_data=(X_valid, y_valid)) # Entrainement
|
||||
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_keras_model.h5")
|
||||
tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)
|
||||
apts = model.fit(X_train, y_train, epochs=n, batch_size=lot, validation_data=(X_valid, y_valid), callbacks=[checkpoint_cb, tensorboard_cb]) # Entrainement
|
||||
|
||||
###############################################################################
|
||||
# Phase d'inférence
|
||||
@ -139,7 +151,8 @@ apts_ax.set(ylim=(-0.05, 1.05))
|
||||
apts_ax.set_xlabel("Époque")
|
||||
apts_ax.legend()
|
||||
|
||||
# # Plot des données
|
||||
# Plot des données
|
||||
# FIXME : mettre des graphiques de prédiction
|
||||
# donnees_ax.set_title("Données")
|
||||
# plot_i=[]
|
||||
# plot_x1=[]
|
||||
|
@ -23,3 +23,8 @@
|
||||
### Réseaux de neurones avec Keras - Classificateur : Points sur spirales
|
||||
|
||||
![capture d'écran](img/06-keras-tf_playground-spiral-v2.png)
|
||||
|
||||
### Réseaux de neurones avec Keras - Regression
|
||||
|
||||
![capture d'écran](img/07-keras-regression.png)
|
||||
|
||||
|
BIN
02-intro_rna/img/07-keras-regression.png
Normal file
BIN
02-intro_rna/img/07-keras-regression.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 356 KiB |
Loading…
Reference in New Issue
Block a user