*** Wartungsfenster jeden ersten Mittwoch vormittag im Monat ***

Skip to content
Snippets Groups Projects
Commit 784522c4 authored by Harrison, Simeon's avatar Harrison, Simeon
Browse files

Added notebook TensorFlow distributed for multi-node, multi-GPU setup.

parent e2d342d4
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
%%writefile ./tf_distr_slurm.sh
#!/bin/bash
#SBATCH --job-name=tf_distr_example
#SBATCH --account=p70824 # training account, please uncomment for training
#SBATCH --nodes=2 # Number of nodes
#SBATCH --ntasks-per-node=1 # Number of tasks per node
#SBATCH --cpus-per-task=256 # Number of CPU cores per task (including hyperthreading if needed)
#SBATCH --partition=zen3_0512_a100x2
#SBATCH --qos=zen3_0512_a100x2 # qos for training
#SBATCH --gres=gpu:2 # Number of GPUs per node
#SBATCH --output=./output/%x-%j.out # Output file
#SBATCH --time=00:10:00
######################
### Set Environment ###
######################
module load miniconda3
eval "$(conda shell.bash hook)"
source /opt/sw/jupyterhub/envs/conda/vsc5/jupyterhub-huggingface-v2/modules # Activate the conda environment
#source /opt/sw/jupyterhub/envs/conda/vsc5/jupyterhub-llm-training-v3
######################
#### Set Network #####
######################
# Get the IP address of the master node (head node)
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
nodes_array=($nodes)
export NODE_0=${nodes_array[0]}
export NODE_1=${nodes_array[1]}
export MASTER_PORT=29500
NUM_PROCESSES=$(( SLURM_NNODES * SLURM_GPUS_ON_NODE ))
######################
### Launch Training ###
######################
srun python3 tf_distr.py
```
%% Output
Overwriting ./tf_distr_slurm.sh
%% Cell type:code id: tags:
``` python
%%writefile ./tf_distr.py
import os
import json
import numpy as np
import tensorflow as tf
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
# Dynamically set TF_CONFIG for distributed TensorFlow
def set_tf_config():
# Retrieve master and worker nodes from environment variables
node_0 = os.environ['NODE_0']
node_1 = os.environ['NODE_1']
task_id = int(os.environ['SLURM_PROCID']) # SLURM task ID determines worker index
# Create a list of workers
worker_hosts = [f"{node_0}:29500", f"{node_1}:29500"]
# Construct TF_CONFIG
tf_config = {
"cluster": {"worker": worker_hosts},
"task": {"type": "worker", "index": task_id}
}
os.environ['TF_CONFIG'] = json.dumps(tf_config)
print("TF_CONFIG set to:", json.dumps(tf_config, indent=4))
# Call the TF_CONFIG setup function
set_tf_config()
# Set up distributed strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()
# Load and preprocess the MNIST dataset
mnist = fetch_openml('mnist_784', as_frame=False)
X, y = mnist.data, mnist.target
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
X_train = X_train.reshape(-1, 28, 28) / 255.0
X_test = X_test.reshape(-1, 28, 28) / 255.0
y_train = np.array(y_train, dtype="int32")
y_test = np.array(y_test, dtype="int32")
# Create a distributed dataset
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(batch_size)
# Define the model within the strategy's scope
with strategy.scope():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax")
])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy']
)
# Train the model
model.fit(train_dataset, epochs=20, validation_data=test_dataset)
# Evaluate the model
y_pred = model.predict(X_test, batch_size=batch_size)
y_preds = tf.argmax(y_pred, axis=1)
# Print predictions and ground truth
print("Predictions:", y_preds.numpy())
print("Ground Truth:", y_test)
# Plot confusion matrix
ConfusionMatrixDisplay.from_predictions(y_test, y_preds)
plt.show()
```
%% Output
Overwriting ./tf_distr.py
%% Cell type:code id: tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment