Source code for neurio.tasks.classification.shdClassification
#!/user/bin/env python
"""
Author: Romain Gaulier
Email: romain.gaulier@csem.ch
Copyright: CSEM, 2023
Creation: 26.04.23
Description: Pipeline for SHD dataset
"""
import numpy as np
import tensorflow as tf
from neurio.tasks.task import Task
[docs]class SHDClassification(Task):
def __init__(self):
super().__init__()
self.x = None
self.y = None
self.metric = tf.keras.metrics.CategoricalCrossentropy()
self.train_ds = spikedata.SHD("dataset/shd", train=True)
self.test_ds = spikedata.SHD("dataset/shd", train=False)
self.prepare_data()
[docs] def prepare_data(self):
# Convert the spike-based data to regular image arrays
train_images = [self.train_ds[i][0] for i in range(len(self.train_ds))]
train_labels = [self.train_ds[i][1] for i in range(len(self.train_ds))]
test_images = [self.test_ds[i][0] for i in range(len(self.test_ds))]
test_labels = [self.test_ds[i][1] for i in range(len(self.test_ds))]
# Convert the image arrays to TensorFlow datasets
train_images_ds = tf.data.Dataset.from_tensor_slices(train_images)
train_labels_ds = tf.data.Dataset.from_tensor_slices(train_labels)
test_images_ds = tf.data.Dataset.from_tensor_slices(test_images)
test_labels_ds = tf.data.Dataset.from_tensor_slices(test_labels)
# Combine the image and label datasets
train_ds = tf.data.Dataset.zip((train_images_ds, train_labels_ds))
test_ds = tf.data.Dataset.zip((test_images_ds, test_labels_ds))
# Shuffle and batch the datasets
batch_size = 64
x_train = train_ds.shuffle(len(train_ds)).batch(batch_size)
x_test = test_ds.batch(batch_size)
# Convert labels to one-hot encoding
num_classes = 10 # TODO: Set the value
y_train = x_train.map(lambda x, y: (x, self.metric(y, num_classes)))
y_test = x_test.map(lambda x, y: (x, self.metric(y, num_classes)))
x_train = (x_train / 255.0).astype(np.float32)
x_test = (x_test / 255.0).astype(np.float32)
self.x = x_train.numpy()
# Not sure?
self.y = y_train.numpy()
[docs] def evaluate(self, y_train, y_pred):
self.metric.update_state(y_train, y_pred, sample_weight=None)
return self.metric.result().numpy()